Enforce non-concurrent token usage (#1048)

Co-authored-by: Azeem Shaikh <azeems@google.com>
This commit is contained in:
Azeem Shaikh 2021-09-21 17:52:13 -07:00 committed by GitHub
parent 5fb87cb0de
commit 14dc32f946
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 32 additions and 12 deletions

View File

@ -149,5 +149,4 @@ linters-settings:
- paramTypeCombine
- ptrToRefParam
- typeUnparen
- unnamedResult
- unnecessaryBlock

View File

@ -19,6 +19,7 @@ import (
"net/http"
"strconv"
"sync/atomic"
"time"
"go.opencensus.io/stats"
"go.opencensus.io/tag"
@ -26,6 +27,8 @@ import (
githubstats "github.com/ossf/scorecard/v2/clients/githubrepo/stats"
)
const expiryTimeInSec = 30
// MakeGitHubTransport wraps input RoundTripper with GitHub authorization logic.
func MakeGitHubTransport(innerTransport http.RoundTripper, accessTokens []string) http.RoundTripper {
return &githubTransport{
@ -41,20 +44,27 @@ type githubTransport struct {
}
type tokenAccessor interface {
next(r *http.Request) (string, error)
next() (uint64, string)
release(uint64)
}
func (gt *githubTransport) RoundTrip(r *http.Request) (*http.Response, error) {
token, err := gt.tokens.next(r)
index, token := gt.tokens.next()
defer gt.tokens.release(index)
ctx, err := tag.New(r.Context(), tag.Upsert(githubstats.TokenIndex, fmt.Sprint(index)))
if err != nil {
return nil, fmt.Errorf("error getting Github token: %w", err)
return nil, fmt.Errorf("error updating context: %w", err)
}
*r = *r.WithContext(ctx)
r.Header.Add("Authorization", fmt.Sprintf("Bearer %s", token))
resp, err := gt.innerTransport.RoundTrip(r)
if err != nil {
return nil, fmt.Errorf("error in HTTP: %w", err)
}
ctx, err := tag.New(r.Context(), tag.Upsert(githubstats.ResourceType, resp.Header.Get("X-RateLimit-Resource")))
ctx, err = tag.New(r.Context(), tag.Upsert(githubstats.ResourceType, resp.Header.Get("X-RateLimit-Resource")))
if err != nil {
return nil, fmt.Errorf("error updating context: %w", err)
}
@ -68,24 +78,35 @@ func (gt *githubTransport) RoundTrip(r *http.Request) (*http.Response, error) {
func makeTokenAccessor(accessTokens []string) tokenAccessor {
return &roundRobinAccessor{
accessTokens: accessTokens,
accessState: make([]int64, len(accessTokens)),
}
}
type roundRobinAccessor struct {
accessTokens []string
accessState []int64
counter uint64
}
func (roundRobin *roundRobinAccessor) next(r *http.Request) (string, error) {
func (roundRobin *roundRobinAccessor) next() (uint64, string) {
c := atomic.AddUint64(&roundRobin.counter, 1)
l := len(roundRobin.accessTokens)
index := c % uint64(l)
ctx, err := tag.New(r.Context(), tag.Upsert(githubstats.TokenIndex, fmt.Sprint(index)))
if err != nil {
return "", fmt.Errorf("error updating context: %w", err)
// If selected accessToken is unavailable, wait.
for !atomic.CompareAndSwapInt64(&roundRobin.accessState[index], 0, time.Now().Unix()) {
currVal := roundRobin.accessState[index]
expired := time.Now().After(time.Unix(currVal, 0).Add(expiryTimeInSec * time.Second))
if !expired {
continue
}
if atomic.CompareAndSwapInt64(&roundRobin.accessState[index], currVal, time.Now().Unix()) {
break
}
}
*r = *r.WithContext(ctx)
return roundRobin.accessTokens[index], nil
return index, roundRobin.accessTokens[index]
}
func (roundRobin *roundRobinAccessor) release(index uint64) {
atomic.SwapInt64(&roundRobin.accessState[index], 0)
}