package proxy import ( "crypto/sha256" "crypto/subtle" "io" "net/http" "strings" "github.com/didip/tollbooth/v7" "github.com/didip/tollbooth/v7/libstring" log "github.com/go-pkgz/lgr" R "github.com/go-pkgz/rest" "github.com/gorilla/handlers" "golang.org/x/crypto/bcrypt" "github.com/umputun/reproxy/app/discovery" ) func headersHandler(addHeaders, dropHeaders []string) func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if len(addHeaders) == 0 && len(dropHeaders) == 0 { next.ServeHTTP(w, r) return } // add headers to response for _, h := range addHeaders { elems := strings.Split(h, ":") if len(elems) != 2 { continue } w.Header().Set(strings.TrimSpace(elems[0]), strings.TrimSpace(elems[1])) } // drop headers from request for _, h := range dropHeaders { r.Header.Del(h) } next.ServeHTTP(w, r) }) } } func maxReqSizeHandler(maxSize int64) func(next http.Handler) http.Handler { if maxSize <= 0 { return passThroughHandler } log.Printf("[DEBUG] request size limited to %d", maxSize) return func(next http.Handler) http.Handler { fn := func(w http.ResponseWriter, r *http.Request) { // check ContentLength if r.ContentLength > maxSize { w.WriteHeader(http.StatusRequestEntityTooLarge) return } // check query string size if int64(len(r.URL.RawQuery)) > maxSize { w.WriteHeader(http.StatusRequestURITooLong) return } r.Body = http.MaxBytesReader(w, r.Body, maxSize) next.ServeHTTP(w, r) } return http.HandlerFunc(fn) } } func accessLogHandler(wr io.Writer) func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler { return handlers.CombinedLoggingHandler(wr, next) } } func stdoutLogHandler(enable bool, lh func(next http.Handler) http.Handler) func(next http.Handler) http.Handler { if !enable { return passThroughHandler } log.Printf("[DEBUG] stdout logging enabled") return func(next http.Handler) http.Handler { fn := func(w http.ResponseWriter, r *http.Request) { // don't log to stdout GET ~/(.*)/ping$ requests if r.Method == "GET" && strings.HasSuffix(r.URL.Path, "/ping") { next.ServeHTTP(w, r) return } lh(next).ServeHTTP(w, r) } return http.HandlerFunc(fn) } } func gzipHandler(enabled bool) func(next http.Handler) http.Handler { if !enabled { return passThroughHandler } log.Printf("[DEBUG] gzip enabled") return handlers.CompressHandler } func signatureHandler(enabled bool, version string) func(next http.Handler) http.Handler { if !enabled { return passThroughHandler } log.Printf("[DEBUG] signature headers enabled") return R.AppInfo("reproxy", "umputun", version) } // limiterSystemHandler throttles overall activity of reproxy server, 0 means disabled func limiterSystemHandler(reqSec int) func(next http.Handler) http.Handler { if reqSec <= 0 { return passThroughHandler } return func(h http.Handler) http.Handler { lmt := tollbooth.NewLimiter(float64(reqSec), nil) fn := func(w http.ResponseWriter, r *http.Request) { if httpError := tollbooth.LimitByKeys(lmt, []string{"system"}); httpError != nil { http.Error(w, http.StatusText(http.StatusTooManyRequests), http.StatusTooManyRequests) return } h.ServeHTTP(w, r) } return http.HandlerFunc(fn) } } // limiterUserHandler throttles per user activity. In case if match found the limit is per destination // otherwise global (per user in any case). 0 means disabled func limiterUserHandler(reqSec int) func(next http.Handler) http.Handler { if reqSec <= 0 { return passThroughHandler } return func(h http.Handler) http.Handler { lmt := tollbooth.NewLimiter(float64(reqSec), nil) fn := func(w http.ResponseWriter, r *http.Request) { keys := []string{libstring.RemoteIP(lmt.GetIPLookups(), lmt.GetForwardedForIndexFromBehind(), r)} // add dst proxy if matched if r.Context().Value(ctxMatch) != nil { // route match detected by matchHandler match := r.Context().Value(ctxMatch).(discovery.MatchedRoute) matchType := r.Context().Value(ctxMatchType).(discovery.MatchType) if matchType == discovery.MTProxy { keys = append(keys, match.Mapper.Dst) } } if httpError := tollbooth.LimitByKeys(lmt, keys); httpError != nil { http.Error(w, http.StatusText(http.StatusTooManyRequests), http.StatusTooManyRequests) return } h.ServeHTTP(w, r) } return http.HandlerFunc(fn) } } // basicAuthHandler is a middleware that authenticates via basic auth, if enabled // allowed is a list of user:bcrypt(passwd) strings generated by `htpasswd -nbB user passwd` func basicAuthHandler(enabled bool, allowed []string) func(next http.Handler) http.Handler { if !enabled { return passThroughHandler } unauthorized := func(w http.ResponseWriter) { w.Header().Set("WWW-Authenticate", `Basic realm="Restricted"`) w.WriteHeader(http.StatusUnauthorized) } return func(h http.Handler) http.Handler { fn := func(w http.ResponseWriter, r *http.Request) { username, password, ok := r.BasicAuth() if !ok { unauthorized(w) return } passed := false for _, a := range allowed { alwElems := strings.Split(strings.TrimSpace(a), ":") if len(alwElems) != 2 { continue } // hash to ensure constant time comparison not affected by username length usernameHash := sha256.Sum256([]byte(username)) expectedUsernameHash := sha256.Sum256([]byte(alwElems[0])) expectedPasswordHash := alwElems[1] userMatched := subtle.ConstantTimeCompare(usernameHash[:], expectedUsernameHash[:]) passMatchErr := bcrypt.CompareHashAndPassword([]byte(expectedPasswordHash), []byte(password)) if userMatched == 1 && passMatchErr == nil { passed = true // don't stop here, check all allowed to keep the overall time consistent } } if !passed { unauthorized(w) return } h.ServeHTTP(w, r) } return http.HandlerFunc(fn) } } func passThroughHandler(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { next.ServeHTTP(w, r) }) }