sq/libsq/core/ioz/httpz/opts.go

225 lines
6.7 KiB
Go
Raw Normal View History

package httpz
import (
"context"
"crypto/tls"
"errors"
"net/http"
"time"
"github.com/neilotoole/sq/cli/buildinfo"
"github.com/neilotoole/sq/libsq/core/errz"
"github.com/neilotoole/sq/libsq/core/ioz"
"github.com/neilotoole/sq/libsq/core/lg"
"github.com/neilotoole/sq/libsq/core/lg/lga"
"github.com/neilotoole/sq/libsq/core/loz"
)
// Opt is an option that can be passed to [NewClient] to
// configure the client.
type Opt interface {
apply(*http.Transport)
}
var _ Opt = (*OptInsecureSkipVerify)(nil)
// OptInsecureSkipVerify is an Opt that can be passed to NewClient that,
// when true, disables TLS verification.
type OptInsecureSkipVerify bool
func (b OptInsecureSkipVerify) apply(tr *http.Transport) {
tr.TLSClientConfig.InsecureSkipVerify = bool(b)
}
var _ Opt = (*minTLSVersion)(nil)
type minTLSVersion uint16
func (v minTLSVersion) apply(tr *http.Transport) {
if tr.TLSClientConfig == nil {
// We allow tls.VersionTLS10, even though it's not considered
// secure these days. Ultimately this could become a config
// option.
tr.TLSClientConfig = &tls.Config{MinVersion: uint16(v)} //nolint:gosec
} else {
tr.TLSClientConfig = tr.TLSClientConfig.Clone()
tr.TLSClientConfig.MinVersion = uint16(v)
}
}
// DefaultTLSVersion is the default minimum TLS version,
// as used by [NewDefaultClient].
var DefaultTLSVersion = minTLSVersion(tls.VersionTLS10)
// OptUserAgent is passed to [NewClient] to set the User-Agent header.
func OptUserAgent(ua string) TripFunc {
return func(next http.RoundTripper, req *http.Request) (*http.Response, error) {
req.Header.Set("User-Agent", ua)
return next.RoundTrip(req)
}
}
// DefaultUserAgent is the default User-Agent header value,
// as used by [NewDefaultClient].
var DefaultUserAgent = OptUserAgent(buildinfo.Get().UserAgent())
// OptResponseTimeout is passed to [NewClient] to set the total request timeout,
// including reading the body. This is basically the same as a traditional
// request timeout via context.WithTimeout. If timeout is zero, this is no-op.
//
// Contrast with [OptRequestTimeout].
func OptResponseTimeout(timeout time.Duration) TripFunc {
if timeout <= 0 {
return NopTripFunc
}
return func(next http.RoundTripper, req *http.Request) (*http.Response, error) {
timeoutErr := errz.Wrapf(context.DeadlineExceeded,
"http request not completed within %s timeout", timeout)
ctx, cancelFn := context.WithTimeoutCause(req.Context(), timeout, timeoutErr)
resp, err := next.RoundTrip(req.WithContext(ctx))
if err == nil {
if resp.Body == nil {
// Shouldn't happen, but just in case.
cancelFn()
} else {
// Wrap resp.Body with a ReadCloserNotifier, so that cancelFn
// is called when the body is closed.
resp.Body = ioz.ReadCloserNotifier(resp.Body, func(err error) {
if errors.Is(context.Cause(ctx), timeoutErr) {
lg.FromContext(ctx).Warn("HTTP request not completed within timeout",
lga.Timeout, timeout, lga.URL, req.URL.String())
}
cancelFn()
})
}
return resp, nil
}
// We've got an error. It may or may not be our timeout error.
// Either which way, we need to cancel the context.
defer cancelFn()
if errors.Is(context.Cause(ctx), timeoutErr) {
// If it is our timeout error, we log it.
lg.FromContext(ctx).Warn("HTTP request not completed within timeout",
lga.Timeout, timeout, lga.URL, req.URL.String())
}
return resp, err
}
}
// OptRequestTimeout is passed to [NewClient] to set a timeout for just
// getting the initial response headers. This is useful if you expect
// a response within, say, 2 seconds, but you expect the body to take longer
// to read.
//
// Contrast with [OptResponseTimeout].
func OptRequestTimeout(timeout time.Duration) TripFunc {
if timeout <= 0 {
return NopTripFunc
}
return func(next http.RoundTripper, req *http.Request) (*http.Response, error) {
timerCancelCh := make(chan struct{})
ctx, cancelFn := context.WithCancelCause(req.Context())
t := time.NewTimer(timeout)
go func() {
defer t.Stop()
select {
case <-ctx.Done():
case <-t.C:
cancelErr := errz.Wrapf(context.DeadlineExceeded,
"http response header not received within %s timeout", timeout)
lg.FromContext(ctx).Warn("HTTP header response not received within timeout",
lga.Timeout, timeout, lga.URL, req.URL.String())
cancelFn(cancelErr)
case <-timerCancelCh:
// Stop the timer goroutine.
}
}()
resp, err := errz.Return(next.RoundTrip(req.WithContext(ctx)))
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
if loz.Take(ctx.Done()) {
// The lower-down RoundTripper probably returned ctx.Err(),
// not context.Cause(), so we swap it around here.
if cause := context.Cause(ctx); cause != nil {
err = cause
}
}
}
// Signal completion of the timer goroutine (it may have already completed).
close(timerCancelCh)
// Don't leak resources; ensure that cancelFn is eventually called.
switch {
case err != nil:
// An error has occurred. It's probable that cancelFn has already been
// called by the timer goroutine, but we call it again just in case.
cancelFn(context.DeadlineExceeded)
case resp != nil && resp.Body != nil:
// Wrap resp.Body with a ReadCloserNotifier, so that cancelFn
// is called when the body is closed.
resp.Body = ioz.ReadCloserNotifier(resp.Body,
func(error) { cancelFn(context.DeadlineExceeded) })
default:
// Not sure if this can actually be reached, but just in case.
cancelFn(context.DeadlineExceeded)
}
return resp, err
}
}
// DefaultHeaderTimeout is the default header timeout as used
// by [NewDefaultClient].
var DefaultHeaderTimeout = OptRequestTimeout(time.Second * 5)
// OptRequestDelay is passed to [NewClient] to delay the request by the
// specified duration. This is useful for testing.
func OptRequestDelay(delay time.Duration) TripFunc {
if delay <= 0 {
return NopTripFunc
}
return func(next http.RoundTripper, req *http.Request) (*http.Response, error) {
ctx := req.Context()
log := lg.FromContext(ctx)
log.Debug("HTTP request delay: started", lga.Val, delay, lga.URL, req.URL.String())
t := time.NewTimer(delay)
defer t.Stop()
select {
case <-ctx.Done():
return nil, context.Cause(ctx)
case <-t.C:
// Continue below
}
log.Debug("HTTP request delay: done", lga.Val, delay, lga.URL, req.URL.String())
return next.RoundTrip(req)
}
}
// contextCause returns a TripFunc that extracts the context.Cause error
// from the request context, if any, and returns it as the error.
func contextCause() TripFunc {
return func(next http.RoundTripper, req *http.Request) (*http.Response, error) {
resp, err := next.RoundTrip(req)
if err != nil {
if cause := context.Cause(req.Context()); cause != nil {
err = cause
}
}
return resp, err
}
}