Fixed connection pooling with cancelable timeout (#763)

This commit is contained in:
Ice3man 2022-09-15 22:57:26 +05:30 committed by GitHub
parent 9b4a2ecb0f
commit 674b792770
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 56 additions and 60 deletions

View File

@ -1,47 +0,0 @@
package jarm
import (
"net"
"sync"
"go.uber.org/multierr"
)
type inFlightConns struct {
sync.RWMutex
inflightConns map[net.Conn]struct{}
}
func newInFlightConns() (*inFlightConns, error) {
return &inFlightConns{inflightConns: make(map[net.Conn]struct{})}, nil
}
func (i *inFlightConns) Add(conn net.Conn) {
i.Lock()
defer i.Unlock()
i.inflightConns[conn] = struct{}{}
}
func (i *inFlightConns) Remove(conn net.Conn) {
i.Lock()
defer i.Unlock()
delete(i.inflightConns, conn)
}
func (i *inFlightConns) Close() error {
i.Lock()
defer i.Unlock()
var errs []error
for conn := range i.inflightConns {
if err := conn.Close(); err != nil {
errs = append(errs, err)
}
delete(i.inflightConns, conn)
}
return multierr.Combine(errs...)
}

View File

@ -13,7 +13,10 @@ import (
"github.com/projectdiscovery/fastdialer/fastdialer"
)
const defaultPort int = 443
const (
poolCount = 3
defaultPort = 443
)
type target struct {
Host string
@ -21,11 +24,15 @@ type target struct {
}
// fingerprint probes a single host/port
func fingerprint(dialer *fastdialer.Dialer, t target, timeout time.Duration) string {
func fingerprint(dialer *fastdialer.Dialer, t target, duration int) string {
results := []string{}
addr := net.JoinHostPort(t.Host, fmt.Sprintf("%d", t.Port))
// using connection pool as we need multiple probes
pool, err := newOneTimePool(context.Background(), addr, 3)
timeout := time.Duration(duration) * time.Second
ctx, cancel := context.WithTimeout(context.Background(), (time.Duration(duration*poolCount) * time.Second))
defer cancel()
pool, err := newOneTimePool(ctx, addr, poolCount)
if err != nil {
return ""
}
@ -35,19 +42,18 @@ func fingerprint(dialer *fastdialer.Dialer, t target, timeout time.Duration) str
go pool.Run() //nolint
for _, probe := range jarm.GetProbes(t.Host, t.Port) {
conn, err := pool.Acquire(context.Background())
conn, err := pool.Acquire(ctx)
if err != nil {
continue
return ""
}
if conn == nil {
continue
return ""
}
_ = conn.SetWriteDeadline(time.Now().Add(timeout))
_, err = conn.Write(jarm.BuildProbe(probe))
if err != nil {
results = append(results, "")
_ = conn.Close()
continue
return ""
}
_ = conn.SetReadDeadline(time.Now().Add(timeout))
buff := make([]byte, 1484)
@ -55,8 +61,7 @@ func fingerprint(dialer *fastdialer.Dialer, t target, timeout time.Duration) str
_ = conn.Close()
ans, err := jarm.ParseServerHello(buff, probe)
if err != nil {
results = append(results, "")
continue
return ""
}
results = append(results, ans)
}
@ -76,6 +81,5 @@ func Jarm(dialer *fastdialer.Dialer, host string, duration int) string {
if t.Port == 0 {
t.Port = defaultPort
}
timeout := time.Duration(duration) * time.Second
return fingerprint(dialer, t, timeout)
return fingerprint(dialer, t, duration)
}

View File

@ -3,8 +3,10 @@ package jarm
import (
"context"
"net"
"sync"
"github.com/projectdiscovery/fastdialer/fastdialer"
"go.uber.org/multierr"
)
// oneTimePool is a pool designed to create continous bare connections that are for one time only usage
@ -75,3 +77,40 @@ func (p *oneTimePool) Close() error {
p.cancel()
return p.InFlightConns.Close()
}
type inFlightConns struct {
sync.Mutex
inflightConns map[net.Conn]struct{}
}
func newInFlightConns() (*inFlightConns, error) {
return &inFlightConns{inflightConns: make(map[net.Conn]struct{})}, nil
}
func (i *inFlightConns) Add(conn net.Conn) {
i.Lock()
defer i.Unlock()
i.inflightConns[conn] = struct{}{}
}
func (i *inFlightConns) Remove(conn net.Conn) {
i.Lock()
defer i.Unlock()
delete(i.inflightConns, conn)
}
func (i *inFlightConns) Close() error {
i.Lock()
defer i.Unlock()
var errs []error
for conn := range i.inflightConns {
if err := conn.Close(); err != nil {
errs = append(errs, err)
}
delete(i.inflightConns, conn)
}
return multierr.Combine(errs...)
}