Only pull required branch names (#1965)

Co-authored-by: Azeem Shaikh <azeems@google.com>
This commit is contained in:
Azeem Shaikh 2022-05-27 15:25:24 -07:00 committed by GitHub
parent 1471c807da
commit 70d045b9ef
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 359 additions and 282 deletions

View File

@ -33,11 +33,14 @@ func getBranchName(branch *clients.BranchRef) string {
return *branch.Name return *branch.Name
} }
func getBranch(branches []*clients.BranchRef, name string) *clients.BranchRef { func getBranch(branches []*clients.BranchRef, name string, isNonAdmin bool) *clients.BranchRef {
for _, branch := range branches { for _, branch := range branches {
branchName := getBranchName(branch) branchName := getBranchName(branch)
if branchName == name { if branchName == name {
return branch if !isNonAdmin {
return branch
}
return scrubBranch(branch)
} }
} }
return nil return nil
@ -49,14 +52,6 @@ func scrubBranch(branch *clients.BranchRef) *clients.BranchRef {
return ret return ret
} }
func scrubBranches(branches []*clients.BranchRef) []*clients.BranchRef {
ret := make([]*clients.BranchRef, len(branches))
for i, branch := range branches {
ret[i] = scrubBranch(branch)
}
return ret
}
func TestReleaseAndDevBranchProtected(t *testing.T) { func TestReleaseAndDevBranchProtected(t *testing.T) {
t.Parallel() t.Parallel()
@ -399,10 +394,7 @@ func TestReleaseAndDevBranchProtected(t *testing.T) {
mockRepoClient := mockrepo.NewMockRepoClient(ctrl) mockRepoClient := mockrepo.NewMockRepoClient(ctrl)
mockRepoClient.EXPECT().GetDefaultBranch(). mockRepoClient.EXPECT().GetDefaultBranch().
DoAndReturn(func() (*clients.BranchRef, error) { DoAndReturn(func() (*clients.BranchRef, error) {
defaultBranch := getBranch(tt.branches, tt.defaultBranch) defaultBranch := getBranch(tt.branches, tt.defaultBranch, tt.nonadmin)
if defaultBranch != nil && tt.nonadmin {
return scrubBranch(defaultBranch), nil
}
return defaultBranch, nil return defaultBranch, nil
}).AnyTimes() }).AnyTimes()
mockRepoClient.EXPECT().ListReleases(). mockRepoClient.EXPECT().ListReleases().
@ -415,12 +407,9 @@ func TestReleaseAndDevBranchProtected(t *testing.T) {
} }
return ret, nil return ret, nil
}).AnyTimes() }).AnyTimes()
mockRepoClient.EXPECT().ListBranches(). mockRepoClient.EXPECT().GetBranch(gomock.Any()).
DoAndReturn(func() ([]*clients.BranchRef, error) { DoAndReturn(func(b string) (*clients.BranchRef, error) {
if tt.nonadmin { return getBranch(tt.branches, b, tt.nonadmin), nil
return scrubBranches(tt.branches), nil
}
return tt.branches, nil
}).AnyTimes() }).AnyTimes()
dl := scut.TestDetailLogger{} dl := scut.TestDetailLogger{}
req := checker.CheckRequest{ req := checker.CheckRequest{

View File

@ -40,7 +40,6 @@ func SecurityPolicy(name string, dl checker.DetailLogger, r *checker.SecurityPol
} }
if msg.Type == checker.FileTypeURL { if msg.Type == checker.FileTypeURL {
msg.Text = "security policy detected in org repo" msg.Text = "security policy detected in org repo"
} else { } else {
msg.Text = "security policy detected in current repo" msg.Text = "security policy detected in current repo"
} }

View File

@ -15,37 +15,56 @@
package raw package raw
import ( import (
"errors"
"fmt" "fmt"
"regexp" "regexp"
"github.com/ossf/scorecard/v4/checker" "github.com/ossf/scorecard/v4/checker"
"github.com/ossf/scorecard/v4/clients" "github.com/ossf/scorecard/v4/clients"
sce "github.com/ossf/scorecard/v4/errors"
) )
const master = "master" const master = "master"
type branchMap map[string]*clients.BranchRef var commit = regexp.MustCompile("^[a-f0-9]{40}$")
type branchSet struct {
exists map[string]bool
set []clients.BranchRef
}
func (set *branchSet) add(branch *clients.BranchRef) bool {
if branch != nil &&
branch.Name != nil &&
*branch.Name != "" &&
!set.exists[*branch.Name] {
set.set = append(set.set, *branch)
set.exists[*branch.Name] = true
return true
}
return false
}
func (set branchSet) contains(branch string) bool {
_, contains := set.exists[branch]
return contains
}
// BranchProtection retrieves the raw data for the Branch-Protection check. // BranchProtection retrieves the raw data for the Branch-Protection check.
func BranchProtection(c clients.RepoClient) (checker.BranchProtectionsData, error) { func BranchProtection(c clients.RepoClient) (checker.BranchProtectionsData, error) {
// Checks branch protection on both release and development branch. branches := branchSet{
// Get all branches. This will include information on whether they are protected. exists: make(map[string]bool),
branches, err := c.ListBranches() }
// Add default branch.
defaultBranch, err := c.GetDefaultBranch()
if err != nil { if err != nil {
return checker.BranchProtectionsData{}, fmt.Errorf("%w", err) return checker.BranchProtectionsData{}, fmt.Errorf("%w", err)
} }
branchesMap := getBranchMapFrom(branches) branches.add(defaultBranch)
// Get release branches. // Get release branches.
releases, err := c.ListReleases() releases, err := c.ListReleases()
if err != nil { if err != nil {
return checker.BranchProtectionsData{}, fmt.Errorf("%w", err) return checker.BranchProtectionsData{}, fmt.Errorf("%w", err)
} }
commit := regexp.MustCompile("^[a-f0-9]{40}$")
checkBranches := make(map[string]bool)
for _, release := range releases { for _, release := range releases {
if release.TargetCommitish == "" { if release.TargetCommitish == "" {
// Log with a named error if target_commitish is nil. // Log with a named error if target_commitish is nil.
@ -57,78 +76,47 @@ func BranchProtection(c clients.RepoClient) (checker.BranchProtectionsData, erro
continue continue
} }
// Try to resolve the branch name. if branches.contains(release.TargetCommitish) ||
b, err := branchesMap.getBranchByName(release.TargetCommitish) branches.contains(branchRedirect(release.TargetCommitish)) {
if err != nil { continue
// If the commitish branch is still not found, fail.
return checker.BranchProtectionsData{}, err
} }
// Branch is valid, add to list of branches to check. // Get the associated release branch.
checkBranches[*b.Name] = true branchRef, err := c.GetBranch(release.TargetCommitish)
}
// Add default branch.
defaultBranch, err := c.GetDefaultBranch()
if err != nil {
return checker.BranchProtectionsData{}, fmt.Errorf("%w", err)
}
defaultBranchName := getBranchName(defaultBranch)
if defaultBranchName != "" {
checkBranches[defaultBranchName] = true
}
rawData := checker.BranchProtectionsData{}
// Check protections on all the branches.
for b := range checkBranches {
branch, err := branchesMap.getBranchByName(b)
if err != nil { if err != nil {
if errors.Is(err, errInternalBranchNotFound) { return checker.BranchProtectionsData{},
continue fmt.Errorf("error during GetBranch(%s): %w", release.TargetCommitish, err)
} }
return checker.BranchProtectionsData{}, err if branches.add(branchRef) {
continue
} }
rawData.Branches = append(rawData.Branches, *branch) // Couldn't find the branch check for redirects.
redirectBranch := branchRedirect(release.TargetCommitish)
if redirectBranch == "" {
continue
}
branchRef, err = c.GetBranch(redirectBranch)
if err != nil {
return checker.BranchProtectionsData{},
fmt.Errorf("error during GetBranch(%s) %w", redirectBranch, err)
}
branches.add(branchRef)
// Branch doesn't exist or was deleted. Continue.
} }
// No error, return the data. // No error, return the data.
return rawData, nil return checker.BranchProtectionsData{
Branches: branches.set,
}, nil
} }
func (b branchMap) getBranchByName(name string) (*clients.BranchRef, error) { func branchRedirect(name string) string {
val, exists := b[name]
if exists {
return val, nil
}
// Ideally, we should check using repositories.GetBranch if there was a branch redirect. // Ideally, we should check using repositories.GetBranch if there was a branch redirect.
// See https://github.com/google/go-github/issues/1895 // See https://github.com/google/go-github/issues/1895
// For now, handle the common master -> main redirect. // For now, handle the common master -> main redirect.
if name == master { if name == master {
val, exists := b["main"] return "main"
if exists {
return val, nil
}
} }
return nil, sce.WithMessage(sce.ErrScorecardInternal, return ""
fmt.Sprintf("could not find branch name %s: %v", name, errInternalBranchNotFound))
}
func getBranchMapFrom(branches []*clients.BranchRef) branchMap {
ret := make(branchMap)
for _, branch := range branches {
branchName := getBranchName(branch)
if branchName != "" {
ret[branchName] = branch
}
}
return ret
}
func getBranchName(branch *clients.BranchRef) string {
if branch == nil || branch.Name == nil {
return ""
}
return *branch.Name
} }

View File

@ -15,160 +15,255 @@
package raw package raw
import ( import (
"errors"
"testing" "testing"
"github.com/golang/mock/gomock"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
"github.com/ossf/scorecard/v4/checker"
"github.com/ossf/scorecard/v4/clients" "github.com/ossf/scorecard/v4/clients"
mockrepo "github.com/ossf/scorecard/v4/clients/mockclients"
) )
var branch = "master" var (
errBPTest = errors.New("test error")
defaultBranchName = "default"
releaseBranchName = "release-branch"
mainBranchName = "main"
)
func Test_getBranchName(t *testing.T) { // nolint: govet
t.Parallel() type branchArg struct {
type args struct { err error
branch *clients.BranchRef name string
} branchRef *clients.BranchRef
tests := []struct { defaultBranch bool
name string
args args
want string
}{
{
name: "simple",
args: args{
branch: &clients.BranchRef{
Name: &branch,
},
},
want: master,
},
{
name: "empty name",
args: args{
branch: &clients.BranchRef{},
},
want: "",
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
if got := getBranchName(tt.args.branch); got != tt.want {
t.Errorf("getBranchName() = %v, want %v", got, tt.want)
}
})
}
} }
func Test_getBranchMapFrom(t *testing.T) { type branchesArg []branchArg
t.Parallel()
type args struct { func (ba branchesArg) getDefaultBranch() (*clients.BranchRef, error) {
branches []*clients.BranchRef for _, branch := range ba {
if branch.defaultBranch {
return branch.branchRef, branch.err
}
} }
//nolint return nil, nil
}
func (ba branchesArg) getBranch(b string) (*clients.BranchRef, error) {
for _, branch := range ba {
if branch.name == b {
return branch.branchRef, branch.err
}
}
return nil, nil
}
func TestBranchProtection(t *testing.T) {
t.Parallel()
// nolint: govet
tests := []struct { tests := []struct {
name string name string
args args branches branchesArg
want branchMap releases []clients.Release
releasesErr error
want checker.BranchProtectionsData
wantErr error
}{ }{
{ {
name: "simple", name: "default-branch-err",
args: args{ branches: branchesArg{
branches: []*clients.BranchRef{ {
{ name: defaultBranchName,
Name: &branch, err: errBPTest,
},
},
},
{
name: "null-default-branch-only",
branches: branchesArg{
{
name: defaultBranchName,
defaultBranch: true,
branchRef: nil,
},
},
},
{
name: "default-branch-only",
branches: branchesArg{
{
name: defaultBranchName,
defaultBranch: true,
branchRef: &clients.BranchRef{
Name: &defaultBranchName,
}, },
}, },
}, },
want: branchMap{ want: checker.BranchProtectionsData{
master: &clients.BranchRef{ Branches: []clients.BranchRef{
Name: &branch, {
Name: &defaultBranchName,
},
}, },
}, },
}, },
{
name: "list-releases-error",
releasesErr: errBPTest,
wantErr: errBPTest,
},
{
name: "no-releases",
},
{
name: "empty-targetcommitish",
releases: []clients.Release{
{
TargetCommitish: "",
},
},
wantErr: errInternalCommitishNil,
},
{
name: "release-branch-err",
releases: []clients.Release{
{
TargetCommitish: releaseBranchName,
},
},
branches: branchesArg{
{
name: releaseBranchName,
err: errBPTest,
},
},
wantErr: errBPTest,
},
{
name: "nil-release-branch",
releases: []clients.Release{
{
TargetCommitish: releaseBranchName,
},
},
branches: branchesArg{
{
name: releaseBranchName,
branchRef: nil,
},
},
},
{
name: "add-release-branch",
releases: []clients.Release{
{
TargetCommitish: releaseBranchName,
},
},
branches: branchesArg{
{
name: releaseBranchName,
branchRef: &clients.BranchRef{
Name: &releaseBranchName,
},
},
},
want: checker.BranchProtectionsData{
Branches: []clients.BranchRef{
{
Name: &releaseBranchName,
},
},
},
},
{
name: "master-to-main-redirect",
releases: []clients.Release{
{
TargetCommitish: "master",
},
},
branches: branchesArg{
{
name: mainBranchName,
branchRef: &clients.BranchRef{
Name: &mainBranchName,
},
},
},
want: checker.BranchProtectionsData{
Branches: []clients.BranchRef{
{
Name: &mainBranchName,
},
},
},
},
{
name: "default-and-release-branches",
releases: []clients.Release{
{
TargetCommitish: releaseBranchName,
},
},
branches: branchesArg{
{
name: defaultBranchName,
defaultBranch: true,
branchRef: &clients.BranchRef{
Name: &defaultBranchName,
},
},
{
name: releaseBranchName,
branchRef: &clients.BranchRef{
Name: &releaseBranchName,
},
},
},
want: checker.BranchProtectionsData{
Branches: []clients.BranchRef{
{
Name: &defaultBranchName,
},
{
Name: &releaseBranchName,
},
},
},
},
// TODO: Add tests for commitSHA regex matching.
} }
for _, tt := range tests { for _, tt := range tests {
tt := tt tt := tt
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
t.Parallel() t.Parallel()
if got := getBranchMapFrom(tt.args.branches); !cmp.Equal(got, tt.want) { ctrl := gomock.NewController(t)
t.Errorf("getBranchMapFrom() = %v, want %v", got, tt.want) mockRepoClient := mockrepo.NewMockRepoClient(ctrl)
} mockRepoClient.EXPECT().GetDefaultBranch().
}) AnyTimes().DoAndReturn(func() (*clients.BranchRef, error) {
} return tt.branches.getDefaultBranch()
} })
mockRepoClient.EXPECT().GetBranch(gomock.Any()).AnyTimes().
DoAndReturn(func(branch string) (*clients.BranchRef, error) {
return tt.branches.getBranch(branch)
})
mockRepoClient.EXPECT().ListReleases().AnyTimes().
DoAndReturn(func() ([]clients.Release, error) {
return tt.releases, tt.releasesErr
})
func Test_branchMap_getBranchByName(t *testing.T) { rawData, err := BranchProtection(mockRepoClient)
main := "main" if !errors.Is(err, tt.wantErr) {
t.Parallel() t.Errorf("failed. expected: %v, got: %v", tt.wantErr, err)
type args struct { t.Fail()
name string
}
//nolint
tests := []struct {
name string
b branchMap
args args
want *clients.BranchRef
wantErr bool
}{
{
name: "simple",
b: branchMap{
master: &clients.BranchRef{
Name: &branch,
},
},
args: args{
name: master,
},
want: &clients.BranchRef{
Name: &branch,
},
},
{
name: "main",
b: branchMap{
master: &clients.BranchRef{
Name: &main,
},
main: &clients.BranchRef{
Name: &main,
},
},
args: args{
name: "main",
},
want: &clients.BranchRef{
Name: &main,
},
},
{
name: "not found",
b: branchMap{
master: &clients.BranchRef{
Name: &branch,
},
},
args: args{
name: "not-found",
},
wantErr: true,
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got, err := tt.b.getBranchByName(tt.args.name)
if (err != nil) != tt.wantErr {
t.Errorf("branchMap.getBranchByName() error = %v, wantErr %v", err, tt.wantErr)
return
} }
if !cmp.Equal(got, tt.want) { if !cmp.Equal(rawData, tt.want) {
t.Errorf("branchMap.getBranchByName() = %v, want %v", got, tt.want) t.Errorf("failed. expected: %v, got: %v", tt.want, rawData)
t.Fail()
} }
}) })
} }

