Add a tarball handler (#654)

Co-authored-by: Azeem Shaikh <azeems@google.com>
This commit is contained in:
Azeem Shaikh 2021-07-04 17:35:53 -07:00 committed by GitHub
parent aab6c217cc
commit 581e170db1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 243 additions and 202 deletions

View File

@ -15,66 +15,22 @@
package githubrepo
import (
"archive/tar"
"compress/gzip"
"context"
"fmt"
"io"
"io/ioutil"
"log"
"net/http"
"os"
"path/filepath"
"strings"
"github.com/google/go-github/v32/github"
"github.com/pkg/errors"
"github.com/ossf/scorecard/clients"
)
const (
repoDir = "repo*"
repoFilename = "githubrepo*.tar.gz"
)
var errZipSlip = errors.New("ZipSlip path detected")
func extractAndValidateArchivePath(path, dest string) (string, error) {
const splitLength = 2
// The tarball will have a top-level directory which contains all the repository files.
// Discard the directory and only keep the actual files.
names := strings.SplitN(path, "/", splitLength)
if len(names) < splitLength {
log.Printf("Unable to split path: %s", path)
return dest, nil
}
if names[1] == "" {
return dest, nil
}
// Check for ZipSlip: https://snyk.io/research/zip-slip-vulnerability
cleanpath := filepath.Join(dest, names[1])
if !strings.HasPrefix(cleanpath, filepath.Clean(dest)+string(os.PathSeparator)) {
return "", fmt.Errorf("%w: %s", errZipSlip, names[1])
}
return cleanpath, nil
}
type Client struct {
repo *github.Repository
repoClient *github.Client
ctx context.Context
tempDir string
tempTarFile string
files []string
repo *github.Repository
repoClient *github.Client
ctx context.Context
tarball tarballHandler
}
func (client *Client) InitRepo(owner, repoName string) error {
// Cleanup any previous state.
if err := client.cleanup(); err != nil {
return fmt.Errorf("error during githubrepo cleanup: %w", err)
}
// Sanity check
repo, _, err := client.repoClient.Repositories.Get(client.ctx, owner, repoName)
if err != nil {
@ -83,158 +39,24 @@ func (client *Client) InitRepo(owner, repoName string) error {
}
client.repo = repo
// Setup temp dir/files and download repo tarball.
if err := client.getTarball(); err != nil {
return fmt.Errorf("error getting githurepo tarball: %w", err)
// Init tarballHandler.
if err := client.tarball.init(client.ctx, client.repo); err != nil {
return fmt.Errorf("error during tarballHandler.init: %w", err)
}
// Extract file names and content from tarball.
if err := client.extractTarball(); err != nil {
return fmt.Errorf("error extracting githubrepo tarball: %w", err)
}
return nil
}
func (client *Client) getTarball() error {
tempDir, err := ioutil.TempDir("", repoDir)
if err != nil {
return fmt.Errorf("error creating TempDir in githubrepo: %w", err)
}
client.tempDir = tempDir
url := client.repo.GetArchiveURL()
url = strings.Replace(url, "{archive_format}", "tarball/", 1)
url = strings.Replace(url, "{/ref}", client.repo.GetDefaultBranch(), 1)
req, err := http.NewRequestWithContext(client.ctx, http.MethodGet, url, nil)
if err != nil {
return fmt.Errorf("http.NewRequestWithContext: %w", err)
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
return fmt.Errorf("http.DefaultClient.Do: %w", err)
}
defer resp.Body.Close()
// Create a temp file. This automaticlly appends a random number to the name.
repoFile, err := ioutil.TempFile(client.tempDir, repoFilename)
if err != nil {
return fmt.Errorf("error during ioutil.TempFile in githubrepo: %w", err)
}
defer repoFile.Close()
if _, err := io.Copy(repoFile, resp.Body); err != nil {
return fmt.Errorf("error during io.Copy in githubrepo tarball: %w", err)
}
client.tempTarFile = repoFile.Name()
return nil
}
// nolint: gocognit
func (client *Client) extractTarball() error {
// nolint: gomnd
in, err := os.OpenFile(client.tempTarFile, os.O_RDONLY, 0o644)
if err != nil {
return fmt.Errorf("error opening %s: %w", client.tempTarFile, err)
}
gz, err := gzip.NewReader(in)
if err != nil {
return fmt.Errorf("error reading %s: %w", client.tempTarFile, err)
}
tr := tar.NewReader(gz)
for {
header, err := tr.Next()
if errors.Is(err, io.EOF) {
break
}
if err != nil {
return fmt.Errorf("error in tarReader.Next(): %w", err)
}
switch header.Typeflag {
case tar.TypeDir:
dirpath, err := extractAndValidateArchivePath(header.Name, client.tempDir)
if err != nil {
return fmt.Errorf("error extracting dirpath: %w", err)
}
if dirpath == filepath.Clean(client.tempDir) {
continue
}
// nolint: gomnd
if err := os.Mkdir(dirpath, 0o755); err != nil {
return fmt.Errorf("error during os.Mkdir: %w", err)
}
case tar.TypeReg:
if header.Size <= 0 {
continue
}
filenamepath, err := extractAndValidateArchivePath(header.Name, client.tempDir)
if err != nil {
return fmt.Errorf("error extracting file path: %w", err)
}
if _, err := os.Stat(filepath.Dir(filenamepath)); os.IsNotExist(err) {
// nolint: gomnd
if err := os.Mkdir(filepath.Dir(filenamepath), 0o755); err != nil {
return fmt.Errorf("error during os.Mkdir: %w", err)
}
}
outFile, err := os.Create(filenamepath)
if err != nil {
return fmt.Errorf("error during os.Create: %w", err)
}
// nolint: gosec
// Potential for DoS vulnerability via decompression bomb.
// Since such an attack will only impact a single shard, ignoring this for now.
if _, err := io.Copy(outFile, tr); err != nil {
return fmt.Errorf("error during io.Copy: %w", err)
}
outFile.Close()
client.files = append(client.files,
strings.TrimPrefix(filenamepath, filepath.Clean(client.tempDir)+string(os.PathSeparator)))
case tar.TypeXGlobalHeader:
continue
case tar.TypeSymlink:
continue
default:
log.Printf("Unknown file type %s: '%s'", header.Name, string(header.Typeflag))
continue
}
}
return nil
}
func (client *Client) ListFiles(predicate func(string) bool) []string {
ret := make([]string, 0)
for _, file := range client.files {
if predicate(file) {
ret = append(ret, file)
}
}
return ret
return client.tarball.listFiles(predicate)
}
func (client *Client) GetFileContent(filename string) ([]byte, error) {
content, err := ioutil.ReadFile(filepath.Join(client.tempDir, filename))
if err != nil {
return content, fmt.Errorf("error trying to ReadFile: %w", err)
}
return content, nil
}
func (client *Client) cleanup() error {
if err := os.RemoveAll(client.tempDir); err != nil && !os.IsNotExist(err) {
return fmt.Errorf("os.Remove: %w", err)
}
// Remove old files so we don't iterate through them.
client.files = nil
return nil
return client.tarball.getFileContent(filename)
}
func (client *Client) Close() error {
return client.cleanup()
return client.tarball.cleanup()
}
func CreateGithubRepoClient(ctx context.Context, client *github.Client) clients.RepoClient {

View File

@ -0,0 +1,219 @@
// Copyright 2021 Security Scorecard Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package githubrepo
import (
"archive/tar"
"compress/gzip"
"context"
"errors"
"fmt"
"io"
"io/ioutil"
"log"
"net/http"
"os"
"path/filepath"
"strings"
"github.com/google/go-github/v32/github"
)
const (
repoDir = "repo*"
repoFilename = "githubrepo*.tar.gz"
)
var errZipSlip = errors.New("ZipSlip path detected")
func extractAndValidateArchivePath(path, dest string) (string, error) {
const splitLength = 2
// The tarball will have a top-level directory which contains all the repository files.
// Discard the directory and only keep the actual files.
names := strings.SplitN(path, "/", splitLength)
if len(names) < splitLength {
log.Printf("Unable to split path: %s", path)
return dest, nil
}
if names[1] == "" {
return dest, nil
}
// Check for ZipSlip: https://snyk.io/research/zip-slip-vulnerability
cleanpath := filepath.Join(dest, names[1])
if !strings.HasPrefix(cleanpath, filepath.Clean(dest)+string(os.PathSeparator)) {
return "", fmt.Errorf("%w: %s", errZipSlip, names[1])
}
return cleanpath, nil
}
type tarballHandler struct {
tempDir string
tempTarFile string
files []string
}
func (handler *tarballHandler) init(ctx context.Context, repo *github.Repository) error {
// Cleanup any previous state.
if err := handler.cleanup(); err != nil {
return fmt.Errorf("error during githubrepo cleanup: %w", err)
}
// Setup temp dir/files and download repo tarball.
if err := handler.getTarball(ctx, repo); err != nil {
return fmt.Errorf("error getting githurepo tarball: %w", err)
}
// Extract file names and content from tarball.
if err := handler.extractTarball(); err != nil {
return fmt.Errorf("error extracting githubrepo tarball: %w", err)
}
return nil
}
func (handler *tarballHandler) getTarball(ctx context.Context, repo *github.Repository) error {
tempDir, err := ioutil.TempDir("", repoDir)
if err != nil {
return fmt.Errorf("error creating TempDir in githubrepo: %w", err)
}
handler.tempDir = tempDir
url := repo.GetArchiveURL()
url = strings.Replace(url, "{archive_format}", "tarball/", 1)
url = strings.Replace(url, "{/ref}", repo.GetDefaultBranch(), 1)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return fmt.Errorf("http.NewRequestWithContext: %w", err)
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
return fmt.Errorf("http.DefaultClient.Do: %w", err)
}
defer resp.Body.Close()
// Create a temp file. This automaticlly appends a random number to the name.
repoFile, err := ioutil.TempFile(handler.tempDir, repoFilename)
if err != nil {
return fmt.Errorf("error during ioutil.TempFile in githubrepo: %w", err)
}
defer repoFile.Close()
if _, err := io.Copy(repoFile, resp.Body); err != nil {
return fmt.Errorf("error during io.Copy in githubrepo tarball: %w", err)
}
handler.tempTarFile = repoFile.Name()
return nil
}
// nolint: gocognit
func (handler *tarballHandler) extractTarball() error {
// nolint: gomnd
in, err := os.OpenFile(handler.tempTarFile, os.O_RDONLY, 0o644)
if err != nil {
return fmt.Errorf("error opening %s: %w", handler.tempTarFile, err)
}
gz, err := gzip.NewReader(in)
if err != nil {
return fmt.Errorf("error reading %s: %w", handler.tempTarFile, err)
}
tr := tar.NewReader(gz)
for {
header, err := tr.Next()
if errors.Is(err, io.EOF) {
break
}
if err != nil {
return fmt.Errorf("error in tarReader.Next(): %w", err)
}
switch header.Typeflag {
case tar.TypeDir:
dirpath, err := extractAndValidateArchivePath(header.Name, handler.tempDir)
if err != nil {
return fmt.Errorf("error extracting dirpath: %w", err)
}
if dirpath == filepath.Clean(handler.tempDir) {
continue
}
// nolint: gomnd
if err := os.Mkdir(dirpath, 0o755); err != nil {
return fmt.Errorf("error during os.Mkdir: %w", err)
}
case tar.TypeReg:
if header.Size <= 0 {
continue
}
filenamepath, err := extractAndValidateArchivePath(header.Name, handler.tempDir)
if err != nil {
return fmt.Errorf("error extracting file path: %w", err)
}
if _, err := os.Stat(filepath.Dir(filenamepath)); os.IsNotExist(err) {
// nolint: gomnd
if err := os.Mkdir(filepath.Dir(filenamepath), 0o755); err != nil {
return fmt.Errorf("error during os.Mkdir: %w", err)
}
}
outFile, err := os.Create(filenamepath)
if err != nil {
return fmt.Errorf("error during os.Create: %w", err)
}
// nolint: gosec
// Potential for DoS vulnerability via decompression bomb.
// Since such an attack will only impact a single shard, ignoring this for now.
if _, err := io.Copy(outFile, tr); err != nil {
return fmt.Errorf("error during io.Copy: %w", err)
}
outFile.Close()
handler.files = append(handler.files,
strings.TrimPrefix(filenamepath, filepath.Clean(handler.tempDir)+string(os.PathSeparator)))
case tar.TypeXGlobalHeader, tar.TypeSymlink:
continue
default:
log.Printf("Unknown file type %s: '%s'", header.Name, string(header.Typeflag))
continue
}
}
return nil
}
func (handler *tarballHandler) listFiles(predicate func(string) bool) []string {
ret := make([]string, 0)
for _, file := range handler.files {
if predicate(file) {
ret = append(ret, file)
}
}
return ret
}
func (handler *tarballHandler) getFileContent(filename string) ([]byte, error) {
content, err := ioutil.ReadFile(filepath.Join(handler.tempDir, filename))
if err != nil {
return content, fmt.Errorf("error trying to ReadFile: %w", err)
}
return content, nil
}
func (handler *tarballHandler) cleanup() error {
if err := os.RemoveAll(handler.tempDir); err != nil && !os.IsNotExist(err) {
return fmt.Errorf("os.Remove: %w", err)
}
// Remove old files so we don't iterate through them.
handler.files = nil
return nil
}

View File

@ -42,23 +42,23 @@ func isSortedString(x, y string) bool {
return x < y
}
func setup(inputFile string) (Client, error) {
func setup(inputFile string) (tarballHandler, error) {
tempDir, err := ioutil.TempDir("", repoDir)
if err != nil {
return Client{}, fmt.Errorf("test failed to create TempDir: %w", err)
return tarballHandler{}, fmt.Errorf("test failed to create TempDir: %w", err)
}
tempFile, err := ioutil.TempFile(tempDir, repoFilename)
if err != nil {
return Client{}, fmt.Errorf("test failed to create TempFile: %w", err)
return tarballHandler{}, fmt.Errorf("test failed to create TempFile: %w", err)
}
testFile, err := os.OpenFile(inputFile, os.O_RDONLY, 0o644)
if err != nil {
return Client{}, fmt.Errorf("unable to open testfile: %w", err)
return tarballHandler{}, fmt.Errorf("unable to open testfile: %w", err)
}
if _, err := io.Copy(tempFile, testFile); err != nil {
return Client{}, fmt.Errorf("unable to do io.Copy: %w", err)
return tarballHandler{}, fmt.Errorf("unable to do io.Copy: %w", err)
}
return Client{
return tarballHandler{
tempDir: tempDir,
tempTarFile: tempFile.Name(),
}, nil
@ -120,28 +120,28 @@ func TestExtractTarball(t *testing.T) {
t.Parallel()
// Setup
client, err := setup(testcase.inputFile)
handler, err := setup(testcase.inputFile)
if err != nil {
t.Fatalf("test setup failed: %v", err)
}
// Extract tarball.
if err := client.extractTarball(); err != nil {
if err := handler.extractTarball(); err != nil {
t.Fatalf("test failed: %v", err)
}
// Test ListFiles API.
for _, listfiletest := range testcase.listfileTests {
if !cmp.Equal(listfiletest.outcome,
client.ListFiles(listfiletest.predicate),
handler.listFiles(listfiletest.predicate),
cmpopts.SortSlices(isSortedString)) {
t.Errorf("test failed: expected - %q, got - %q", listfiletest.outcome, client.ListFiles(listfiletest.predicate))
t.Errorf("test failed: expected - %q, got - %q", listfiletest.outcome, handler.listFiles(listfiletest.predicate))
}
}
// Test GetFileContent API.
for _, getcontenttest := range testcase.getcontentTests {
content, err := client.GetFileContent(getcontenttest.filename)
content, err := handler.getFileContent(getcontenttest.filename)
if getcontenttest.err != nil && !errors.As(err, &getcontenttest.err) {
t.Errorf("test failed: expected - %v, got - %v", getcontenttest.err, err)
}
@ -151,13 +151,13 @@ func TestExtractTarball(t *testing.T) {
}
// Test that files get deleted.
if err := client.cleanup(); err != nil {
if err := handler.cleanup(); err != nil {
t.Errorf("test failed: %v", err)
}
if _, err := os.Stat(client.tempDir); !os.IsNotExist(err) {
if _, err := os.Stat(handler.tempDir); !os.IsNotExist(err) {
t.Errorf("%v", err)
}
if len(client.files) != 0 {
if len(handler.files) != 0 {
t.Error("client.files not cleaned up!")
}
})