scorecard/roundtripper/roundtripper.go

128 lines
2.8 KiB
Go
Raw Normal View History

2020-10-09 17:47:59 +03:00
package roundtripper
import (
"bytes"
"context"
"io/ioutil"
"net/http"
"net/url"
"os"
"strconv"
"sync"
2020-10-09 17:47:59 +03:00
"time"
2020-10-13 19:29:29 +03:00
"go.uber.org/zap"
2020-10-09 17:47:59 +03:00
"golang.org/x/oauth2"
)
const GITHUB_AUTH_TOKEN = "GITHUB_AUTH_TOKEN"
// RateLimitRoundTripper is a rate-limit aware http.Transport for Github.
type RateLimitRoundTripper struct {
2020-10-13 19:29:29 +03:00
Logger *zap.SugaredLogger
2020-10-09 17:47:59 +03:00
InnerTransport http.RoundTripper
}
// NewTransport returns a configured http.Transport for use with GitHub
2020-10-13 19:29:29 +03:00
func NewTransport(ctx context.Context, logger *zap.SugaredLogger) http.RoundTripper {
2020-10-09 17:47:59 +03:00
token := os.Getenv(GITHUB_AUTH_TOKEN)
// Start with oauth
transport := http.DefaultTransport
if token != "" {
ts := oauth2.StaticTokenSource(
&oauth2.Token{AccessToken: token},
)
transport = oauth2.NewClient(ctx, ts).Transport
}
// Wrap that with the rate limiter
rateLimit := &RateLimitRoundTripper{
2020-10-13 19:29:29 +03:00
Logger: logger,
2020-10-09 17:47:59 +03:00
InnerTransport: transport,
}
// Wrap that with the response cacher
cache := &CachingRoundTripper{
2020-10-13 19:29:29 +03:00
Logger: logger,
2020-10-09 17:47:59 +03:00
innerTransport: rateLimit,
respCache: map[url.URL]*http.Response{},
bodyCache: map[url.URL][]byte{},
}
return cache
}
// Roundtrip handles caching and ratelimiting of responses from GitHub.
func (gh *RateLimitRoundTripper) RoundTrip(r *http.Request) (*http.Response, error) {
resp, err := gh.InnerTransport.RoundTrip(r)
if err != nil {
return nil, err
}
rateLimit := resp.Header.Get("X-RateLimit-Remaining")
remaining, err := strconv.Atoi(rateLimit)
if err != nil {
return resp, nil
}
if remaining <= 0 {
reset, err := strconv.Atoi(resp.Header.Get("X-RateLimit-Reset"))
if err != nil {
return resp, nil
}
duration := time.Until(time.Unix(int64(reset), 0))
2020-10-13 19:29:29 +03:00
gh.Logger.Debugf("Rate limit exceeded. Waiting %s to retry...", duration)
2020-10-09 17:47:59 +03:00
// Retry
time.Sleep(duration)
2020-10-13 19:29:29 +03:00
gh.Logger.Warnf("Rate limit exceeded. Retrying...")
2020-10-09 17:47:59 +03:00
return gh.RoundTrip(r)
}
return resp, err
}
type CachingRoundTripper struct {
innerTransport http.RoundTripper
respCache map[url.URL]*http.Response
bodyCache map[url.URL][]byte
mutex sync.Mutex
2020-10-13 19:29:29 +03:00
Logger *zap.SugaredLogger
2020-10-09 17:47:59 +03:00
}
func (rt *CachingRoundTripper) RoundTrip(r *http.Request) (*http.Response, error) {
// Check the cache
rt.mutex.Lock()
defer rt.mutex.Unlock()
2020-10-09 17:47:59 +03:00
resp, ok := rt.respCache[*r.URL]
2020-10-09 17:47:59 +03:00
if ok {
2020-10-13 19:29:29 +03:00
rt.Logger.Debugf("Cache hit on %s", r.URL.String())
2020-10-09 17:47:59 +03:00
resp.Body = ioutil.NopCloser(bytes.NewReader(rt.bodyCache[*r.URL]))
return resp, nil
}
// Get the real value
resp, err := rt.innerTransport.RoundTrip(r)
if err != nil {
return nil, err
}
// Add to cache
if resp.StatusCode == http.StatusOK {
defer resp.Body.Close()
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
return nil, err
}
rt.respCache[*r.URL] = resp
rt.bodyCache[*r.URL] = body
resp.Body = ioutil.NopCloser(bytes.NewReader(body))
}
return resp, err
}