View File

@ -19,9 +19,8 @@ import (
) )
var ( var (
errInternalCommitishNil = errors.New("commitish is nil") errInternalCommitishNil = errors.New("commitish is nil")
errInternalBranchNotFound = errors.New("branch not found") errInvalidArgType = errors.New("invalid arg type")
errInvalidArgType = errors.New("invalid arg type") errInvalidArgLength = errors.New("invalid arg length")
errInvalidArgLength = errors.New("invalid arg length") errInvalidGitHubWorkflow = errors.New("invalid GitHub workflow")
errInvalidGitHubWorkflow = errors.New("invalid GitHub workflow")
) )

View File

@ -28,8 +28,7 @@ import (
) )
const ( const (
refsToAnalyze = 30 refPrefix = "refs/heads/"
refPrefix = "refs/heads/"
) )
// See https://github.community/t/graphql-api-protected-branch/14380 // See https://github.community/t/graphql-api-protected-branch/14380
@ -97,30 +96,30 @@ type branch struct {
RefUpdateRule *refUpdateRule RefUpdateRule *refUpdateRule
BranchProtectionRule *branchProtectionRule BranchProtectionRule *branchProtectionRule
} }
type defaultBranchData struct {
// nolint:govet // internal structure, ignore.
type branchesData struct {
Repository struct { Repository struct {
DefaultBranchRef branch DefaultBranchRef *branch
Refs struct {
Nodes []branch
} `graphql:"refs(first: $refsToAnalyze, refPrefix: $refPrefix)"`
} `graphql:"repository(owner: $owner, name: $name)"` } `graphql:"repository(owner: $owner, name: $name)"`
RateLimit struct { RateLimit struct {
Cost *int Cost *int
} }
} }
type branchData struct {
Repository struct {
Ref *branch `graphql:"ref(qualifiedName: $branchRefName)"`
} `graphql:"repository(owner: $owner, name: $name)"`
}
type branchesHandler struct { type branchesHandler struct {
ghClient *github.Client ghClient *github.Client
graphClient *githubv4.Client graphClient *githubv4.Client
data *branchesData data *defaultBranchData
once *sync.Once once *sync.Once
ctx context.Context ctx context.Context
errSetup error errSetup error
repourl *repoURL repourl *repoURL
defaultBranchRef *clients.BranchRef defaultBranchRef *clients.BranchRef
branches []*clients.BranchRef
} }
func (handler *branchesHandler) init(ctx context.Context, repourl *repoURL) { func (handler *branchesHandler) init(ctx context.Context, repourl *repoURL) {
@ -137,22 +136,35 @@ func (handler *branchesHandler) setup() error {
return return
} }
vars := map[string]interface{}{ vars := map[string]interface{}{
"owner": githubv4.String(handler.repourl.owner), "owner": githubv4.String(handler.repourl.owner),
"name": githubv4.String(handler.repourl.repo), "name": githubv4.String(handler.repourl.repo),
"refsToAnalyze": githubv4.Int(refsToAnalyze),
"refPrefix": githubv4.String(refPrefix),
} }
handler.data = new(branchesData) handler.data = new(defaultBranchData)
if err := handler.graphClient.Query(handler.ctx, handler.data, vars); err != nil { if err := handler.graphClient.Query(handler.ctx, handler.data, vars); err != nil {
handler.errSetup = sce.WithMessage(sce.ErrScorecardInternal, fmt.Sprintf("githubv4.Query: %v", err)) handler.errSetup = sce.WithMessage(sce.ErrScorecardInternal, fmt.Sprintf("githubv4.Query: %v", err))
return return
} }
handler.defaultBranchRef = getBranchRefFrom(handler.data.Repository.DefaultBranchRef) handler.defaultBranchRef = getBranchRefFrom(handler.data.Repository.DefaultBranchRef)
handler.branches = getBranchRefsFrom(handler.data.Repository.Refs.Nodes, handler.defaultBranchRef)
}) })
return handler.errSetup return handler.errSetup
} }
func (handler *branchesHandler) query(branchName string) (*clients.BranchRef, error) {
if !strings.EqualFold(handler.repourl.commitSHA, clients.HeadSHA) {
return nil, fmt.Errorf("%w: branches only supported for HEAD queries", clients.ErrUnsupportedFeature)
}
vars := map[string]interface{}{
"owner": githubv4.String(handler.repourl.owner),
"name": githubv4.String(handler.repourl.repo),
"branchRefName": githubv4.String(refPrefix + branchName),
}
queryData := new(branchData)
if err := handler.graphClient.Query(handler.ctx, queryData, vars); err != nil {
return nil, sce.WithMessage(sce.ErrScorecardInternal, fmt.Sprintf("githubv4.Query: %v", err))
}
return getBranchRefFrom(queryData.Repository.Ref), nil
}
func (handler *branchesHandler) getDefaultBranch() (*clients.BranchRef, error) { func (handler *branchesHandler) getDefaultBranch() (*clients.BranchRef, error) {
if err := handler.setup(); err != nil { if err := handler.setup(); err != nil {
return nil, fmt.Errorf("error during branchesHandler.setup: %w", err) return nil, fmt.Errorf("error during branchesHandler.setup: %w", err)
@ -160,11 +172,12 @@ func (handler *branchesHandler) getDefaultBranch() (*clients.BranchRef, error) {
return handler.defaultBranchRef, nil return handler.defaultBranchRef, nil
} }
func (handler *branchesHandler) listBranches() ([]*clients.BranchRef, error) { func (handler *branchesHandler) getBranch(branch string) (*clients.BranchRef, error) {
if err := handler.setup(); err != nil { branchRef, err := handler.query(branch)
return nil, fmt.Errorf("error during branchesHandler.setup: %w", err) if err != nil {
return nil, fmt.Errorf("error during branchesHandler.query: %w", err)
} }
return handler.branches, nil return branchRef, nil
} }
func copyAdminSettings(src *branchProtectionRule, dst *clients.BranchProtectionRule) { func copyAdminSettings(src *branchProtectionRule, dst *clients.BranchProtectionRule) {
@ -197,7 +210,10 @@ func copyNonAdminSettings(src interface{}, dst *clients.BranchProtectionRule) {
} }
} }
func getBranchRefFrom(data branch) *clients.BranchRef { func getBranchRefFrom(data *branch) *clients.BranchRef {
if data == nil {
return nil
}
branchRef := new(clients.BranchRef) branchRef := new(clients.BranchRef)
if data.Name != nil { if data.Name != nil {
branchRef.Name = data.Name branchRef.Name = data.Name
@ -238,18 +254,3 @@ func getBranchRefFrom(data branch) *clients.BranchRef {
return branchRef return branchRef
} }
func getBranchRefsFrom(data []branch, defaultBranch *clients.BranchRef) []*clients.BranchRef {
var branchRefs []*clients.BranchRef
var defaultFound bool
for i, b := range data {
branchRefs = append(branchRefs, getBranchRefFrom(b))
if defaultBranch != nil && branchRefs[i].Name == defaultBranch.Name {
defaultFound = true
}
}
if !defaultFound {
branchRefs = append(branchRefs, defaultBranch)
}
return branchRefs
}

View File

@ -30,7 +30,10 @@ import (
"github.com/ossf/scorecard/v4/log" "github.com/ossf/scorecard/v4/log"
) )
var errInputRepoType = errors.New("input repo should be of type repoURL") var (
_ clients.RepoClient = &Client{}
errInputRepoType = errors.New("input repo should be of type repoURL")
)
// Client is GitHub-specific implementation of RepoClient. // Client is GitHub-specific implementation of RepoClient.
type Client struct { type Client struct {
@ -149,9 +152,9 @@ func (client *Client) GetDefaultBranch() (*clients.BranchRef, error) {
return client.branches.getDefaultBranch() return client.branches.getDefaultBranch()
} }
// ListBranches implements RepoClient.ListBranches. // GetBranch implements RepoClient.GetBranch.
func (client *Client) ListBranches() ([]*clients.BranchRef, error) { func (client *Client) GetBranch(branch string) (*clients.BranchRef, error) {
return client.branches.listBranches() return client.branches.getBranch(branch)
} }
// ListWebhooks implements RepoClient.ListWebhooks. // ListWebhooks implements RepoClient.ListWebhooks.

View File

@ -31,7 +31,10 @@ import (
"github.com/ossf/scorecard/v4/log" "github.com/ossf/scorecard/v4/log"
) )
var errInputRepoType = errors.New("input repo should be of type repoLocal") var (
_ clients.RepoClient = &localDirClient{}
errInputRepoType = errors.New("input repo should be of type repoLocal")
)
//nolint:govet //nolint:govet
type localDirClient struct { type localDirClient struct {
@ -156,8 +159,8 @@ func (client *localDirClient) GetFileContent(filename string) ([]byte, error) {
return getFileContent(client.path, filename) return getFileContent(client.path, filename)
} }
// ListBranches implements RepoClient.ListBranches. // GetBranch implements RepoClient.GetBranch.
func (client *localDirClient) ListBranches() ([]*clients.BranchRef, error) { func (client *localDirClient) GetBranch(branch string) (*clients.BranchRef, error) {
return nil, fmt.Errorf("ListBranches: %w", clients.ErrUnsupportedFeature) return nil, fmt.Errorf("ListBranches: %w", clients.ErrUnsupportedFeature)
} }

View File

@ -63,6 +63,21 @@ func (mr *MockRepoClientMockRecorder) Close() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockRepoClient)(nil).Close)) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockRepoClient)(nil).Close))
} }
// GetBranch mocks base method.
func (m *MockRepoClient) GetBranch(branch string) (*clients.BranchRef, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetBranch", branch)
ret0, _ := ret[0].(*clients.BranchRef)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetBranch indicates an expected call of GetBranch.
func (mr *MockRepoClientMockRecorder) GetBranch(branch interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetBranch", reflect.TypeOf((*MockRepoClient)(nil).GetBranch), branch)
}
// GetDefaultBranch mocks base method. // GetDefaultBranch mocks base method.
func (m *MockRepoClient) GetDefaultBranch() (*clients.BranchRef, error) { func (m *MockRepoClient) GetDefaultBranch() (*clients.BranchRef, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
@ -122,21 +137,6 @@ func (mr *MockRepoClientMockRecorder) IsArchived() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsArchived", reflect.TypeOf((*MockRepoClient)(nil).IsArchived)) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsArchived", reflect.TypeOf((*MockRepoClient)(nil).IsArchived))
} }
// ListBranches mocks base method.
func (m *MockRepoClient) ListBranches() ([]*clients.BranchRef, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ListBranches")
ret0, _ := ret[0].([]*clients.BranchRef)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// ListBranches indicates an expected call of ListBranches.
func (mr *MockRepoClientMockRecorder) ListBranches() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListBranches", reflect.TypeOf((*MockRepoClient)(nil).ListBranches))
}
// ListCheckRunsForRef mocks base method. // ListCheckRunsForRef mocks base method.
func (m *MockRepoClient) ListCheckRunsForRef(ref string) ([]clients.CheckRun, error) { func (m *MockRepoClient) ListCheckRunsForRef(ref string) ([]clients.CheckRun, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()

View File

@ -32,7 +32,7 @@ type RepoClient interface {
IsArchived() (bool, error) IsArchived() (bool, error)
ListFiles(predicate func(string) (bool, error)) ([]string, error) ListFiles(predicate func(string) (bool, error)) ([]string, error)
GetFileContent(filename string) ([]byte, error) GetFileContent(filename string) ([]byte, error)
ListBranches() ([]*BranchRef, error) GetBranch(branch string) (*BranchRef, error)
GetDefaultBranch() (*BranchRef, error) GetDefaultBranch() (*BranchRef, error)
ListCommits() ([]Commit, error) ListCommits() ([]Commit, error)
ListIssues() ([]Issue, error) ListIssues() ([]Issue, error)