diff --git a/src/models.go b/src/models.go index 0609a25..fa455ea 100644 --- a/src/models.go +++ b/src/models.go @@ -5,6 +5,8 @@ import ( "net" "sync" "time" + + "github.com/gorilla/websocket" ) const ( @@ -24,8 +26,8 @@ type relayState struct { type channelData struct { // Public - // Name is the name of the channel - Name string `json:"name,omitempty"` + // Channel is the name of the channel + Channel string `json:"channel,omitempty"` // State contains state variables that are public to both parties State map[string][]byte `json:"state"` // TransferReady is set by the relaying when both parties have connected @@ -43,6 +45,8 @@ type channelData struct { curve elliptic.Curve // connection information is stored when the clients do connect over TCP connection [2]net.Conn + // websocket connections + websocketConn [2]*websocket.Conn // startTime is the time that the channel was opened startTime time.Time } @@ -58,7 +62,9 @@ type response struct { Message string `json:"message"` } -type payloadOpen struct { +type payload struct { + // Open set to true when trying to open + Open bool `json:"open"` // Channel is used to designate the channel of interest Channel string `json:"channel"` // Role designates which role the person will take; @@ -66,13 +72,14 @@ type payloadOpen struct { Role int `json:"role"` // Curve is the curve to be used. Curve string `json:"curve"` -} -type payloadChannel struct { - Channel string `json:"channel" binding:"required"` - UUID string `json:"uuid" binding:"required"` - State map[string][]byte `json:"state"` - Close bool `json:"close"` + // Update set to true when updating + Update bool `json:"update"` + // State is the state information to be updated + State map[string][]byte `json:"state"` + + // Close set to true when closing: + Close bool `json:"close"` } func newChannelData(name string) (cd *channelData) { diff --git a/src/server.go b/src/server.go index ee7a704..1ed66b3 100644 --- a/src/server.go +++ b/src/server.go @@ -13,168 +13,118 @@ import ( ) func (c *Croc) updateChannel(p payloadChannel) (r response, err error) { + c.rs.Lock() + defer c.rs.Unlock() + r.Success = true + // determine if channel is invalid + if _, ok := c.rs.channel[p.Channel]; !ok { + err = errors.Errorf("channel '%s' does not exist", p.Channel) + return + } + + // determine if UUID is invalid for channel + if p.UUID != c.rs.channel[p.Channel].uuids[0] && + p.UUID != c.rs.channel[p.Channel].uuids[1] { + err = errors.Errorf("uuid '%s' is invalid", p.UUID) + return + } + + // check if the action is to close the channel + if p.Close { + delete(c.rs.channel, p.Channel) + r.Message = "deleted " + p.Channel + return + } + + // assign each key provided + assignedKeys := []string{} + for key := range p.State { + // TODO: + // add a check that the value of key is not enormous + + // add only if it is a valid key + if _, ok := c.rs.channel[p.Channel].State[key]; ok { + assignedKeys = append(assignedKeys, key) + c.rs.channel[p.Channel].State[key] = p.State[key] + } + } + + // return the current state + r.Data = c.rs.channel[p.Channel] + + r.Message = fmt.Sprintf("assigned %d keys: %v", len(assignedKeys), assignedKeys) + return +} + +func (c *Croc) joinChannel(p payloadChannel) (r response, err error) { + c.rs.Lock() + defer c.rs.Unlock() + r.Success = true + + // determine if sender or recipient + if p.Role != 0 && p.Role != 1 { + err = errors.Errorf("no such role of %d", p.Role) + return + } + + // determine channel + if p.Channel == "" { + // TODO: + // find an empty channel + p.Channel = "chou" + } + if _, ok := c.rs.channel[p.Channel]; ok { + // channel is not empty + if c.rs.channel[p.Channel].uuids[p.Role] != "" { + err = errors.Errorf("channel '%s' already occupied by role %d", p.Channel, p.Role) + return + } + } + r.Channel = p.Channel + if _, ok := c.rs.channel[r.Channel]; !ok { + c.rs.channel[r.Channel] = newChannelData(r.Channel) + } + + // assign UUID for the role in the channel + c.rs.channel[r.Channel].uuids[p.Role] = uuid4.New().String() + r.UUID = c.rs.channel[r.Channel].uuids[p.Role] + log.Debugf("(%s) %s has joined as role %d", r.Channel, r.UUID, p.Role) + + // if channel is not open, set initial parameters + if !c.rs.channel[r.Channel].isopen { + c.rs.channel[r.Channel].isopen = true + c.rs.channel[r.Channel].Ports = tcpPorts + c.rs.channel[r.Channel].startTime = time.Now() + switch curve := p.Curve; curve { + case "p224": + c.rs.channel[r.Channel].curve = elliptic.P224() + case "p256": + c.rs.channel[r.Channel].curve = elliptic.P256() + case "p384": + c.rs.channel[r.Channel].curve = elliptic.P384() + case "p521": + c.rs.channel[r.Channel].curve = elliptic.P521() + default: + // TODO: + // add SIEC + p.Curve = "p256" + c.rs.channel[r.Channel].curve = elliptic.P256() + } + log.Debugf("(%s) using curve '%s'", r.Channel, p.Curve) + c.rs.channel[r.Channel].State["curve"] = []byte(p.Curve) + } + + r.Message = fmt.Sprintf("assigned role %d in channel '%s'", p.Role, r.Channel) + return } func (c *Croc) startServer(tcpPorts []string, port string) (err error) { // start cleanup on dangling channels go c.channelCleanup() - // start server - gin.SetMode(gin.ReleaseMode) - r := gin.New() - r.Use(middleWareHandler(), gin.Recovery()) - r.POST("/channel", func(cg *gin.Context) { - r, err := func(cg *gin.Context) (r response, err error) { - c.rs.Lock() - defer c.rs.Unlock() - r.Success = true - var p payloadChannel - err = cg.ShouldBindJSON(&p) - if err != nil { - log.Errorf("failed on payload %+v", p) - err = errors.Wrap(err, "problem parsing /channel") - return - } - - // determine if channel is invalid - if _, ok := c.rs.channel[p.Channel]; !ok { - err = errors.Errorf("channel '%s' does not exist", p.Channel) - return - } - - // determine if UUID is invalid for channel - if p.UUID != c.rs.channel[p.Channel].uuids[0] && - p.UUID != c.rs.channel[p.Channel].uuids[1] { - err = errors.Errorf("uuid '%s' is invalid", p.UUID) - return - } - - // check if the action is to close the channel - if p.Close { - delete(c.rs.channel, p.Channel) - r.Message = "deleted " + p.Channel - return - } - - // assign each key provided - assignedKeys := []string{} - for key := range p.State { - // TODO: - // add a check that the value of key is not enormous - - // add only if it is a valid key - if _, ok := c.rs.channel[p.Channel].State[key]; ok { - assignedKeys = append(assignedKeys, key) - c.rs.channel[p.Channel].State[key] = p.State[key] - } - } - - // return the current state - r.Data = c.rs.channel[p.Channel] - - r.Message = fmt.Sprintf("assigned %d keys: %v", len(assignedKeys), assignedKeys) - return - }(cg) - if err != nil { - log.Debugf("bad /channel: %s", err.Error()) - r.Message = err.Error() - r.Success = false - } - bR, _ := json.Marshal(r) - cg.Data(200, "application/json", bR) - }) - r.POST("/join", func(cg *gin.Context) { - r, err := func(cg *gin.Context) (r response, err error) { - c.rs.Lock() - defer c.rs.Unlock() - r.Success = true - - var p payloadOpen - err = cg.ShouldBindJSON(&p) - if err != nil { - log.Errorf("failed on payload %+v", p) - err = errors.Wrap(err, "problem parsing") - return - } - - // determine if sender or recipient - if p.Role != 0 && p.Role != 1 { - err = errors.Errorf("no such role of %d", p.Role) - return - } - - // determine channel - if p.Channel == "" { - // TODO: - // find an empty channel - p.Channel = "chou" - } - if _, ok := c.rs.channel[p.Channel]; ok { - // channel is not empty - if c.rs.channel[p.Channel].uuids[p.Role] != "" { - err = errors.Errorf("channel '%s' already occupied by role %d", p.Channel, p.Role) - return - } - } - r.Channel = p.Channel - if _, ok := c.rs.channel[r.Channel]; !ok { - c.rs.channel[r.Channel] = newChannelData(r.Channel) - } - - // assign UUID for the role in the channel - c.rs.channel[r.Channel].uuids[p.Role] = uuid4.New().String() - r.UUID = c.rs.channel[r.Channel].uuids[p.Role] - log.Debugf("(%s) %s has joined as role %d", r.Channel, r.UUID, p.Role) - - // if channel is not open, set initial parameters - if !c.rs.channel[r.Channel].isopen { - c.rs.channel[r.Channel].isopen = true - c.rs.channel[r.Channel].Ports = tcpPorts - c.rs.channel[r.Channel].startTime = time.Now() - switch curve := p.Curve; curve { - case "p224": - c.rs.channel[r.Channel].curve = elliptic.P224() - case "p256": - c.rs.channel[r.Channel].curve = elliptic.P256() - case "p384": - c.rs.channel[r.Channel].curve = elliptic.P384() - case "p521": - c.rs.channel[r.Channel].curve = elliptic.P521() - default: - // TODO: - // add SIEC - p.Curve = "p256" - c.rs.channel[r.Channel].curve = elliptic.P256() - } - log.Debugf("(%s) using curve '%s'", r.Channel, p.Curve) - c.rs.channel[r.Channel].State["curve"] = []byte(p.Curve) - } - - r.Message = fmt.Sprintf("assigned role %d in channel '%s'", p.Role, r.Channel) - return - }(cg) - if err != nil { - log.Debugf("bad /join: %s", err.Error()) - r.Message = err.Error() - r.Success = false - } - bR, _ := json.Marshal(r) - cg.Data(200, "application/json", bR) - }) - log.Infof("Running at http://0.0.0.0:" + port) - err = r.Run(":" + port) - return -} - -func middleWareHandler() gin.HandlerFunc { - return func(cg *gin.Context) { - t := time.Now() - // Run next function - cg.Next() - // Log request - log.Infof("%v %v %v %s", cg.Request.RemoteAddr, cg.Request.Method, cg.Request.URL, time.Since(t)) - } + // TODO: + // insert websockets here } func (c *Croc) channelCleanup() {