mirror of
https://github.com/neilotoole/sq.git
synced 2024-12-30 11:46:08 +03:00
225 lines
6.7 KiB
Go
225 lines
6.7 KiB
Go
|
package httpz
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
"crypto/tls"
|
||
|
"errors"
|
||
|
"net/http"
|
||
|
"time"
|
||
|
|
||
|
"github.com/neilotoole/sq/cli/buildinfo"
|
||
|
"github.com/neilotoole/sq/libsq/core/errz"
|
||
|
"github.com/neilotoole/sq/libsq/core/ioz"
|
||
|
"github.com/neilotoole/sq/libsq/core/lg"
|
||
|
"github.com/neilotoole/sq/libsq/core/lg/lga"
|
||
|
"github.com/neilotoole/sq/libsq/core/loz"
|
||
|
)
|
||
|
|
||
|
// Opt is an option that can be passed to [NewClient] to
|
||
|
// configure the client.
|
||
|
type Opt interface {
|
||
|
apply(*http.Transport)
|
||
|
}
|
||
|
|
||
|
var _ Opt = (*OptInsecureSkipVerify)(nil)
|
||
|
|
||
|
// OptInsecureSkipVerify is an Opt that can be passed to NewClient that,
|
||
|
// when true, disables TLS verification.
|
||
|
type OptInsecureSkipVerify bool
|
||
|
|
||
|
func (b OptInsecureSkipVerify) apply(tr *http.Transport) {
|
||
|
tr.TLSClientConfig.InsecureSkipVerify = bool(b)
|
||
|
}
|
||
|
|
||
|
var _ Opt = (*minTLSVersion)(nil)
|
||
|
|
||
|
type minTLSVersion uint16
|
||
|
|
||
|
func (v minTLSVersion) apply(tr *http.Transport) {
|
||
|
if tr.TLSClientConfig == nil {
|
||
|
// We allow tls.VersionTLS10, even though it's not considered
|
||
|
// secure these days. Ultimately this could become a config
|
||
|
// option.
|
||
|
tr.TLSClientConfig = &tls.Config{MinVersion: uint16(v)} //nolint:gosec
|
||
|
} else {
|
||
|
tr.TLSClientConfig = tr.TLSClientConfig.Clone()
|
||
|
tr.TLSClientConfig.MinVersion = uint16(v)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// DefaultTLSVersion is the default minimum TLS version,
|
||
|
// as used by [NewDefaultClient].
|
||
|
var DefaultTLSVersion = minTLSVersion(tls.VersionTLS10)
|
||
|
|
||
|
// OptUserAgent is passed to [NewClient] to set the User-Agent header.
|
||
|
func OptUserAgent(ua string) TripFunc {
|
||
|
return func(next http.RoundTripper, req *http.Request) (*http.Response, error) {
|
||
|
req.Header.Set("User-Agent", ua)
|
||
|
return next.RoundTrip(req)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// DefaultUserAgent is the default User-Agent header value,
|
||
|
// as used by [NewDefaultClient].
|
||
|
var DefaultUserAgent = OptUserAgent(buildinfo.Get().UserAgent())
|
||
|
|
||
|
// OptResponseTimeout is passed to [NewClient] to set the total request timeout,
|
||
|
// including reading the body. This is basically the same as a traditional
|
||
|
// request timeout via context.WithTimeout. If timeout is zero, this is no-op.
|
||
|
//
|
||
|
// Contrast with [OptRequestTimeout].
|
||
|
func OptResponseTimeout(timeout time.Duration) TripFunc {
|
||
|
if timeout <= 0 {
|
||
|
return NopTripFunc
|
||
|
}
|
||
|
|
||
|
return func(next http.RoundTripper, req *http.Request) (*http.Response, error) {
|
||
|
timeoutErr := errz.Wrapf(context.DeadlineExceeded,
|
||
|
"http request not completed within %s timeout", timeout)
|
||
|
ctx, cancelFn := context.WithTimeoutCause(req.Context(), timeout, timeoutErr)
|
||
|
|
||
|
resp, err := next.RoundTrip(req.WithContext(ctx))
|
||
|
if err == nil {
|
||
|
if resp.Body == nil {
|
||
|
// Shouldn't happen, but just in case.
|
||
|
cancelFn()
|
||
|
} else {
|
||
|
// Wrap resp.Body with a ReadCloserNotifier, so that cancelFn
|
||
|
// is called when the body is closed.
|
||
|
resp.Body = ioz.ReadCloserNotifier(resp.Body, func(err error) {
|
||
|
if errors.Is(context.Cause(ctx), timeoutErr) {
|
||
|
lg.FromContext(ctx).Warn("HTTP request not completed within timeout",
|
||
|
lga.Timeout, timeout, lga.URL, req.URL.String())
|
||
|
}
|
||
|
|
||
|
cancelFn()
|
||
|
})
|
||
|
}
|
||
|
return resp, nil
|
||
|
}
|
||
|
|
||
|
// We've got an error. It may or may not be our timeout error.
|
||
|
// Either which way, we need to cancel the context.
|
||
|
defer cancelFn()
|
||
|
|
||
|
if errors.Is(context.Cause(ctx), timeoutErr) {
|
||
|
// If it is our timeout error, we log it.
|
||
|
|
||
|
lg.FromContext(ctx).Warn("HTTP request not completed within timeout",
|
||
|
lga.Timeout, timeout, lga.URL, req.URL.String())
|
||
|
}
|
||
|
|
||
|
return resp, err
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// OptRequestTimeout is passed to [NewClient] to set a timeout for just
|
||
|
// getting the initial response headers. This is useful if you expect
|
||
|
// a response within, say, 2 seconds, but you expect the body to take longer
|
||
|
// to read.
|
||
|
//
|
||
|
// Contrast with [OptResponseTimeout].
|
||
|
func OptRequestTimeout(timeout time.Duration) TripFunc {
|
||
|
if timeout <= 0 {
|
||
|
return NopTripFunc
|
||
|
}
|
||
|
return func(next http.RoundTripper, req *http.Request) (*http.Response, error) {
|
||
|
timerCancelCh := make(chan struct{})
|
||
|
|
||
|
ctx, cancelFn := context.WithCancelCause(req.Context())
|
||
|
t := time.NewTimer(timeout)
|
||
|
go func() {
|
||
|
defer t.Stop()
|
||
|
select {
|
||
|
case <-ctx.Done():
|
||
|
case <-t.C:
|
||
|
cancelErr := errz.Wrapf(context.DeadlineExceeded,
|
||
|
"http response header not received within %s timeout", timeout)
|
||
|
|
||
|
lg.FromContext(ctx).Warn("HTTP header response not received within timeout",
|
||
|
lga.Timeout, timeout, lga.URL, req.URL.String())
|
||
|
|
||
|
cancelFn(cancelErr)
|
||
|
case <-timerCancelCh:
|
||
|
// Stop the timer goroutine.
|
||
|
}
|
||
|
}()
|
||
|
|
||
|
resp, err := errz.Return(next.RoundTrip(req.WithContext(ctx)))
|
||
|
|
||
|
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||
|
if loz.Take(ctx.Done()) {
|
||
|
// The lower-down RoundTripper probably returned ctx.Err(),
|
||
|
// not context.Cause(), so we swap it around here.
|
||
|
if cause := context.Cause(ctx); cause != nil {
|
||
|
err = cause
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// Signal completion of the timer goroutine (it may have already completed).
|
||
|
close(timerCancelCh)
|
||
|
|
||
|
// Don't leak resources; ensure that cancelFn is eventually called.
|
||
|
switch {
|
||
|
case err != nil:
|
||
|
// An error has occurred. It's probable that cancelFn has already been
|
||
|
// called by the timer goroutine, but we call it again just in case.
|
||
|
cancelFn(context.DeadlineExceeded)
|
||
|
case resp != nil && resp.Body != nil:
|
||
|
// Wrap resp.Body with a ReadCloserNotifier, so that cancelFn
|
||
|
// is called when the body is closed.
|
||
|
resp.Body = ioz.ReadCloserNotifier(resp.Body,
|
||
|
func(error) { cancelFn(context.DeadlineExceeded) })
|
||
|
default:
|
||
|
// Not sure if this can actually be reached, but just in case.
|
||
|
cancelFn(context.DeadlineExceeded)
|
||
|
}
|
||
|
|
||
|
return resp, err
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// DefaultHeaderTimeout is the default header timeout as used
|
||
|
// by [NewDefaultClient].
|
||
|
var DefaultHeaderTimeout = OptRequestTimeout(time.Second * 5)
|
||
|
|
||
|
// OptRequestDelay is passed to [NewClient] to delay the request by the
|
||
|
// specified duration. This is useful for testing.
|
||
|
func OptRequestDelay(delay time.Duration) TripFunc {
|
||
|
if delay <= 0 {
|
||
|
return NopTripFunc
|
||
|
}
|
||
|
|
||
|
return func(next http.RoundTripper, req *http.Request) (*http.Response, error) {
|
||
|
ctx := req.Context()
|
||
|
log := lg.FromContext(ctx)
|
||
|
log.Debug("HTTP request delay: started", lga.Val, delay, lga.URL, req.URL.String())
|
||
|
t := time.NewTimer(delay)
|
||
|
defer t.Stop()
|
||
|
select {
|
||
|
case <-ctx.Done():
|
||
|
return nil, context.Cause(ctx)
|
||
|
case <-t.C:
|
||
|
// Continue below
|
||
|
}
|
||
|
|
||
|
log.Debug("HTTP request delay: done", lga.Val, delay, lga.URL, req.URL.String())
|
||
|
return next.RoundTrip(req)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// contextCause returns a TripFunc that extracts the context.Cause error
|
||
|
// from the request context, if any, and returns it as the error.
|
||
|
func contextCause() TripFunc {
|
||
|
return func(next http.RoundTripper, req *http.Request) (*http.Response, error) {
|
||
|
resp, err := next.RoundTrip(req)
|
||
|
if err != nil {
|
||
|
if cause := context.Cause(req.Context()); cause != nil {
|
||
|
err = cause
|
||
|
}
|
||
|
}
|
||
|
return resp, err
|
||
|
}
|
||
|
}
|