1
0
mirror of https://github.com/schollz/croc.git synced 2024-11-24 08:02:33 +03:00

remove globals

This commit is contained in:
Zack Scholl 2018-06-28 10:12:04 -07:00
parent b2dc1f32f8
commit ffebb472b9
4 changed files with 84 additions and 84 deletions

View File

@ -10,6 +10,10 @@ type Croc struct {
UseCompression bool
CurveType string
AllowLocalDiscovery bool
// private variables
// rs relay state is only for the relay
rs relayState
}
// Init will initialize the croc relay
@ -27,11 +31,15 @@ func Init() (c *Croc) {
// Relay initiates a relay
func (c *Croc) Relay() error {
c.rs.Lock()
c.rs.channel = make(map[string]*channelData)
c.rs.Unlock()
// start relay
go startRelay(c.TcpPorts)
go c.startRelay(c.TcpPorts)
// start server
return startServer(c.TcpPorts, c.ServerPort)
return c.startServer(c.TcpPorts, c.ServerPort)
}
// Send will take an existing file or folder and send it through the croc relay

View File

@ -3,6 +3,7 @@ package croc
import (
"crypto/elliptic"
"net"
"sync"
"time"
)
@ -16,6 +17,11 @@ var (
availableStates = []string{"curve", "h_k", "hh_k", "x", "y"}
)
type relayState struct {
channel map[string]*channelData
sync.RWMutex
}
type channelData struct {
// Public
// Name is the name of the channel

View File

@ -10,14 +10,14 @@ import (
"github.com/pkg/errors"
)
func startRelay(ports []string) {
func (c *Croc) startRelay(ports []string) {
var wg sync.WaitGroup
wg.Add(len(ports))
for _, port := range ports {
go func(port string, wg *sync.WaitGroup) {
defer wg.Done()
log.Debugf("listening on port %s", port)
if err := listener(port); err != nil {
if err := c.listener(port); err != nil {
log.Error(err)
return
}
@ -26,7 +26,7 @@ func startRelay(ports []string) {
wg.Wait()
}
func listener(port string) (err error) {
func (c *Croc) listener(port string) (err error) {
server, err := net.Listen("tcp", "0.0.0.0:"+port)
if err != nil {
return errors.Wrap(err, "Error listening on :"+port)
@ -40,7 +40,7 @@ func listener(port string) (err error) {
}
log.Debugf("client %s connected", connection.RemoteAddr().String())
go func(port string, connection net.Conn) {
errCommunication := clientCommuncation(port, connection)
errCommunication := c.clientCommuncation(port, connection)
if errCommunication != nil {
log.Warnf("relay-%s: %s", connection.RemoteAddr().String(), errCommunication.Error())
}
@ -48,7 +48,7 @@ func listener(port string) (err error) {
}
}
func clientCommuncation(port string, connection net.Conn) (err error) {
func (c *Croc) clientCommuncation(port string, connection net.Conn) (err error) {
var con1, con2 net.Conn
// get the channel and UUID from the client
@ -67,27 +67,27 @@ func clientCommuncation(port string, connection net.Conn) (err error) {
log.Debugf("%s connected with channel %s and uuid %s", connection.RemoteAddr().String(), channel, uuid)
// validate channel and UUID
rs.Lock()
if _, ok := rs.channel[channel]; !ok {
rs.Unlock()
c.rs.Lock()
if _, ok := c.rs.channel[channel]; !ok {
c.rs.Unlock()
err = errors.Errorf("channel %s does not exist", channel)
return
}
if uuid != rs.channel[channel].uuids[0] &&
uuid != rs.channel[channel].uuids[1] {
rs.Unlock()
if uuid != c.rs.channel[channel].uuids[0] &&
uuid != c.rs.channel[channel].uuids[1] {
c.rs.Unlock()
err = errors.Errorf("uuid '%s' is invalid", uuid)
return
}
role := 0
if uuid == rs.channel[channel].uuids[1] {
if uuid == c.rs.channel[channel].uuids[1] {
role = 1
}
rs.channel[channel].connection[role] = connection
c.rs.channel[channel].connection[role] = connection
con1 = rs.channel[channel].connection[0]
con2 = rs.channel[channel].connection[1]
rs.Unlock()
con1 = c.rs.channel[channel].connection[0]
con2 = c.rs.channel[channel].connection[1]
c.rs.Unlock()
if con1 != nil && con2 != nil {
var wg sync.WaitGroup
@ -100,9 +100,9 @@ func clientCommuncation(port string, connection net.Conn) (err error) {
// then set transfer ready
go func(channel string, wg *sync.WaitGroup) {
// set the channels to ready
rs.Lock()
rs.channel[channel].TransferReady = true
rs.Unlock()
c.rs.Lock()
c.rs.channel[channel].TransferReady = true
c.rs.Unlock()
wg.Done()
}(channel, &wg)
wg.Wait()

View File

@ -4,7 +4,6 @@ import (
"crypto/elliptic"
"encoding/json"
"fmt"
"sync"
"time"
log "github.com/cihub/seelog"
@ -13,34 +12,21 @@ import (
"github.com/pkg/errors"
)
type relayState struct {
channel map[string]*channelData
sync.RWMutex
}
var rs relayState
func init() {
rs.Lock()
rs.channel = make(map[string]*channelData)
rs.Unlock()
}
func startServer(tcpPorts []string, port string) (err error) {
func (c *Croc) startServer(tcpPorts []string, port string) (err error) {
// start cleanup on dangling channels
go channelCleanup()
go c.channelCleanup()
// start server
gin.SetMode(gin.ReleaseMode)
r := gin.New()
r.Use(middleWareHandler(), gin.Recovery())
r.POST("/channel", func(c *gin.Context) {
r, err := func(c *gin.Context) (r response, err error) {
rs.Lock()
defer rs.Unlock()
r.POST("/channel", func(cg *gin.Context) {
r, err := func(cg *gin.Context) (r response, err error) {
c.rs.Lock()
defer c.rs.Unlock()
r.Success = true
var p payloadChannel
err = c.ShouldBindJSON(&p)
err = cg.ShouldBindJSON(&p)
if err != nil {
log.Errorf("failed on payload %+v", p)
err = errors.Wrap(err, "problem parsing /channel")
@ -48,21 +34,21 @@ func startServer(tcpPorts []string, port string) (err error) {
}
// determine if channel is invalid
if _, ok := rs.channel[p.Channel]; !ok {
if _, ok := c.rs.channel[p.Channel]; !ok {
err = errors.Errorf("channel '%s' does not exist", p.Channel)
return
}
// determine if UUID is invalid for channel
if p.UUID != rs.channel[p.Channel].uuids[0] &&
p.UUID != rs.channel[p.Channel].uuids[1] {
if p.UUID != c.rs.channel[p.Channel].uuids[0] &&
p.UUID != c.rs.channel[p.Channel].uuids[1] {
err = errors.Errorf("uuid '%s' is invalid", p.UUID)
return
}
// check if the action is to close the channel
if p.Close {
delete(rs.channel, p.Channel)
delete(c.rs.channel, p.Channel)
r.Message = "deleted " + p.Channel
return
}
@ -74,34 +60,34 @@ func startServer(tcpPorts []string, port string) (err error) {
// add a check that the value of key is not enormous
// add only if it is a valid key
if _, ok := rs.channel[p.Channel].State[key]; ok {
if _, ok := c.rs.channel[p.Channel].State[key]; ok {
assignedKeys = append(assignedKeys, key)
rs.channel[p.Channel].State[key] = p.State[key]
c.rs.channel[p.Channel].State[key] = p.State[key]
}
}
// return the current state
r.Data = rs.channel[p.Channel]
r.Data = c.rs.channel[p.Channel]
r.Message = fmt.Sprintf("assigned %d keys: %v", len(assignedKeys), assignedKeys)
return
}(c)
}(cg)
if err != nil {
log.Debugf("bad /channel: %s", err.Error())
r.Message = err.Error()
r.Success = false
}
bR, _ := json.Marshal(r)
c.Data(200, "application/json", bR)
cg.Data(200, "application/json", bR)
})
r.POST("/join", func(c *gin.Context) {
r, err := func(c *gin.Context) (r response, err error) {
rs.Lock()
defer rs.Unlock()
r.POST("/join", func(cg *gin.Context) {
r, err := func(cg *gin.Context) (r response, err error) {
c.rs.Lock()
defer c.rs.Unlock()
r.Success = true
var p payloadOpen
err = c.ShouldBindJSON(&p)
err = cg.ShouldBindJSON(&p)
if err != nil {
log.Errorf("failed on payload %+v", p)
err = errors.Wrap(err, "problem parsing")
@ -120,57 +106,57 @@ func startServer(tcpPorts []string, port string) (err error) {
// find an empty channel
p.Channel = "chou"
}
if _, ok := rs.channel[p.Channel]; ok {
if _, ok := c.rs.channel[p.Channel]; ok {
// channel is not empty
if rs.channel[p.Channel].uuids[p.Role] != "" {
if c.rs.channel[p.Channel].uuids[p.Role] != "" {
err = errors.Errorf("channel '%s' already occupied by role %d", p.Channel, p.Role)
return
}
}
r.Channel = p.Channel
if _, ok := rs.channel[r.Channel]; !ok {
rs.channel[r.Channel] = newChannelData(r.Channel)
if _, ok := c.rs.channel[r.Channel]; !ok {
c.rs.channel[r.Channel] = newChannelData(r.Channel)
}
// assign UUID for the role in the channel
rs.channel[r.Channel].uuids[p.Role] = uuid4.New().String()
r.UUID = rs.channel[r.Channel].uuids[p.Role]
c.rs.channel[r.Channel].uuids[p.Role] = uuid4.New().String()
r.UUID = c.rs.channel[r.Channel].uuids[p.Role]
log.Debugf("(%s) %s has joined as role %d", r.Channel, r.UUID, p.Role)
// if channel is not open, set initial parameters
if !rs.channel[r.Channel].isopen {
rs.channel[r.Channel].isopen = true
rs.channel[r.Channel].Ports = tcpPorts
rs.channel[r.Channel].startTime = time.Now()
if !c.rs.channel[r.Channel].isopen {
c.rs.channel[r.Channel].isopen = true
c.rs.channel[r.Channel].Ports = tcpPorts
c.rs.channel[r.Channel].startTime = time.Now()
switch curve := p.Curve; curve {
case "p224":
rs.channel[r.Channel].curve = elliptic.P224()
c.rs.channel[r.Channel].curve = elliptic.P224()
case "p256":
rs.channel[r.Channel].curve = elliptic.P256()
c.rs.channel[r.Channel].curve = elliptic.P256()
case "p384":
rs.channel[r.Channel].curve = elliptic.P384()
c.rs.channel[r.Channel].curve = elliptic.P384()
case "p521":
rs.channel[r.Channel].curve = elliptic.P521()
c.rs.channel[r.Channel].curve = elliptic.P521()
default:
// TODO:
// add SIEC
p.Curve = "p256"
rs.channel[r.Channel].curve = elliptic.P256()
c.rs.channel[r.Channel].curve = elliptic.P256()
}
log.Debugf("(%s) using curve '%s'", r.Channel, p.Curve)
rs.channel[r.Channel].State["curve"] = []byte(p.Curve)
c.rs.channel[r.Channel].State["curve"] = []byte(p.Curve)
}
r.Message = fmt.Sprintf("assigned role %d in channel '%s'", p.Role, r.Channel)
return
}(c)
}(cg)
if err != nil {
log.Debugf("bad /join: %s", err.Error())
r.Message = err.Error()
r.Success = false
}
bR, _ := json.Marshal(r)
c.Data(200, "application/json", bR)
cg.Data(200, "application/json", bR)
})
log.Infof("Running at http://0.0.0.0:" + port)
err = r.Run(":" + port)
@ -178,32 +164,32 @@ func startServer(tcpPorts []string, port string) (err error) {
}
func middleWareHandler() gin.HandlerFunc {
return func(c *gin.Context) {
return func(cg *gin.Context) {
t := time.Now()
// Run next function
c.Next()
cg.Next()
// Log request
log.Infof("%v %v %v %s", c.Request.RemoteAddr, c.Request.Method, c.Request.URL, time.Since(t))
log.Infof("%v %v %v %s", cg.Request.RemoteAddr, cg.Request.Method, cg.Request.URL, time.Since(t))
}
}
func channelCleanup() {
func (c *Croc) channelCleanup() {
maximumWait := 10 * time.Minute
for {
rs.Lock()
keys := make([]string, len(rs.channel))
c.rs.Lock()
keys := make([]string, len(c.rs.channel))
i := 0
for key := range rs.channel {
for key := range c.rs.channel {
keys[i] = key
i++
}
for _, key := range keys {
if time.Since(rs.channel[key].startTime) > maximumWait {
if time.Since(c.rs.channel[key].startTime) > maximumWait {
log.Debugf("channel %s has exceeded time, deleting", key)
delete(rs.channel, key)
delete(c.rs.channel, key)
}
}
rs.Unlock()
c.rs.Unlock()
time.Sleep(1 * time.Minute)
}
}