diff --git a/connect.go b/connect.go index b01ff56..08970a2 100644 --- a/connect.go +++ b/connect.go @@ -53,18 +53,6 @@ func NewConnection(flags *Flags) *Connection { } else { c.IsSender = false } - return c -} - -func (c *Connection) Run() { - if len(c.Code) == 0 { - if !c.IsSender { - c.Code = getInput("Enter receive code: ") - } - if len(c.Code) < 5 { - c.Code = GetRandomName() - } - } log.SetFormatter(&log.TextFormatter{}) if c.Debug { @@ -73,6 +61,41 @@ func (c *Connection) Run() { log.SetLevel(log.WarnLevel) } + return c +} + +func (c *Connection) Run() { + log.Debug("checking code validity") + for { + // check code + goodCode := true + m := strings.Split(c.Code, "-") + numThreads, errParse := strconv.Atoi(m[0]) + if len(m) < 2 { + goodCode = false + } else if numThreads > MAX_NUMBER_THREADS || numThreads < 1 { + c.NumberOfConnections = MAX_NUMBER_THREADS + goodCode = false + } else if errParse != nil { + goodCode = false + } + log.Debug(m) + if !goodCode { + if c.IsSender { + c.Code = strconv.Itoa(c.NumberOfConnections) + "-" + GetRandomName() + } else { + if len(c.Code) != 0 { + fmt.Println("Code must begin with number of threads (e.g. 3-some-code)") + } + c.Code = getInput("Enter receive code: ") + } + } else { + break + } + } + // assign number of connections + c.NumberOfConnections, _ = strconv.Atoi(strings.Split(c.Code, "-")[0]) + if c.IsSender { // encrypt the file log.Debug("encrypting...") @@ -112,6 +135,7 @@ func (c *Connection) runClient() { c.bars = make([]*uiprogress.Bar, c.NumberOfConnections) } gotOK := false + gotResponse := false for id := 0; id < c.NumberOfConnections; id++ { go func(id int) { defer wg.Done() @@ -168,32 +192,35 @@ func (c *Connection) runClient() { // have the main thread ask for the okay if id == 0 { fmt.Printf("Receiving file (%d bytes) into: %s\n", c.File.Size, c.File.Name) - getOk := getInput("ok? (y/n): ") - if getOk == "y" { + getOK := getInput("ok? (y/n): ") + if getOK == "y" { gotOK = true - } else { - return } + gotResponse = true } // wait for the main thread to get the okay for limit := 0; limit < 1000; limit++ { - if gotOK { + if gotResponse { break } time.Sleep(10 * time.Millisecond) } if !gotOK { - return + sendMessage("not ok", connection) + } else { + sendMessage("ok", connection) + logger.Debug("receive file") + c.receiveFile(id, connection) } - sendMessage("ok", connection) - logger.Debug("receive file") - c.receiveFile(id, connection) } }(id) } wg.Wait() if !c.IsSender { + if !gotOK { + return + } c.catFile(c.File.Name) encrypted, err := ioutil.ReadFile(c.File.Name + ".encrypted") if err != nil { diff --git a/relay.go b/relay.go index 434f6b4..f558fa6 100644 --- a/relay.go +++ b/relay.go @@ -11,6 +11,8 @@ import ( log "github.com/sirupsen/logrus" ) +const MAX_NUMBER_THREADS = 8 + type connectionMap struct { reciever map[string]net.Conn sender map[string]net.Conn @@ -27,7 +29,7 @@ type Relay struct { func NewRelay(flags *Flags) *Relay { r := new(Relay) r.Debug = flags.Debug - r.NumberOfConnections = flags.NumberOfConnections + r.NumberOfConnections = MAX_NUMBER_THREADS log.SetFormatter(&log.TextFormatter{}) if r.Debug { log.SetLevel(log.DebugLevel) @@ -149,7 +151,6 @@ func (r *Relay) clientCommuncation(id int, connection net.Conn) { } r.connections.RUnlock() time.Sleep(100 * time.Millisecond) - logger.Debug("waiting for metadata") } // send meta data r.connections.RLock() @@ -157,11 +158,13 @@ func (r *Relay) clientCommuncation(id int, connection net.Conn) { r.connections.RUnlock() // check for receiver's consent consent := receiveMessage(connection) - logger.Debug("consent: %s", consent) - logger.Debug("got reciever") - r.connections.Lock() - r.connections.reciever[key] = connection - r.connections.Unlock() + logger.Debugf("consent: %s", consent) + if consent == "ok" { + logger.Debug("got consent") + r.connections.Lock() + r.connections.reciever[key] = connection + r.connections.Unlock() + } } return }