package downloader import ( "bufio" "bytes" "errors" "log/slog" "net/http" "strings" "time" "github.com/neilotoole/sq/libsq/core/lg/lga" ) // errNoDateHeader indicates that the HTTP headers contained no Date header. var errNoDateHeader = errors.New("no Date header") // Date parses and returns the value of the Date header. func getDate(respHeaders http.Header) (date time.Time, err error) { dateHeader := respHeaders.Get("date") if dateHeader == "" { err = errNoDateHeader return date, err } return time.Parse(time.RFC1123, dateHeader) } type realClock struct{} func (c *realClock) since(d time.Time) time.Duration { return time.Since(d) } type timer interface { since(d time.Time) time.Duration } var clock timer = &realClock{} // getFreshness will return one of Fresh/stale/transparent based on the cache-control // values of the request and the response // // Fresh indicates the response can be returned // stale indicates that the response needs validating before it is returned // Transparent indicates the response should not be used to fulfil the request // // Because this is only a private cache, 'public' and 'private' in cache-control aren't // significant. Similarly, smax-age isn't used. func getFreshness(respHeaders, reqHeaders http.Header) (freshness State) { respCacheControl := parseCacheControl(respHeaders) reqCacheControl := parseCacheControl(reqHeaders) if _, ok := reqCacheControl["no-cache"]; ok { return Transparent } if _, ok := respCacheControl["no-cache"]; ok { return Stale } if _, ok := reqCacheControl["only-if-cached"]; ok { return Fresh } date, err := getDate(respHeaders) if err != nil { return Stale } currentAge := clock.since(date) var lifetime time.Duration var zeroDuration time.Duration // If a response includes both an Expires header and a max-age directive, // the max-age directive overrides the Expires header, even if the Expires header is more restrictive. if maxAge, ok := respCacheControl["max-age"]; ok { lifetime, err = time.ParseDuration(maxAge + "s") if err != nil { lifetime = zeroDuration } } else { expiresHeader := respHeaders.Get("Expires") if expiresHeader != "" { var expires time.Time expires, err = time.Parse(time.RFC1123, expiresHeader) if err != nil { lifetime = zeroDuration } else { lifetime = expires.Sub(date) } } } if maxAge, ok := reqCacheControl["max-age"]; ok { // the client is willing to accept a response whose age is no greater than the specified time in seconds lifetime, err = time.ParseDuration(maxAge + "s") if err != nil { lifetime = zeroDuration } } if minFresh, ok := reqCacheControl["min-fresh"]; ok { // the client wants a response that will still be fresh for at least the specified number of seconds. var minFreshDuration time.Duration minFreshDuration, err = time.ParseDuration(minFresh + "s") if err == nil { currentAge += minFreshDuration } } if maxStale, ok := reqCacheControl["max-stale"]; ok { // Indicates that the client is willing to accept a response that has exceeded its expiration time. // If max-stale is assigned a value, then the client is willing to accept a response that has exceeded // its expiration time by no more than the specified number of seconds. // If no value is assigned to max-stale, then the client is willing to accept a stale response of any age. // // Responses served only because of a max-stale value are supposed to have a Warning header added to them, // but that seems like a hassle, and is it actually useful? If so, then there needs to be a different // return-value available here. if maxStale == "" { return Fresh } maxStaleDuration, err := time.ParseDuration(maxStale + "s") if err == nil { currentAge -= maxStaleDuration } } if lifetime > currentAge { return Fresh } return Stale } // Returns true if either the request or the response includes the stale-if-error // cache control extension: https://tools.ietf.org/html/rfc5861 func canStaleOnError(respHeaders, reqHeaders http.Header) bool { respCacheControl := parseCacheControl(respHeaders) reqCacheControl := parseCacheControl(reqHeaders) var err error lifetime := time.Duration(-1) if staleMaxAge, ok := respCacheControl["stale-if-error"]; ok { if staleMaxAge == "" { return true } lifetime, err = time.ParseDuration(staleMaxAge + "s") if err != nil { return false } } if staleMaxAge, ok := reqCacheControl["stale-if-error"]; ok { if staleMaxAge == "" { return true } lifetime, err = time.ParseDuration(staleMaxAge + "s") if err != nil { return false } } if lifetime >= 0 { date, err := getDate(respHeaders) if err != nil { return false } currentAge := clock.since(date) if lifetime > currentAge { return true } } return false } func getEndToEndHeaders(respHeaders http.Header) []string { // These headers are always hop-by-hop hopByHopHeaders := map[string]struct{}{ "Connection": {}, "Keep-Alive": {}, "Proxy-Authenticate": {}, "Proxy-Authorization": {}, "Te": {}, "Trailers": {}, "Transfer-Encoding": {}, "Upgrade": {}, } for _, extra := range strings.Split(respHeaders.Get("connection"), ",") { // any header listed in connection, if present, is also considered hop-by-hop if strings.Trim(extra, " ") != "" { hopByHopHeaders[http.CanonicalHeaderKey(extra)] = struct{}{} } } endToEndHeaders := []string{} for respHeader := range respHeaders { if _, ok := hopByHopHeaders[respHeader]; !ok { endToEndHeaders = append(endToEndHeaders, respHeader) } } return endToEndHeaders } func canStore(reqCacheControl, respCacheControl cacheControl) (canStore bool) { if _, ok := respCacheControl["no-store"]; ok { return false } if _, ok := reqCacheControl["no-store"]; ok { return false } return true } func newGatewayTimeoutResponse(req *http.Request) *http.Response { var braw bytes.Buffer braw.WriteString("HTTP/1.1 504 Gateway Timeout\r\n\r\n") resp, err := http.ReadResponse(bufio.NewReader(&braw), req) if err != nil { panic(err) } return resp } // cloneRequest returns a clone of the provided *http.Request. // The clone is a shallow copy of the struct and its Header map. // (This function copyright goauth2 authors: https://code.google.com/p/goauth2). func cloneRequest(r *http.Request) *http.Request { // shallow copy of the struct r2 := new(http.Request) *r2 = *r if ctx := r.Context(); ctx != nil { r2 = r2.WithContext(ctx) } // deep copy of the Header r2.Header = make(http.Header) for k, s := range r.Header { r2.Header[k] = s } return r2 } type cacheControl map[string]string func parseCacheControl(headers http.Header) cacheControl { cc := cacheControl{} ccHeader := headers.Get("Cache-Control") for _, part := range strings.Split(ccHeader, ",") { part = strings.Trim(part, " ") if part == "" { continue } if strings.ContainsRune(part, '=') { keyval := strings.Split(part, "=") cc[strings.Trim(keyval[0], " ")] = strings.Trim(keyval[1], ",") } else { cc[part] = "" } } return cc } // headerAllCommaSepValues returns all comma-separated values (each // with whitespace trimmed) for header name in headers. According to // Section 4.2 of the HTTP/1.1 spec // (http://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html#sec4.2), // values from multiple occurrences of a header should be concatenated, if // the header's value is a comma-separated list. func headerAllCommaSepValues(headers http.Header, name string) []string { var vals []string for _, val := range headers[http.CanonicalHeaderKey(name)] { fields := strings.Split(val, ",") for i, f := range fields { fields[i] = strings.TrimSpace(f) } vals = append(vals, fields...) } return vals } // varyMatches will return false unless all the cached values for the // headers listed in Vary match the new request. func varyMatches(cachedResp *http.Response, req *http.Request) bool { for _, header := range headerAllCommaSepValues(cachedResp.Header, "vary") { header = http.CanonicalHeaderKey(header) if header != "" && req.Header.Get(header) != cachedResp.Header.Get("X-Varied-"+header) { return false } } return true } func logResp(log *slog.Logger, req *http.Request, resp *http.Response, elapsed time.Duration, err error) { if req != nil { log = log.With("response_time", elapsed, lga.Method, req.Method, lga.URL, req.URL.String()) } else { log = log.With("response_time", elapsed, "no_request", true) } if err != nil { log.Warn("HTTP request error", lga.Err, err) return } log.Info("HTTP request completed", lga.Resp, resp) }