diff --git a/common/hashes/jarm/connpool.go b/common/hashes/jarm/connpool.go deleted file mode 100644 index 56bc6b4..0000000 --- a/common/hashes/jarm/connpool.go +++ /dev/null @@ -1,47 +0,0 @@ -package jarm - -import ( - "net" - "sync" - - "go.uber.org/multierr" -) - -type inFlightConns struct { - sync.RWMutex - inflightConns map[net.Conn]struct{} -} - -func newInFlightConns() (*inFlightConns, error) { - return &inFlightConns{inflightConns: make(map[net.Conn]struct{})}, nil -} - -func (i *inFlightConns) Add(conn net.Conn) { - i.Lock() - defer i.Unlock() - - i.inflightConns[conn] = struct{}{} -} - -func (i *inFlightConns) Remove(conn net.Conn) { - i.Lock() - defer i.Unlock() - - delete(i.inflightConns, conn) -} - -func (i *inFlightConns) Close() error { - i.Lock() - defer i.Unlock() - - var errs []error - - for conn := range i.inflightConns { - if err := conn.Close(); err != nil { - errs = append(errs, err) - } - delete(i.inflightConns, conn) - } - - return multierr.Combine(errs...) -} diff --git a/common/hashes/jarm/jarmhash.go b/common/hashes/jarm/jarmhash.go index ad75640..5b45656 100644 --- a/common/hashes/jarm/jarmhash.go +++ b/common/hashes/jarm/jarmhash.go @@ -13,7 +13,10 @@ import ( "github.com/projectdiscovery/fastdialer/fastdialer" ) -const defaultPort int = 443 +const ( + poolCount = 3 + defaultPort = 443 +) type target struct { Host string @@ -21,11 +24,15 @@ type target struct { } // fingerprint probes a single host/port -func fingerprint(dialer *fastdialer.Dialer, t target, timeout time.Duration) string { +func fingerprint(dialer *fastdialer.Dialer, t target, duration int) string { results := []string{} addr := net.JoinHostPort(t.Host, fmt.Sprintf("%d", t.Port)) - // using connection pool as we need multiple probes - pool, err := newOneTimePool(context.Background(), addr, 3) + timeout := time.Duration(duration) * time.Second + + ctx, cancel := context.WithTimeout(context.Background(), (time.Duration(duration*poolCount) * time.Second)) + defer cancel() + + pool, err := newOneTimePool(ctx, addr, poolCount) if err != nil { return "" } @@ -35,19 +42,18 @@ func fingerprint(dialer *fastdialer.Dialer, t target, timeout time.Duration) str go pool.Run() //nolint for _, probe := range jarm.GetProbes(t.Host, t.Port) { - conn, err := pool.Acquire(context.Background()) + conn, err := pool.Acquire(ctx) if err != nil { - continue + return "" } if conn == nil { - continue + return "" } _ = conn.SetWriteDeadline(time.Now().Add(timeout)) _, err = conn.Write(jarm.BuildProbe(probe)) if err != nil { - results = append(results, "") _ = conn.Close() - continue + return "" } _ = conn.SetReadDeadline(time.Now().Add(timeout)) buff := make([]byte, 1484) @@ -55,8 +61,7 @@ func fingerprint(dialer *fastdialer.Dialer, t target, timeout time.Duration) str _ = conn.Close() ans, err := jarm.ParseServerHello(buff, probe) if err != nil { - results = append(results, "") - continue + return "" } results = append(results, ans) } @@ -76,6 +81,5 @@ func Jarm(dialer *fastdialer.Dialer, host string, duration int) string { if t.Port == 0 { t.Port = defaultPort } - timeout := time.Duration(duration) * time.Second - return fingerprint(dialer, t, timeout) + return fingerprint(dialer, t, duration) } diff --git a/common/hashes/jarm/onetimepool.go b/common/hashes/jarm/onetimepool.go index c625ec8..a30fe6e 100644 --- a/common/hashes/jarm/onetimepool.go +++ b/common/hashes/jarm/onetimepool.go @@ -3,8 +3,10 @@ package jarm import ( "context" "net" + "sync" "github.com/projectdiscovery/fastdialer/fastdialer" + "go.uber.org/multierr" ) // oneTimePool is a pool designed to create continous bare connections that are for one time only usage @@ -75,3 +77,40 @@ func (p *oneTimePool) Close() error { p.cancel() return p.InFlightConns.Close() } + +type inFlightConns struct { + sync.Mutex + inflightConns map[net.Conn]struct{} +} + +func newInFlightConns() (*inFlightConns, error) { + return &inFlightConns{inflightConns: make(map[net.Conn]struct{})}, nil +} + +func (i *inFlightConns) Add(conn net.Conn) { + i.Lock() + defer i.Unlock() + + i.inflightConns[conn] = struct{}{} +} + +func (i *inFlightConns) Remove(conn net.Conn) { + i.Lock() + defer i.Unlock() + + delete(i.inflightConns, conn) +} + +func (i *inFlightConns) Close() error { + i.Lock() + defer i.Unlock() + + var errs []error + for conn := range i.inflightConns { + if err := conn.Close(); err != nil { + errs = append(errs, err) + } + delete(i.inflightConns, conn) + } + return multierr.Combine(errs...) +}