Merge pull request #313 from numtide/fix/stdin

fix: --stdin flag
This commit is contained in:
Brian McGee 2024-06-05 15:13:52 +01:00 committed by GitHub
commit ab2b373094
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 130 additions and 82 deletions

View File

@ -1,6 +1,8 @@
package cli
import (
"os"
"git.numtide.com/numtide/treefmt/walk"
"github.com/alecthomas/kong"
"github.com/charmbracelet/log"
@ -33,6 +35,7 @@ type Format struct {
func configureLogging() {
log.SetReportTimestamp(false)
log.SetOutput(os.Stderr)
if Cli.Verbosity == 0 {
log.SetLevel(log.WarnLevel)

View File

@ -1,16 +1,15 @@
package cli
import (
"bufio"
"context"
"errors"
"fmt"
"io"
"os"
"os/signal"
"path/filepath"
"runtime"
"runtime/pprof"
"strings"
"syscall"
"git.numtide.com/numtide/treefmt/format"
@ -173,6 +172,10 @@ func updateCache(ctx context.Context) func() error {
// apply a batch
processBatch := func() error {
if Cli.Stdin {
// do nothing
return nil
}
if err := cache.Update(batch); err != nil {
return err
}
@ -192,6 +195,24 @@ func updateCache(ctx context.Context) func() error {
// channel has been closed, no further files to process
break LOOP
}
if Cli.Stdin {
// dump file into stdout
f, err := os.Open(file.Path)
if err != nil {
return fmt.Errorf("failed to open %s: %w", file.Path, err)
}
if _, err = io.Copy(os.Stdout, f); err != nil {
return fmt.Errorf("failed to copy %s to stdout: %w", file.Path, err)
}
if err = os.Remove(f.Name()); err != nil {
return fmt.Errorf("failed to remove temp file %s: %w", file.Path, err)
}
stats.Add(stats.Formatted, 1)
continue
}
// append to batch and process if we have enough
batch = append(batch, file)
if len(batch) == BatchSize {
@ -212,8 +233,10 @@ func updateCache(ctx context.Context) func() error {
return ErrFailOnChange
}
// print stats to stdout
stats.Print()
// print stats to stdout unless we are processing stdin and printing the results to stdout
if !Cli.Stdin {
stats.Print()
}
return nil
}
@ -224,6 +247,32 @@ func walkFilesystem(ctx context.Context) func() error {
eg, ctx := errgroup.WithContext(ctx)
pathsCh := make(chan string, BatchSize)
// By default, we use the cli arg, but if the stdin flag has been set we force a filesystem walk
// since we will only be processing one file from a temp directory
walkerType := Cli.Walk
if Cli.Stdin {
walkerType = walk.Filesystem
// check we have only received one path arg which we use for the file extension / matching to formatters
if len(Cli.Paths) != 1 {
return fmt.Errorf("only one path should be specified when using the --stdin flag")
}
// read stdin into a temporary file with the same file extension
pattern := fmt.Sprintf("*%s", filepath.Ext(Cli.Paths[0]))
file, err := os.CreateTemp("", pattern)
if err != nil {
return fmt.Errorf("failed to create a temporary file for processing stdin: %w", err)
}
if _, err = io.Copy(file, os.Stdin); err != nil {
return fmt.Errorf("failed to copy stdin into a temporary file")
}
Cli.Paths[0] = file.Name()
}
walkPaths := func() error {
defer close(pathsCh)
@ -241,38 +290,8 @@ func walkFilesystem(ctx context.Context) func() error {
return nil
}
walkStdin := func() error {
defer close(pathsCh)
// determine the current working directory
cwd, err := os.Getwd()
if err != nil {
return fmt.Errorf("failed to determine current working directory: %w", err)
}
// read in all the paths
scanner := bufio.NewScanner(os.Stdin)
for scanner.Scan() {
select {
case <-ctx.Done():
return ctx.Err()
default:
path := scanner.Text()
if !strings.HasPrefix(path, "/") {
// append the cwd
path = filepath.Join(cwd, path)
}
pathsCh <- path
}
}
return nil
}
if len(Cli.Paths) > 0 {
eg.Go(walkPaths)
} else if Cli.Stdin {
eg.Go(walkStdin)
} else {
// no explicit paths to process, so we only need to process root
pathsCh <- Cli.TreeRoot
@ -280,7 +299,7 @@ func walkFilesystem(ctx context.Context) func() error {
}
// create a filesystem walker
walker, err := walk.New(Cli.Walk, Cli.TreeRoot, pathsCh)
walker, err := walk.New(walkerType, Cli.TreeRoot, pathsCh)
if err != nil {
return fmt.Errorf("failed to create walker: %w", err)
}
@ -288,8 +307,8 @@ func walkFilesystem(ctx context.Context) func() error {
// close the files channel when we're done walking the file system
defer close(filesCh)
// if no cache has been configured, we invoke the walker directly
if Cli.NoCache {
// if no cache has been configured, or we are processing from stdin, we invoke the walker directly
if Cli.NoCache || Cli.Stdin {
return walker.Walk(ctx, func(file *walk.File, err error) error {
select {
case <-ctx.Done():

View File

@ -573,54 +573,49 @@ func TestStdIn(t *testing.T) {
// capture current cwd, so we can replace it after the test is finished
cwd, err := os.Getwd()
as.NoError(err)
t.Cleanup(func() {
// return to the previous working directory
as.NoError(os.Chdir(cwd))
})
tempDir := test.TempExamples(t)
configPath := filepath.Join(tempDir, "/treefmt.toml")
// change working directory to temp root
as.NoError(os.Chdir(tempDir))
// basic config
cfg := config.Config{
Formatters: map[string]*config.Formatter{
"echo": {
Command: "echo",
Includes: []string{"*"},
},
},
}
test.WriteConfig(t, configPath, cfg)
// swap out stdin
// capture current stdin and replace it on test cleanup
prevStdIn := os.Stdin
stdin, err := os.CreateTemp("", "stdin")
as.NoError(err)
os.Stdin = stdin
t.Cleanup(func() {
os.Stdin = prevStdIn
_ = os.Remove(stdin.Name())
})
go func() {
_, err := stdin.WriteString(`treefmt.toml
elm/elm.json
go/main.go
`)
as.NoError(err, "failed to write to stdin")
as.NoError(stdin.Sync())
_, _ = stdin.Seek(0, 0)
}()
//
contents := `{ foo, ... }: "hello"`
os.Stdin = test.TempFile(t, "", "stdin", &contents)
_, err = cmd(t, "-C", tempDir, "--stdin")
out, err := cmd(t, "-C", tempDir, "--allow-missing-formatter", "--stdin", "test.nix")
as.NoError(err)
assertStats(t, as, 3, 3, 3, 0)
assertStats(t, as, 1, 1, 1, 1)
// the nix formatters should have reduced the example to the following
as.Equal(`{ ...}: "hello"
`, string(out))
// try some markdown instead
contents = `
| col1 | col2 |
| ---- | ---- |
| nice | fits |
| oh no! | it's ugly |
`
os.Stdin = test.TempFile(t, "", "stdin", &contents)
out, err = cmd(t, "-C", tempDir, "--allow-missing-formatter", "--stdin", "test.md")
as.NoError(err)
assertStats(t, as, 1, 1, 1, 1)
as.Equal(`| col1 | col2 |
| ------ | --------- |
| nice | fits |
| oh no! | it's ugly |
`, string(out))
}
func TestDeterministicOrderingInPipeline(t *testing.T) {

View File

@ -4,7 +4,6 @@ import (
"fmt"
"io"
"os"
"path/filepath"
"testing"
"github.com/charmbracelet/log"
@ -42,7 +41,7 @@ func cmd(t *testing.T, args ...string) ([]byte, error) {
}
tempDir := t.TempDir()
tempOut := test.TempFile(t, filepath.Join(tempDir, "combined_output"))
tempOut := test.TempFile(t, tempDir, "combined_output", nil)
// capture standard outputs before swapping them
stdout := os.Stdout

View File

@ -65,7 +65,7 @@ func TestReadConfigFile(t *testing.T) {
deadnix, ok := cfg.Formatters["deadnix"]
as.True(ok, "deadnix formatter not found")
as.Equal("deadnix", deadnix.Command)
as.Nil(deadnix.Options)
as.Equal([]string{"-e"}, deadnix.Options)
as.Equal([]string{"*.nix"}, deadnix.Includes)
as.Nil(deadnix.Excludes)
as.Equal(2, deadnix.Priority)

View File

@ -35,6 +35,7 @@ priority = 1
[formatter.deadnix]
command = "deadnix"
options = ["-e"]
includes = ["*.nix"]
priority = 2

View File

@ -1,6 +1,7 @@
package test
import (
"io"
"os"
"testing"
@ -29,15 +30,34 @@ func TempExamples(t *testing.T) string {
return tempDir
}
func TempFile(t *testing.T, path string) *os.File {
func TempFile(t *testing.T, dir string, pattern string, contents *string) *os.File {
t.Helper()
file, err := os.Create(path)
if err != nil {
t.Fatalf("failed to create temporary file: %v", err)
file, err := os.CreateTemp(dir, pattern)
require.NoError(t, err, "failed to create temp file")
if contents == nil {
return file
}
_, err = file.WriteString(*contents)
require.NoError(t, err, "failed to write contents to temp file")
require.NoError(t, file.Close(), "failed to close temp file")
file, err = os.Open(file.Name())
require.NoError(t, err, "failed to open temp file")
return file
}
func ReadStdout(t *testing.T) string {
_, err := os.Stdout.Seek(0, 0)
require.NoError(t, err, "failed to seek to 0")
bytes, err := io.ReadAll(os.Stdout)
require.NoError(t, err, "failed to read")
return string(bytes)
}
func RecreateSymlink(t *testing.T, path string) error {
t.Helper()
src, err := os.Readlink(path)

View File

@ -2,6 +2,7 @@ package walk
import (
"context"
"fmt"
"io/fs"
"path/filepath"
)
@ -18,17 +19,27 @@ func (f filesystemWalker) Root() string {
func (f filesystemWalker) Walk(_ context.Context, fn WalkFunc) error {
relPathOffset := len(f.root) + 1
relPathFn := func(path string) (relPath string) {
relPathFn := func(path string) (string, error) {
// quick optimisation for the majority of use cases
// todo check that root is a prefix in path?
if len(path) >= relPathOffset {
relPath = path[relPathOffset:]
return path[relPathOffset:], nil
}
return
return filepath.Rel(f.root, path)
}
walkFn := func(path string, info fs.FileInfo, err error) error {
if info == nil {
return fmt.Errorf("no such file or directory '%s'", path)
}
relPath, err := relPathFn(path)
if err != nil {
return fmt.Errorf("failed to determine a relative path for %s: %w", path, err)
}
file := File{
Path: path,
RelPath: relPathFn(path),
RelPath: relPath,
Info: info,
}
return fn(&file, err)