Merge pull request #1762 from projectdiscovery/http_based_speed_control

add http based speed control
This commit is contained in:
Mzack9999 2024-06-12 15:31:56 +02:00 committed by GitHub
commit 8f9bb51b64
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 217 additions and 16 deletions

View File

@ -34,9 +34,12 @@ jobs:
PDCP_API_KEY: "${{ secrets.PDCP_API_KEY }}"
- name: Running example
- name: Testing Example - Simple
run: go run .
working-directory: examples/
working-directory: examples/simple/
- name: Testing Example - Speed Control
run: go run .
working-directory: examples/speed_control/
- name: Integration Tests Linux, macOS
if: runner.os == 'Linux' || runner.os == 'macOS'

View File

@ -16,7 +16,6 @@ func main() {
options := runner.Options{
Methods: "GET",
InputTargetHost: goflags.StringSlice{"scanme.sh", "projectdiscovery.io", "localhost"},
//InputFile: "./targetDomains.txt", // path to file containing the target domains list
OnResult: func(r runner.Result) {
// handle error
if r.Err != nil {

View File

@ -0,0 +1,99 @@
package main
import (
"bytes"
"encoding/json"
"fmt"
"log"
"net/http"
"time"
"github.com/projectdiscovery/goflags"
"github.com/projectdiscovery/gologger"
"github.com/projectdiscovery/gologger/levels"
"github.com/projectdiscovery/httpx/runner"
)
func main() {
gologger.DefaultLogger.SetMaxLevel(levels.LevelVerbose) // increase the verbosity (optional)
// generate urls
var urls []string
for i := 0; i < 100; i++ {
urls = append(urls, fmt.Sprintf("https://scanme.sh/a=%d", i))
}
apiEndpoint := "127.0.0.1:31234"
options := runner.Options{
Methods: "GET",
InputTargetHost: goflags.StringSlice(urls),
Threads: 1,
HttpApiEndpoint: apiEndpoint,
OnResult: func(r runner.Result) {
// handle error
if r.Err != nil {
fmt.Printf("[Err] %s: %s\n", r.Input, r.Err)
return
}
fmt.Printf("%s %s %d\n", r.Input, r.Host, r.StatusCode)
},
}
// after 3 seconds increase the speed to 50
time.AfterFunc(3*time.Second, func() {
client := &http.Client{}
concurrencySettings := runner.Concurrency{Threads: 50}
requestBody, err := json.Marshal(concurrencySettings)
if err != nil {
log.Fatalf("Error creating request body: %v", err)
}
req, err := http.NewRequest("PUT", fmt.Sprintf("http://%s/api/concurrency", apiEndpoint), bytes.NewBuffer(requestBody))
if err != nil {
log.Fatalf("Error creating PUT request: %v", err)
}
req.Header.Set("Content-Type", "application/json")
resp, err := client.Do(req)
if err != nil {
log.Fatalf("Error sending PUT request: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
log.Printf("Failed to update threads, status code: %d", resp.StatusCode)
} else {
log.Println("Threads updated to 50 successfully")
}
})
if err := options.ValidateOptions(); err != nil {
log.Fatal(err)
}
httpxRunner, err := runner.New(&options)
if err != nil {
log.Fatal(err)
}
defer httpxRunner.Close()
httpxRunner.RunEnumeration()
// check the threads
req, err := http.Get(fmt.Sprintf("http://%s/api/concurrency", apiEndpoint))
if err != nil {
log.Fatalf("Error creating GET request: %v", err)
}
var concurrencySettings runner.Concurrency
if err := json.NewDecoder(req.Body).Decode(&concurrencySettings); err != nil {
log.Fatalf("Error decoding response body: %v", err)
}
if concurrencySettings.Threads == 50 {
log.Println("Threads are set to 50")
} else {
log.Fatalf("Fatal error: Threads are not set to 50, current value: %d", concurrencySettings.Threads)
}
}

1
go.mod
View File

@ -39,7 +39,6 @@ require (
github.com/projectdiscovery/useragent v0.0.54
github.com/projectdiscovery/utils v0.1.1
github.com/projectdiscovery/wappalyzergo v0.1.4
github.com/remeh/sizedwaitgroup v1.0.0
github.com/rs/xid v1.5.0
github.com/spaolacci/murmur3 v1.1.0
github.com/stretchr/testify v1.9.0

4
go.sum
View File

@ -62,6 +62,8 @@ github.com/dsnet/compress v0.0.2-0.20210315054119-f66993602bf5/go.mod h1:qssHWj6
github.com/dsnet/golib v0.0.0-20171103203638-1ea166775780/go.mod h1:Lj+Z9rebOhdfkVLjJ8T6VcRQv3SXugXy999NBtR9aFY=
github.com/fatih/color v1.15.0 h1:kOqh6YHBtK8aywxGerMG2Eq3H6Qgoqeo13Bk2Mv/nBs=
github.com/fatih/color v1.15.0/go.mod h1:0h5ZqXfHYED7Bhv2ZJamyIOUej9KtShiJESRwBDUSsw=
github.com/fortytw2/leaktest v1.3.0 h1:u8491cBMTQ8ft8aeV+adlcytMZylmA5nnwwkRZjI8vw=
github.com/fortytw2/leaktest v1.3.0/go.mod h1:jDsjWgpAGjm2CA7WthBh/CdZYEPF31XHquHwclZch5g=
github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo=
github.com/fsnotify/fsnotify v1.6.0 h1:n+5WquG0fcWoWp6xPWfHdbskMCQaFnG6PfBrh1Ky4HY=
github.com/fsnotify/fsnotify v1.6.0/go.mod h1:sl3t1tCWJFWoRz9R8WJCbQihKKwmorjAbSClcnxKAGw=
@ -268,8 +270,6 @@ github.com/quic-go/quic-go v0.42.0 h1:uSfdap0eveIl8KXnipv9K7nlwZ5IqLlYOpJ58u5utp
github.com/quic-go/quic-go v0.42.0/go.mod h1:132kz4kL3F9vxhW3CtQJLDVwcFe5wdWeJXXijhsO57M=
github.com/refraction-networking/utls v1.5.4 h1:9k6EO2b8TaOGsQ7Pl7p9w6PUhx18/ZCeT0WNTZ7Uw4o=
github.com/refraction-networking/utls v1.5.4/go.mod h1:SPuDbBmgLGp8s+HLNc83FuavwZCFoMmExj+ltUHiHUw=
github.com/remeh/sizedwaitgroup v1.0.0 h1:VNGGFwNo/R5+MJBf6yrsr110p0m4/OX4S3DCy7Kyl5E=
github.com/remeh/sizedwaitgroup v1.0.0/go.mod h1:3j2R4OIe/SeS6YDhICBy22RWjJC5eNCJ1V+9+NVNYlo=
github.com/rivo/uniseg v0.1.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
github.com/rivo/uniseg v0.4.4 h1:8TfxU8dW6PdqD27gjM8MVNuicgxIjxpm4K7x4jp8sis=

73
runner/apiendpoint.go Normal file
View File

@ -0,0 +1,73 @@
// TODO: move this to internal package
package runner
import (
"encoding/json"
"net/http"
)
type Concurrency struct {
Threads int `json:"threads"`
}
// Server represents the HTTP server that handles the concurrency settings endpoints.
type Server struct {
addr string
config *Options
}
// New creates a new instance of Server.
func NewServer(addr string, config *Options) *Server {
return &Server{
addr: addr,
config: config,
}
}
// Start initializes the server and its routes, then starts listening on the specified address.
func (s *Server) Start() error {
http.HandleFunc("/api/concurrency", s.handleConcurrency)
if err := http.ListenAndServe(s.addr, nil); err != nil {
return err
}
return nil
}
// handleConcurrency routes the request based on its method to the appropriate handler.
func (s *Server) handleConcurrency(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case http.MethodGet:
s.getSettings(w, r)
case http.MethodPut:
s.updateSettings(w, r)
default:
http.Error(w, "Unsupported HTTP method", http.StatusMethodNotAllowed)
}
}
// GetSettings handles GET requests and returns the current concurrency settings
func (s *Server) getSettings(w http.ResponseWriter, _ *http.Request) {
concurrencySettings := Concurrency{
Threads: s.config.Threads,
}
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(concurrencySettings); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
}
// UpdateSettings handles PUT requests to update the concurrency settings
func (s *Server) updateSettings(w http.ResponseWriter, r *http.Request) {
var newSettings Concurrency
if err := json.NewDecoder(r.Body).Decode(&newSettings); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
if newSettings.Threads > 0 {
s.config.Threads = newSettings.Threads
}
w.WriteHeader(http.StatusOK)
}

View File

@ -34,6 +34,7 @@ import (
const (
two = 2
defaultThreads = 50
DefaultResumeFile = "resume.cfg"
DefaultOutputDirectory = "output"
)
@ -295,6 +296,7 @@ type Options struct {
UseInstalledChrome bool
TlsImpersonate bool
DisableStdin bool
HttpApiEndpoint string
NoScreenshotBytes bool
NoHeadlessBody bool
ScreenshotTimeout int
@ -385,7 +387,7 @@ func ParseOptions() *Options {
)
flagSet.CreateGroup("rate-limit", "Rate-Limit",
flagSet.IntVarP(&options.Threads, "threads", "t", 50, "number of threads to use"),
flagSet.IntVarP(&options.Threads, "threads", "t", defaultThreads, "number of threads to use"),
flagSet.IntVarP(&options.RateLimit, "rate-limit", "rl", 150, "maximum requests to send per second"),
flagSet.IntVarP(&options.RateLimitMinute, "rate-limit-minute", "rlm", 0, "maximum number of requests to send per minute"),
)
@ -451,6 +453,7 @@ func ParseOptions() *Options {
flagSet.BoolVar(&options.NoDecode, "no-decode", false, "avoid decoding body"),
flagSet.BoolVarP(&options.TlsImpersonate, "tls-impersonate", "tlsi", false, "enable experimental client hello (ja3) tls randomization"),
flagSet.BoolVar(&options.DisableStdin, "no-stdin", false, "Disable Stdin processing"),
flagSet.StringVarP(&options.HttpApiEndpoint, "http-api-endpoint", "hae", "", "experimental http api endpoint"),
)
flagSet.CreateGroup("debug", "Debug",
@ -678,6 +681,11 @@ func (options *Options) ValidateOptions() error {
return fmt.Errorf("invalid protocol: %s", options.Protocol)
}
if options.Threads == 0 {
gologger.Info().Msgf("Threads automatically set to %d", defaultThreads)
options.Threads = defaultThreads
}
return nil
}

View File

@ -22,6 +22,7 @@ import (
"sort"
"strconv"
"strings"
"sync"
"time"
"golang.org/x/exp/maps"
@ -52,7 +53,6 @@ import (
urlutil "github.com/projectdiscovery/utils/url"
"github.com/projectdiscovery/ratelimit"
"github.com/remeh/sizedwaitgroup"
// automatic fd max increase if running as root
_ "github.com/projectdiscovery/fdmax/autofdmax"
@ -68,6 +68,7 @@ import (
fileutil "github.com/projectdiscovery/utils/file"
pdhttputil "github.com/projectdiscovery/utils/http"
iputil "github.com/projectdiscovery/utils/ip"
syncutil "github.com/projectdiscovery/utils/sync"
wappalyzer "github.com/projectdiscovery/wappalyzergo"
)
@ -85,6 +86,7 @@ type Runner struct {
browser *Browser
errorPageClassifier *errorpageclassifier.ErrorPageClassifier
pHashClusters []pHashCluster
httpApiEndpoint *Server
}
// picked based on try-fail but it seems to close to one it's used https://www.hackerfactor.com/blog/index.php?/archives/432-Looks-Like-It.html#c1992
@ -364,6 +366,17 @@ func New(options *Options) (*Runner, error) {
runner.errorPageClassifier = errorpageclassifier.New()
if options.HttpApiEndpoint != "" {
apiServer := NewServer(options.HttpApiEndpoint, options)
gologger.Info().Msgf("Listening api endpoint on: %s", options.HttpApiEndpoint)
runner.httpApiEndpoint = apiServer
go func() {
if err := apiServer.Start(); err != nil {
gologger.Error().Msgf("Failed to start API server: %s", err)
}
}()
}
return runner, nil
}
@ -680,12 +693,12 @@ func (r *Runner) RunEnumeration() {
}
// output routine
wgoutput := sizedwaitgroup.New(2)
wgoutput.Add()
var wgoutput sync.WaitGroup
output := make(chan Result)
nextStep := make(chan Result)
wgoutput.Add(1)
go func(output chan Result, nextSteps ...chan Result) {
defer wgoutput.Done()
@ -1065,7 +1078,7 @@ func (r *Runner) RunEnumeration() {
// HTML Summary
// - needs output of previous routine
// - separate goroutine due to incapability of go templates to render from file
wgoutput.Add()
wgoutput.Add(1)
go func(output chan Result) {
defer wgoutput.Done()
@ -1109,7 +1122,7 @@ func (r *Runner) RunEnumeration() {
}
}(nextStep)
wg := sizedwaitgroup.New(r.options.Threads)
wg, _ := syncutil.New(syncutil.WithSize(r.options.Threads))
processItem := func(k string) error {
if r.options.resumeCfg != nil {
@ -1132,10 +1145,10 @@ func (r *Runner) RunEnumeration() {
for _, p := range r.options.requestURIs {
scanopts := r.scanopts.Clone()
scanopts.RequestURI = p
r.process(k, &wg, r.hp, protocol, scanopts, output)
r.process(k, wg, r.hp, protocol, scanopts, output)
}
} else {
r.process(k, &wg, r.hp, protocol, &r.scanopts, output)
r.process(k, wg, r.hp, protocol, &r.scanopts, output)
}
return nil
@ -1224,11 +1237,18 @@ func (r *Runner) GetScanOpts() ScanOptions {
return r.scanopts
}
func (r *Runner) Process(t string, wg *sizedwaitgroup.SizedWaitGroup, protocol string, scanopts *ScanOptions, output chan Result) {
func (r *Runner) Process(t string, wg *syncutil.AdaptiveWaitGroup, protocol string, scanopts *ScanOptions, output chan Result) {
r.process(t, wg, r.hp, protocol, scanopts, output)
}
func (r *Runner) process(t string, wg *sizedwaitgroup.SizedWaitGroup, hp *httpx.HTTPX, protocol string, scanopts *ScanOptions, output chan Result) {
func (r *Runner) process(t string, wg *syncutil.AdaptiveWaitGroup, hp *httpx.HTTPX, protocol string, scanopts *ScanOptions, output chan Result) {
// attempts to set the workpool size to the number of threads
if r.options.Threads > 0 && wg.Size != r.options.Threads {
if err := wg.Resize(context.Background(), r.options.Threads); err != nil {
gologger.Error().Msgf("Could not resize workpool: %s\n", err)
}
}
protocols := []string{protocol}
if scanopts.NoFallback || protocol == httpx.HTTPandHTTPS {
protocols = []string{httpx.HTTPS, httpx.HTTP}