mirror of
https://github.com/schollz/croc.git
synced 2024-11-24 08:02:33 +03:00
remove globals
This commit is contained in:
parent
b2dc1f32f8
commit
ffebb472b9
12
src/api.go
12
src/api.go
@ -10,6 +10,10 @@ type Croc struct {
|
||||
UseCompression bool
|
||||
CurveType string
|
||||
AllowLocalDiscovery bool
|
||||
|
||||
// private variables
|
||||
// rs relay state is only for the relay
|
||||
rs relayState
|
||||
}
|
||||
|
||||
// Init will initialize the croc relay
|
||||
@ -27,11 +31,15 @@ func Init() (c *Croc) {
|
||||
|
||||
// Relay initiates a relay
|
||||
func (c *Croc) Relay() error {
|
||||
c.rs.Lock()
|
||||
c.rs.channel = make(map[string]*channelData)
|
||||
c.rs.Unlock()
|
||||
|
||||
// start relay
|
||||
go startRelay(c.TcpPorts)
|
||||
go c.startRelay(c.TcpPorts)
|
||||
|
||||
// start server
|
||||
return startServer(c.TcpPorts, c.ServerPort)
|
||||
return c.startServer(c.TcpPorts, c.ServerPort)
|
||||
}
|
||||
|
||||
// Send will take an existing file or folder and send it through the croc relay
|
||||
|
@ -3,6 +3,7 @@ package croc
|
||||
import (
|
||||
"crypto/elliptic"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
@ -16,6 +17,11 @@ var (
|
||||
availableStates = []string{"curve", "h_k", "hh_k", "x", "y"}
|
||||
)
|
||||
|
||||
type relayState struct {
|
||||
channel map[string]*channelData
|
||||
sync.RWMutex
|
||||
}
|
||||
|
||||
type channelData struct {
|
||||
// Public
|
||||
// Name is the name of the channel
|
||||
|
38
src/relay.go
38
src/relay.go
@ -10,14 +10,14 @@ import (
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
func startRelay(ports []string) {
|
||||
func (c *Croc) startRelay(ports []string) {
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(len(ports))
|
||||
for _, port := range ports {
|
||||
go func(port string, wg *sync.WaitGroup) {
|
||||
defer wg.Done()
|
||||
log.Debugf("listening on port %s", port)
|
||||
if err := listener(port); err != nil {
|
||||
if err := c.listener(port); err != nil {
|
||||
log.Error(err)
|
||||
return
|
||||
}
|
||||
@ -26,7 +26,7 @@ func startRelay(ports []string) {
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func listener(port string) (err error) {
|
||||
func (c *Croc) listener(port string) (err error) {
|
||||
server, err := net.Listen("tcp", "0.0.0.0:"+port)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "Error listening on :"+port)
|
||||
@ -40,7 +40,7 @@ func listener(port string) (err error) {
|
||||
}
|
||||
log.Debugf("client %s connected", connection.RemoteAddr().String())
|
||||
go func(port string, connection net.Conn) {
|
||||
errCommunication := clientCommuncation(port, connection)
|
||||
errCommunication := c.clientCommuncation(port, connection)
|
||||
if errCommunication != nil {
|
||||
log.Warnf("relay-%s: %s", connection.RemoteAddr().String(), errCommunication.Error())
|
||||
}
|
||||
@ -48,7 +48,7 @@ func listener(port string) (err error) {
|
||||
}
|
||||
}
|
||||
|
||||
func clientCommuncation(port string, connection net.Conn) (err error) {
|
||||
func (c *Croc) clientCommuncation(port string, connection net.Conn) (err error) {
|
||||
var con1, con2 net.Conn
|
||||
|
||||
// get the channel and UUID from the client
|
||||
@ -67,27 +67,27 @@ func clientCommuncation(port string, connection net.Conn) (err error) {
|
||||
log.Debugf("%s connected with channel %s and uuid %s", connection.RemoteAddr().String(), channel, uuid)
|
||||
|
||||
// validate channel and UUID
|
||||
rs.Lock()
|
||||
if _, ok := rs.channel[channel]; !ok {
|
||||
rs.Unlock()
|
||||
c.rs.Lock()
|
||||
if _, ok := c.rs.channel[channel]; !ok {
|
||||
c.rs.Unlock()
|
||||
err = errors.Errorf("channel %s does not exist", channel)
|
||||
return
|
||||
}
|
||||
if uuid != rs.channel[channel].uuids[0] &&
|
||||
uuid != rs.channel[channel].uuids[1] {
|
||||
rs.Unlock()
|
||||
if uuid != c.rs.channel[channel].uuids[0] &&
|
||||
uuid != c.rs.channel[channel].uuids[1] {
|
||||
c.rs.Unlock()
|
||||
err = errors.Errorf("uuid '%s' is invalid", uuid)
|
||||
return
|
||||
}
|
||||
role := 0
|
||||
if uuid == rs.channel[channel].uuids[1] {
|
||||
if uuid == c.rs.channel[channel].uuids[1] {
|
||||
role = 1
|
||||
}
|
||||
rs.channel[channel].connection[role] = connection
|
||||
c.rs.channel[channel].connection[role] = connection
|
||||
|
||||
con1 = rs.channel[channel].connection[0]
|
||||
con2 = rs.channel[channel].connection[1]
|
||||
rs.Unlock()
|
||||
con1 = c.rs.channel[channel].connection[0]
|
||||
con2 = c.rs.channel[channel].connection[1]
|
||||
c.rs.Unlock()
|
||||
|
||||
if con1 != nil && con2 != nil {
|
||||
var wg sync.WaitGroup
|
||||
@ -100,9 +100,9 @@ func clientCommuncation(port string, connection net.Conn) (err error) {
|
||||
// then set transfer ready
|
||||
go func(channel string, wg *sync.WaitGroup) {
|
||||
// set the channels to ready
|
||||
rs.Lock()
|
||||
rs.channel[channel].TransferReady = true
|
||||
rs.Unlock()
|
||||
c.rs.Lock()
|
||||
c.rs.channel[channel].TransferReady = true
|
||||
c.rs.Unlock()
|
||||
wg.Done()
|
||||
}(channel, &wg)
|
||||
wg.Wait()
|
||||
|
112
src/server.go
112
src/server.go
@ -4,7 +4,6 @@ import (
|
||||
"crypto/elliptic"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/cihub/seelog"
|
||||
@ -13,34 +12,21 @@ import (
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
type relayState struct {
|
||||
channel map[string]*channelData
|
||||
sync.RWMutex
|
||||
}
|
||||
|
||||
var rs relayState
|
||||
|
||||
func init() {
|
||||
rs.Lock()
|
||||
rs.channel = make(map[string]*channelData)
|
||||
rs.Unlock()
|
||||
}
|
||||
|
||||
func startServer(tcpPorts []string, port string) (err error) {
|
||||
func (c *Croc) startServer(tcpPorts []string, port string) (err error) {
|
||||
// start cleanup on dangling channels
|
||||
go channelCleanup()
|
||||
go c.channelCleanup()
|
||||
|
||||
// start server
|
||||
gin.SetMode(gin.ReleaseMode)
|
||||
r := gin.New()
|
||||
r.Use(middleWareHandler(), gin.Recovery())
|
||||
r.POST("/channel", func(c *gin.Context) {
|
||||
r, err := func(c *gin.Context) (r response, err error) {
|
||||
rs.Lock()
|
||||
defer rs.Unlock()
|
||||
r.POST("/channel", func(cg *gin.Context) {
|
||||
r, err := func(cg *gin.Context) (r response, err error) {
|
||||
c.rs.Lock()
|
||||
defer c.rs.Unlock()
|
||||
r.Success = true
|
||||
var p payloadChannel
|
||||
err = c.ShouldBindJSON(&p)
|
||||
err = cg.ShouldBindJSON(&p)
|
||||
if err != nil {
|
||||
log.Errorf("failed on payload %+v", p)
|
||||
err = errors.Wrap(err, "problem parsing /channel")
|
||||
@ -48,21 +34,21 @@ func startServer(tcpPorts []string, port string) (err error) {
|
||||
}
|
||||
|
||||
// determine if channel is invalid
|
||||
if _, ok := rs.channel[p.Channel]; !ok {
|
||||
if _, ok := c.rs.channel[p.Channel]; !ok {
|
||||
err = errors.Errorf("channel '%s' does not exist", p.Channel)
|
||||
return
|
||||
}
|
||||
|
||||
// determine if UUID is invalid for channel
|
||||
if p.UUID != rs.channel[p.Channel].uuids[0] &&
|
||||
p.UUID != rs.channel[p.Channel].uuids[1] {
|
||||
if p.UUID != c.rs.channel[p.Channel].uuids[0] &&
|
||||
p.UUID != c.rs.channel[p.Channel].uuids[1] {
|
||||
err = errors.Errorf("uuid '%s' is invalid", p.UUID)
|
||||
return
|
||||
}
|
||||
|
||||
// check if the action is to close the channel
|
||||
if p.Close {
|
||||
delete(rs.channel, p.Channel)
|
||||
delete(c.rs.channel, p.Channel)
|
||||
r.Message = "deleted " + p.Channel
|
||||
return
|
||||
}
|
||||
@ -74,34 +60,34 @@ func startServer(tcpPorts []string, port string) (err error) {
|
||||
// add a check that the value of key is not enormous
|
||||
|
||||
// add only if it is a valid key
|
||||
if _, ok := rs.channel[p.Channel].State[key]; ok {
|
||||
if _, ok := c.rs.channel[p.Channel].State[key]; ok {
|
||||
assignedKeys = append(assignedKeys, key)
|
||||
rs.channel[p.Channel].State[key] = p.State[key]
|
||||
c.rs.channel[p.Channel].State[key] = p.State[key]
|
||||
}
|
||||
}
|
||||
|
||||
// return the current state
|
||||
r.Data = rs.channel[p.Channel]
|
||||
r.Data = c.rs.channel[p.Channel]
|
||||
|
||||
r.Message = fmt.Sprintf("assigned %d keys: %v", len(assignedKeys), assignedKeys)
|
||||
return
|
||||
}(c)
|
||||
}(cg)
|
||||
if err != nil {
|
||||
log.Debugf("bad /channel: %s", err.Error())
|
||||
r.Message = err.Error()
|
||||
r.Success = false
|
||||
}
|
||||
bR, _ := json.Marshal(r)
|
||||
c.Data(200, "application/json", bR)
|
||||
cg.Data(200, "application/json", bR)
|
||||
})
|
||||
r.POST("/join", func(c *gin.Context) {
|
||||
r, err := func(c *gin.Context) (r response, err error) {
|
||||
rs.Lock()
|
||||
defer rs.Unlock()
|
||||
r.POST("/join", func(cg *gin.Context) {
|
||||
r, err := func(cg *gin.Context) (r response, err error) {
|
||||
c.rs.Lock()
|
||||
defer c.rs.Unlock()
|
||||
r.Success = true
|
||||
|
||||
var p payloadOpen
|
||||
err = c.ShouldBindJSON(&p)
|
||||
err = cg.ShouldBindJSON(&p)
|
||||
if err != nil {
|
||||
log.Errorf("failed on payload %+v", p)
|
||||
err = errors.Wrap(err, "problem parsing")
|
||||
@ -120,57 +106,57 @@ func startServer(tcpPorts []string, port string) (err error) {
|
||||
// find an empty channel
|
||||
p.Channel = "chou"
|
||||
}
|
||||
if _, ok := rs.channel[p.Channel]; ok {
|
||||
if _, ok := c.rs.channel[p.Channel]; ok {
|
||||
// channel is not empty
|
||||
if rs.channel[p.Channel].uuids[p.Role] != "" {
|
||||
if c.rs.channel[p.Channel].uuids[p.Role] != "" {
|
||||
err = errors.Errorf("channel '%s' already occupied by role %d", p.Channel, p.Role)
|
||||
return
|
||||
}
|
||||
}
|
||||
r.Channel = p.Channel
|
||||
if _, ok := rs.channel[r.Channel]; !ok {
|
||||
rs.channel[r.Channel] = newChannelData(r.Channel)
|
||||
if _, ok := c.rs.channel[r.Channel]; !ok {
|
||||
c.rs.channel[r.Channel] = newChannelData(r.Channel)
|
||||
}
|
||||
|
||||
// assign UUID for the role in the channel
|
||||
rs.channel[r.Channel].uuids[p.Role] = uuid4.New().String()
|
||||
r.UUID = rs.channel[r.Channel].uuids[p.Role]
|
||||
c.rs.channel[r.Channel].uuids[p.Role] = uuid4.New().String()
|
||||
r.UUID = c.rs.channel[r.Channel].uuids[p.Role]
|
||||
log.Debugf("(%s) %s has joined as role %d", r.Channel, r.UUID, p.Role)
|
||||
|
||||
// if channel is not open, set initial parameters
|
||||
if !rs.channel[r.Channel].isopen {
|
||||
rs.channel[r.Channel].isopen = true
|
||||
rs.channel[r.Channel].Ports = tcpPorts
|
||||
rs.channel[r.Channel].startTime = time.Now()
|
||||
if !c.rs.channel[r.Channel].isopen {
|
||||
c.rs.channel[r.Channel].isopen = true
|
||||
c.rs.channel[r.Channel].Ports = tcpPorts
|
||||
c.rs.channel[r.Channel].startTime = time.Now()
|
||||
switch curve := p.Curve; curve {
|
||||
case "p224":
|
||||
rs.channel[r.Channel].curve = elliptic.P224()
|
||||
c.rs.channel[r.Channel].curve = elliptic.P224()
|
||||
case "p256":
|
||||
rs.channel[r.Channel].curve = elliptic.P256()
|
||||
c.rs.channel[r.Channel].curve = elliptic.P256()
|
||||
case "p384":
|
||||
rs.channel[r.Channel].curve = elliptic.P384()
|
||||
c.rs.channel[r.Channel].curve = elliptic.P384()
|
||||
case "p521":
|
||||
rs.channel[r.Channel].curve = elliptic.P521()
|
||||
c.rs.channel[r.Channel].curve = elliptic.P521()
|
||||
default:
|
||||
// TODO:
|
||||
// add SIEC
|
||||
p.Curve = "p256"
|
||||
rs.channel[r.Channel].curve = elliptic.P256()
|
||||
c.rs.channel[r.Channel].curve = elliptic.P256()
|
||||
}
|
||||
log.Debugf("(%s) using curve '%s'", r.Channel, p.Curve)
|
||||
rs.channel[r.Channel].State["curve"] = []byte(p.Curve)
|
||||
c.rs.channel[r.Channel].State["curve"] = []byte(p.Curve)
|
||||
}
|
||||
|
||||
r.Message = fmt.Sprintf("assigned role %d in channel '%s'", p.Role, r.Channel)
|
||||
return
|
||||
}(c)
|
||||
}(cg)
|
||||
if err != nil {
|
||||
log.Debugf("bad /join: %s", err.Error())
|
||||
r.Message = err.Error()
|
||||
r.Success = false
|
||||
}
|
||||
bR, _ := json.Marshal(r)
|
||||
c.Data(200, "application/json", bR)
|
||||
cg.Data(200, "application/json", bR)
|
||||
})
|
||||
log.Infof("Running at http://0.0.0.0:" + port)
|
||||
err = r.Run(":" + port)
|
||||
@ -178,32 +164,32 @@ func startServer(tcpPorts []string, port string) (err error) {
|
||||
}
|
||||
|
||||
func middleWareHandler() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
return func(cg *gin.Context) {
|
||||
t := time.Now()
|
||||
// Run next function
|
||||
c.Next()
|
||||
cg.Next()
|
||||
// Log request
|
||||
log.Infof("%v %v %v %s", c.Request.RemoteAddr, c.Request.Method, c.Request.URL, time.Since(t))
|
||||
log.Infof("%v %v %v %s", cg.Request.RemoteAddr, cg.Request.Method, cg.Request.URL, time.Since(t))
|
||||
}
|
||||
}
|
||||
|
||||
func channelCleanup() {
|
||||
func (c *Croc) channelCleanup() {
|
||||
maximumWait := 10 * time.Minute
|
||||
for {
|
||||
rs.Lock()
|
||||
keys := make([]string, len(rs.channel))
|
||||
c.rs.Lock()
|
||||
keys := make([]string, len(c.rs.channel))
|
||||
i := 0
|
||||
for key := range rs.channel {
|
||||
for key := range c.rs.channel {
|
||||
keys[i] = key
|
||||
i++
|
||||
}
|
||||
for _, key := range keys {
|
||||
if time.Since(rs.channel[key].startTime) > maximumWait {
|
||||
if time.Since(c.rs.channel[key].startTime) > maximumWait {
|
||||
log.Debugf("channel %s has exceeded time, deleting", key)
|
||||
delete(rs.channel, key)
|
||||
delete(c.rs.channel, key)
|
||||
}
|
||||
}
|
||||
rs.Unlock()
|
||||
c.rs.Unlock()
|
||||
time.Sleep(1 * time.Minute)
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user