From 70d045b9ef00e7171ce3950aca38eef6ea4d7308 Mon Sep 17 00:00:00 2001 From: Azeem Shaikh Date: Fri, 27 May 2022 15:25:24 -0700 Subject: [PATCH] Only pull required branch names (#1965) Co-authored-by: Azeem Shaikh --- checks/branch_protection_test.go | 29 +-- checks/evaluation/security_policy.go | 1 - checks/raw/branch_protection.go | 126 +++++----- checks/raw/branch_protection_test.go | 349 +++++++++++++++++---------- checks/raw/errors.go | 9 +- clients/githubrepo/branches.go | 75 +++--- clients/githubrepo/client.go | 11 +- clients/localdir/client.go | 9 +- clients/mockclients/repo_client.go | 30 +-- clients/repo_client.go | 2 +- 10 files changed, 359 insertions(+), 282 deletions(-) diff --git a/checks/branch_protection_test.go b/checks/branch_protection_test.go index 53d0dc7c..5a7fa98a 100644 --- a/checks/branch_protection_test.go +++ b/checks/branch_protection_test.go @@ -33,11 +33,14 @@ func getBranchName(branch *clients.BranchRef) string { return *branch.Name } -func getBranch(branches []*clients.BranchRef, name string) *clients.BranchRef { +func getBranch(branches []*clients.BranchRef, name string, isNonAdmin bool) *clients.BranchRef { for _, branch := range branches { branchName := getBranchName(branch) if branchName == name { - return branch + if !isNonAdmin { + return branch + } + return scrubBranch(branch) } } return nil @@ -49,14 +52,6 @@ func scrubBranch(branch *clients.BranchRef) *clients.BranchRef { return ret } -func scrubBranches(branches []*clients.BranchRef) []*clients.BranchRef { - ret := make([]*clients.BranchRef, len(branches)) - for i, branch := range branches { - ret[i] = scrubBranch(branch) - } - return ret -} - func TestReleaseAndDevBranchProtected(t *testing.T) { t.Parallel() @@ -399,10 +394,7 @@ func TestReleaseAndDevBranchProtected(t *testing.T) { mockRepoClient := mockrepo.NewMockRepoClient(ctrl) mockRepoClient.EXPECT().GetDefaultBranch(). DoAndReturn(func() (*clients.BranchRef, error) { - defaultBranch := getBranch(tt.branches, tt.defaultBranch) - if defaultBranch != nil && tt.nonadmin { - return scrubBranch(defaultBranch), nil - } + defaultBranch := getBranch(tt.branches, tt.defaultBranch, tt.nonadmin) return defaultBranch, nil }).AnyTimes() mockRepoClient.EXPECT().ListReleases(). @@ -415,12 +407,9 @@ func TestReleaseAndDevBranchProtected(t *testing.T) { } return ret, nil }).AnyTimes() - mockRepoClient.EXPECT().ListBranches(). - DoAndReturn(func() ([]*clients.BranchRef, error) { - if tt.nonadmin { - return scrubBranches(tt.branches), nil - } - return tt.branches, nil + mockRepoClient.EXPECT().GetBranch(gomock.Any()). + DoAndReturn(func(b string) (*clients.BranchRef, error) { + return getBranch(tt.branches, b, tt.nonadmin), nil }).AnyTimes() dl := scut.TestDetailLogger{} req := checker.CheckRequest{ diff --git a/checks/evaluation/security_policy.go b/checks/evaluation/security_policy.go index 6180b4a8..aa80aded 100644 --- a/checks/evaluation/security_policy.go +++ b/checks/evaluation/security_policy.go @@ -40,7 +40,6 @@ func SecurityPolicy(name string, dl checker.DetailLogger, r *checker.SecurityPol } if msg.Type == checker.FileTypeURL { msg.Text = "security policy detected in org repo" - } else { msg.Text = "security policy detected in current repo" } diff --git a/checks/raw/branch_protection.go b/checks/raw/branch_protection.go index dee511e6..9a8b0838 100644 --- a/checks/raw/branch_protection.go +++ b/checks/raw/branch_protection.go @@ -15,37 +15,56 @@ package raw import ( - "errors" "fmt" "regexp" "github.com/ossf/scorecard/v4/checker" "github.com/ossf/scorecard/v4/clients" - sce "github.com/ossf/scorecard/v4/errors" ) const master = "master" -type branchMap map[string]*clients.BranchRef +var commit = regexp.MustCompile("^[a-f0-9]{40}$") + +type branchSet struct { + exists map[string]bool + set []clients.BranchRef +} + +func (set *branchSet) add(branch *clients.BranchRef) bool { + if branch != nil && + branch.Name != nil && + *branch.Name != "" && + !set.exists[*branch.Name] { + set.set = append(set.set, *branch) + set.exists[*branch.Name] = true + return true + } + return false +} + +func (set branchSet) contains(branch string) bool { + _, contains := set.exists[branch] + return contains +} // BranchProtection retrieves the raw data for the Branch-Protection check. func BranchProtection(c clients.RepoClient) (checker.BranchProtectionsData, error) { - // Checks branch protection on both release and development branch. - // Get all branches. This will include information on whether they are protected. - branches, err := c.ListBranches() + branches := branchSet{ + exists: make(map[string]bool), + } + // Add default branch. + defaultBranch, err := c.GetDefaultBranch() if err != nil { return checker.BranchProtectionsData{}, fmt.Errorf("%w", err) } - branchesMap := getBranchMapFrom(branches) + branches.add(defaultBranch) // Get release branches. releases, err := c.ListReleases() if err != nil { return checker.BranchProtectionsData{}, fmt.Errorf("%w", err) } - - commit := regexp.MustCompile("^[a-f0-9]{40}$") - checkBranches := make(map[string]bool) for _, release := range releases { if release.TargetCommitish == "" { // Log with a named error if target_commitish is nil. @@ -57,78 +76,47 @@ func BranchProtection(c clients.RepoClient) (checker.BranchProtectionsData, erro continue } - // Try to resolve the branch name. - b, err := branchesMap.getBranchByName(release.TargetCommitish) - if err != nil { - // If the commitish branch is still not found, fail. - return checker.BranchProtectionsData{}, err + if branches.contains(release.TargetCommitish) || + branches.contains(branchRedirect(release.TargetCommitish)) { + continue } - // Branch is valid, add to list of branches to check. - checkBranches[*b.Name] = true - } - - // Add default branch. - defaultBranch, err := c.GetDefaultBranch() - if err != nil { - return checker.BranchProtectionsData{}, fmt.Errorf("%w", err) - } - defaultBranchName := getBranchName(defaultBranch) - if defaultBranchName != "" { - checkBranches[defaultBranchName] = true - } - - rawData := checker.BranchProtectionsData{} - // Check protections on all the branches. - for b := range checkBranches { - branch, err := branchesMap.getBranchByName(b) + // Get the associated release branch. + branchRef, err := c.GetBranch(release.TargetCommitish) if err != nil { - if errors.Is(err, errInternalBranchNotFound) { - continue - } - return checker.BranchProtectionsData{}, err + return checker.BranchProtectionsData{}, + fmt.Errorf("error during GetBranch(%s): %w", release.TargetCommitish, err) + } + if branches.add(branchRef) { + continue } - rawData.Branches = append(rawData.Branches, *branch) + // Couldn't find the branch check for redirects. + redirectBranch := branchRedirect(release.TargetCommitish) + if redirectBranch == "" { + continue + } + branchRef, err = c.GetBranch(redirectBranch) + if err != nil { + return checker.BranchProtectionsData{}, + fmt.Errorf("error during GetBranch(%s) %w", redirectBranch, err) + } + branches.add(branchRef) + // Branch doesn't exist or was deleted. Continue. } // No error, return the data. - return rawData, nil + return checker.BranchProtectionsData{ + Branches: branches.set, + }, nil } -func (b branchMap) getBranchByName(name string) (*clients.BranchRef, error) { - val, exists := b[name] - if exists { - return val, nil - } - +func branchRedirect(name string) string { // Ideally, we should check using repositories.GetBranch if there was a branch redirect. // See https://github.com/google/go-github/issues/1895 // For now, handle the common master -> main redirect. if name == master { - val, exists := b["main"] - if exists { - return val, nil - } + return "main" } - return nil, sce.WithMessage(sce.ErrScorecardInternal, - fmt.Sprintf("could not find branch name %s: %v", name, errInternalBranchNotFound)) -} - -func getBranchMapFrom(branches []*clients.BranchRef) branchMap { - ret := make(branchMap) - for _, branch := range branches { - branchName := getBranchName(branch) - if branchName != "" { - ret[branchName] = branch - } - } - return ret -} - -func getBranchName(branch *clients.BranchRef) string { - if branch == nil || branch.Name == nil { - return "" - } - return *branch.Name + return "" } diff --git a/checks/raw/branch_protection_test.go b/checks/raw/branch_protection_test.go index e3085377..b59fd8ad 100644 --- a/checks/raw/branch_protection_test.go +++ b/checks/raw/branch_protection_test.go @@ -15,160 +15,255 @@ package raw import ( + "errors" "testing" + "github.com/golang/mock/gomock" "github.com/google/go-cmp/cmp" + "github.com/ossf/scorecard/v4/checker" "github.com/ossf/scorecard/v4/clients" + mockrepo "github.com/ossf/scorecard/v4/clients/mockclients" ) -var branch = "master" +var ( + errBPTest = errors.New("test error") + defaultBranchName = "default" + releaseBranchName = "release-branch" + mainBranchName = "main" +) -func Test_getBranchName(t *testing.T) { - t.Parallel() - type args struct { - branch *clients.BranchRef - } - tests := []struct { - name string - args args - want string - }{ - { - name: "simple", - args: args{ - branch: &clients.BranchRef{ - Name: &branch, - }, - }, - want: master, - }, - { - name: "empty name", - args: args{ - branch: &clients.BranchRef{}, - }, - want: "", - }, - } - for _, tt := range tests { - tt := tt - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - if got := getBranchName(tt.args.branch); got != tt.want { - t.Errorf("getBranchName() = %v, want %v", got, tt.want) - } - }) - } +// nolint: govet +type branchArg struct { + err error + name string + branchRef *clients.BranchRef + defaultBranch bool } -func Test_getBranchMapFrom(t *testing.T) { - t.Parallel() - type args struct { - branches []*clients.BranchRef +type branchesArg []branchArg + +func (ba branchesArg) getDefaultBranch() (*clients.BranchRef, error) { + for _, branch := range ba { + if branch.defaultBranch { + return branch.branchRef, branch.err + } } - //nolint + return nil, nil +} + +func (ba branchesArg) getBranch(b string) (*clients.BranchRef, error) { + for _, branch := range ba { + if branch.name == b { + return branch.branchRef, branch.err + } + } + return nil, nil +} + +func TestBranchProtection(t *testing.T) { + t.Parallel() + // nolint: govet tests := []struct { - name string - args args - want branchMap + name string + branches branchesArg + releases []clients.Release + releasesErr error + want checker.BranchProtectionsData + wantErr error }{ { - name: "simple", - args: args{ - branches: []*clients.BranchRef{ - { - Name: &branch, + name: "default-branch-err", + branches: branchesArg{ + { + name: defaultBranchName, + err: errBPTest, + }, + }, + }, + { + name: "null-default-branch-only", + branches: branchesArg{ + { + name: defaultBranchName, + defaultBranch: true, + branchRef: nil, + }, + }, + }, + { + name: "default-branch-only", + branches: branchesArg{ + { + name: defaultBranchName, + defaultBranch: true, + branchRef: &clients.BranchRef{ + Name: &defaultBranchName, }, }, }, - want: branchMap{ - master: &clients.BranchRef{ - Name: &branch, + want: checker.BranchProtectionsData{ + Branches: []clients.BranchRef{ + { + Name: &defaultBranchName, + }, }, }, }, + { + name: "list-releases-error", + releasesErr: errBPTest, + wantErr: errBPTest, + }, + { + name: "no-releases", + }, + { + name: "empty-targetcommitish", + releases: []clients.Release{ + { + TargetCommitish: "", + }, + }, + wantErr: errInternalCommitishNil, + }, + { + name: "release-branch-err", + releases: []clients.Release{ + { + TargetCommitish: releaseBranchName, + }, + }, + branches: branchesArg{ + { + name: releaseBranchName, + err: errBPTest, + }, + }, + wantErr: errBPTest, + }, + { + name: "nil-release-branch", + releases: []clients.Release{ + { + TargetCommitish: releaseBranchName, + }, + }, + branches: branchesArg{ + { + name: releaseBranchName, + branchRef: nil, + }, + }, + }, + { + name: "add-release-branch", + releases: []clients.Release{ + { + TargetCommitish: releaseBranchName, + }, + }, + branches: branchesArg{ + { + name: releaseBranchName, + branchRef: &clients.BranchRef{ + Name: &releaseBranchName, + }, + }, + }, + want: checker.BranchProtectionsData{ + Branches: []clients.BranchRef{ + { + Name: &releaseBranchName, + }, + }, + }, + }, + { + name: "master-to-main-redirect", + releases: []clients.Release{ + { + TargetCommitish: "master", + }, + }, + branches: branchesArg{ + { + name: mainBranchName, + branchRef: &clients.BranchRef{ + Name: &mainBranchName, + }, + }, + }, + want: checker.BranchProtectionsData{ + Branches: []clients.BranchRef{ + { + Name: &mainBranchName, + }, + }, + }, + }, + { + name: "default-and-release-branches", + releases: []clients.Release{ + { + TargetCommitish: releaseBranchName, + }, + }, + branches: branchesArg{ + { + name: defaultBranchName, + defaultBranch: true, + branchRef: &clients.BranchRef{ + Name: &defaultBranchName, + }, + }, + { + name: releaseBranchName, + branchRef: &clients.BranchRef{ + Name: &releaseBranchName, + }, + }, + }, + want: checker.BranchProtectionsData{ + Branches: []clients.BranchRef{ + { + Name: &defaultBranchName, + }, + { + Name: &releaseBranchName, + }, + }, + }, + }, + // TODO: Add tests for commitSHA regex matching. } for _, tt := range tests { tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() - if got := getBranchMapFrom(tt.args.branches); !cmp.Equal(got, tt.want) { - t.Errorf("getBranchMapFrom() = %v, want %v", got, tt.want) - } - }) - } -} + ctrl := gomock.NewController(t) + mockRepoClient := mockrepo.NewMockRepoClient(ctrl) + mockRepoClient.EXPECT().GetDefaultBranch(). + AnyTimes().DoAndReturn(func() (*clients.BranchRef, error) { + return tt.branches.getDefaultBranch() + }) + mockRepoClient.EXPECT().GetBranch(gomock.Any()).AnyTimes(). + DoAndReturn(func(branch string) (*clients.BranchRef, error) { + return tt.branches.getBranch(branch) + }) + mockRepoClient.EXPECT().ListReleases().AnyTimes(). + DoAndReturn(func() ([]clients.Release, error) { + return tt.releases, tt.releasesErr + }) -func Test_branchMap_getBranchByName(t *testing.T) { - main := "main" - t.Parallel() - type args struct { - name string - } - //nolint - tests := []struct { - name string - b branchMap - args args - want *clients.BranchRef - wantErr bool - }{ - { - name: "simple", - b: branchMap{ - master: &clients.BranchRef{ - Name: &branch, - }, - }, - args: args{ - name: master, - }, - want: &clients.BranchRef{ - Name: &branch, - }, - }, - { - name: "main", - b: branchMap{ - master: &clients.BranchRef{ - Name: &main, - }, - main: &clients.BranchRef{ - Name: &main, - }, - }, - args: args{ - name: "main", - }, - want: &clients.BranchRef{ - Name: &main, - }, - }, - { - name: "not found", - b: branchMap{ - master: &clients.BranchRef{ - Name: &branch, - }, - }, - args: args{ - name: "not-found", - }, - wantErr: true, - }, - } - for _, tt := range tests { - tt := tt - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - got, err := tt.b.getBranchByName(tt.args.name) - if (err != nil) != tt.wantErr { - t.Errorf("branchMap.getBranchByName() error = %v, wantErr %v", err, tt.wantErr) - return + rawData, err := BranchProtection(mockRepoClient) + if !errors.Is(err, tt.wantErr) { + t.Errorf("failed. expected: %v, got: %v", tt.wantErr, err) + t.Fail() } - if !cmp.Equal(got, tt.want) { - t.Errorf("branchMap.getBranchByName() = %v, want %v", got, tt.want) + if !cmp.Equal(rawData, tt.want) { + t.Errorf("failed. expected: %v, got: %v", tt.want, rawData) + t.Fail() } }) } diff --git a/checks/raw/errors.go b/checks/raw/errors.go index 605e908c..0de31a67 100644 --- a/checks/raw/errors.go +++ b/checks/raw/errors.go @@ -19,9 +19,8 @@ import ( ) var ( - errInternalCommitishNil = errors.New("commitish is nil") - errInternalBranchNotFound = errors.New("branch not found") - errInvalidArgType = errors.New("invalid arg type") - errInvalidArgLength = errors.New("invalid arg length") - errInvalidGitHubWorkflow = errors.New("invalid GitHub workflow") + errInternalCommitishNil = errors.New("commitish is nil") + errInvalidArgType = errors.New("invalid arg type") + errInvalidArgLength = errors.New("invalid arg length") + errInvalidGitHubWorkflow = errors.New("invalid GitHub workflow") ) diff --git a/clients/githubrepo/branches.go b/clients/githubrepo/branches.go index efc814ef..659f1f0e 100644 --- a/clients/githubrepo/branches.go +++ b/clients/githubrepo/branches.go @@ -28,8 +28,7 @@ import ( ) const ( - refsToAnalyze = 30 - refPrefix = "refs/heads/" + refPrefix = "refs/heads/" ) // See https://github.community/t/graphql-api-protected-branch/14380 @@ -97,30 +96,30 @@ type branch struct { RefUpdateRule *refUpdateRule BranchProtectionRule *branchProtectionRule } - -// nolint:govet // internal structure, ignore. -type branchesData struct { +type defaultBranchData struct { Repository struct { - DefaultBranchRef branch - Refs struct { - Nodes []branch - } `graphql:"refs(first: $refsToAnalyze, refPrefix: $refPrefix)"` + DefaultBranchRef *branch } `graphql:"repository(owner: $owner, name: $name)"` RateLimit struct { Cost *int } } +type branchData struct { + Repository struct { + Ref *branch `graphql:"ref(qualifiedName: $branchRefName)"` + } `graphql:"repository(owner: $owner, name: $name)"` +} + type branchesHandler struct { ghClient *github.Client graphClient *githubv4.Client - data *branchesData + data *defaultBranchData once *sync.Once ctx context.Context errSetup error repourl *repoURL defaultBranchRef *clients.BranchRef - branches []*clients.BranchRef } func (handler *branchesHandler) init(ctx context.Context, repourl *repoURL) { @@ -137,22 +136,35 @@ func (handler *branchesHandler) setup() error { return } vars := map[string]interface{}{ - "owner": githubv4.String(handler.repourl.owner), - "name": githubv4.String(handler.repourl.repo), - "refsToAnalyze": githubv4.Int(refsToAnalyze), - "refPrefix": githubv4.String(refPrefix), + "owner": githubv4.String(handler.repourl.owner), + "name": githubv4.String(handler.repourl.repo), } - handler.data = new(branchesData) + handler.data = new(defaultBranchData) if err := handler.graphClient.Query(handler.ctx, handler.data, vars); err != nil { handler.errSetup = sce.WithMessage(sce.ErrScorecardInternal, fmt.Sprintf("githubv4.Query: %v", err)) return } handler.defaultBranchRef = getBranchRefFrom(handler.data.Repository.DefaultBranchRef) - handler.branches = getBranchRefsFrom(handler.data.Repository.Refs.Nodes, handler.defaultBranchRef) }) return handler.errSetup } +func (handler *branchesHandler) query(branchName string) (*clients.BranchRef, error) { + if !strings.EqualFold(handler.repourl.commitSHA, clients.HeadSHA) { + return nil, fmt.Errorf("%w: branches only supported for HEAD queries", clients.ErrUnsupportedFeature) + } + vars := map[string]interface{}{ + "owner": githubv4.String(handler.repourl.owner), + "name": githubv4.String(handler.repourl.repo), + "branchRefName": githubv4.String(refPrefix + branchName), + } + queryData := new(branchData) + if err := handler.graphClient.Query(handler.ctx, queryData, vars); err != nil { + return nil, sce.WithMessage(sce.ErrScorecardInternal, fmt.Sprintf("githubv4.Query: %v", err)) + } + return getBranchRefFrom(queryData.Repository.Ref), nil +} + func (handler *branchesHandler) getDefaultBranch() (*clients.BranchRef, error) { if err := handler.setup(); err != nil { return nil, fmt.Errorf("error during branchesHandler.setup: %w", err) @@ -160,11 +172,12 @@ func (handler *branchesHandler) getDefaultBranch() (*clients.BranchRef, error) { return handler.defaultBranchRef, nil } -func (handler *branchesHandler) listBranches() ([]*clients.BranchRef, error) { - if err := handler.setup(); err != nil { - return nil, fmt.Errorf("error during branchesHandler.setup: %w", err) +func (handler *branchesHandler) getBranch(branch string) (*clients.BranchRef, error) { + branchRef, err := handler.query(branch) + if err != nil { + return nil, fmt.Errorf("error during branchesHandler.query: %w", err) } - return handler.branches, nil + return branchRef, nil } func copyAdminSettings(src *branchProtectionRule, dst *clients.BranchProtectionRule) { @@ -197,7 +210,10 @@ func copyNonAdminSettings(src interface{}, dst *clients.BranchProtectionRule) { } } -func getBranchRefFrom(data branch) *clients.BranchRef { +func getBranchRefFrom(data *branch) *clients.BranchRef { + if data == nil { + return nil + } branchRef := new(clients.BranchRef) if data.Name != nil { branchRef.Name = data.Name @@ -238,18 +254,3 @@ func getBranchRefFrom(data branch) *clients.BranchRef { return branchRef } - -func getBranchRefsFrom(data []branch, defaultBranch *clients.BranchRef) []*clients.BranchRef { - var branchRefs []*clients.BranchRef - var defaultFound bool - for i, b := range data { - branchRefs = append(branchRefs, getBranchRefFrom(b)) - if defaultBranch != nil && branchRefs[i].Name == defaultBranch.Name { - defaultFound = true - } - } - if !defaultFound { - branchRefs = append(branchRefs, defaultBranch) - } - return branchRefs -} diff --git a/clients/githubrepo/client.go b/clients/githubrepo/client.go index c380017e..3394674d 100644 --- a/clients/githubrepo/client.go +++ b/clients/githubrepo/client.go @@ -30,7 +30,10 @@ import ( "github.com/ossf/scorecard/v4/log" ) -var errInputRepoType = errors.New("input repo should be of type repoURL") +var ( + _ clients.RepoClient = &Client{} + errInputRepoType = errors.New("input repo should be of type repoURL") +) // Client is GitHub-specific implementation of RepoClient. type Client struct { @@ -149,9 +152,9 @@ func (client *Client) GetDefaultBranch() (*clients.BranchRef, error) { return client.branches.getDefaultBranch() } -// ListBranches implements RepoClient.ListBranches. -func (client *Client) ListBranches() ([]*clients.BranchRef, error) { - return client.branches.listBranches() +// GetBranch implements RepoClient.GetBranch. +func (client *Client) GetBranch(branch string) (*clients.BranchRef, error) { + return client.branches.getBranch(branch) } // ListWebhooks implements RepoClient.ListWebhooks. diff --git a/clients/localdir/client.go b/clients/localdir/client.go index e34a9f8d..0a4c5a61 100644 --- a/clients/localdir/client.go +++ b/clients/localdir/client.go @@ -31,7 +31,10 @@ import ( "github.com/ossf/scorecard/v4/log" ) -var errInputRepoType = errors.New("input repo should be of type repoLocal") +var ( + _ clients.RepoClient = &localDirClient{} + errInputRepoType = errors.New("input repo should be of type repoLocal") +) //nolint:govet type localDirClient struct { @@ -156,8 +159,8 @@ func (client *localDirClient) GetFileContent(filename string) ([]byte, error) { return getFileContent(client.path, filename) } -// ListBranches implements RepoClient.ListBranches. -func (client *localDirClient) ListBranches() ([]*clients.BranchRef, error) { +// GetBranch implements RepoClient.GetBranch. +func (client *localDirClient) GetBranch(branch string) (*clients.BranchRef, error) { return nil, fmt.Errorf("ListBranches: %w", clients.ErrUnsupportedFeature) } diff --git a/clients/mockclients/repo_client.go b/clients/mockclients/repo_client.go index 1691c239..8c57778b 100644 --- a/clients/mockclients/repo_client.go +++ b/clients/mockclients/repo_client.go @@ -63,6 +63,21 @@ func (mr *MockRepoClientMockRecorder) Close() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockRepoClient)(nil).Close)) } +// GetBranch mocks base method. +func (m *MockRepoClient) GetBranch(branch string) (*clients.BranchRef, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetBranch", branch) + ret0, _ := ret[0].(*clients.BranchRef) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetBranch indicates an expected call of GetBranch. +func (mr *MockRepoClientMockRecorder) GetBranch(branch interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetBranch", reflect.TypeOf((*MockRepoClient)(nil).GetBranch), branch) +} + // GetDefaultBranch mocks base method. func (m *MockRepoClient) GetDefaultBranch() (*clients.BranchRef, error) { m.ctrl.T.Helper() @@ -122,21 +137,6 @@ func (mr *MockRepoClientMockRecorder) IsArchived() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsArchived", reflect.TypeOf((*MockRepoClient)(nil).IsArchived)) } -// ListBranches mocks base method. -func (m *MockRepoClient) ListBranches() ([]*clients.BranchRef, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ListBranches") - ret0, _ := ret[0].([]*clients.BranchRef) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// ListBranches indicates an expected call of ListBranches. -func (mr *MockRepoClientMockRecorder) ListBranches() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListBranches", reflect.TypeOf((*MockRepoClient)(nil).ListBranches)) -} - // ListCheckRunsForRef mocks base method. func (m *MockRepoClient) ListCheckRunsForRef(ref string) ([]clients.CheckRun, error) { m.ctrl.T.Helper() diff --git a/clients/repo_client.go b/clients/repo_client.go index 0e8e2ff7..8957e883 100644 --- a/clients/repo_client.go +++ b/clients/repo_client.go @@ -32,7 +32,7 @@ type RepoClient interface { IsArchived() (bool, error) ListFiles(predicate func(string) (bool, error)) ([]string, error) GetFileContent(filename string) ([]byte, error) - ListBranches() ([]*BranchRef, error) + GetBranch(branch string) (*BranchRef, error) GetDefaultBranch() (*BranchRef, error) ListCommits() ([]Commit, error) ListIssues() ([]Issue, error)