package ast import ( "testing" "github.com/neilotoole/slogt" "github.com/antlr/antlr4/runtime/Go/antlr/v4" "github.com/stretchr/testify/require" "github.com/neilotoole/sq/libsq/ast/internal/slq" ) const ( fixtRowRange1 = `@mydb1 | .user | .uid, .username | .[]` fixtRowRange2 = `@mydb1 | .user | .uid, .username | .[2]` fixtRowRange3 = `@mydb1 | .user | .uid, .username | .[1:3]` fixtRowRange4 = `@mydb1 | .user | .uid, .username | .[0:3]` fixtRowRange5 = `@mydb1 | .user | .uid, .username | .[:3]` fixtRowRange6 = `@mydb1 | .user | .uid, .username | .[2:]` fixtJoinQuery1 = `@mydb1 | .user, .address | join(.user.uid == .address.uid) | .uid, .username, .country` fixtSelect1 = `@mydb1 | .user | .uid, .username` ) var slqInputs = map[string]string{ "rr1": `@mydb1 | .user | .uid, .username | .[]`, "rr2": `@mydb1 | .user | .uid, .username | .[2]`, "rr3": `@mydb1 | .user | .uid, .username | .[1:3]`, "rr4": `@mydb1 | .user | .uid, .username | .[0:3]`, "rr5": `@mydb1 | .user | .uid, .username | .[:3]`, "rr6": `@mydb1 | .user | .uid, .username | .[2:]`, "join with row range": `@my1 |.user, .address | join(.uid) | .[0:4] | .user.uid, .username, .country`, "join1": `@mydb1 | .user, .address | join(.user.uid == .address.uid) | .uid, .username, .country`, "select1": `@mydb1 | .user | .uid, .username`, "tbl datasource": `@mydb1.user | .uid, .username`, "count1": `@mydb1.user | count`, } // getSLQParser returns a parser for the given SQL input. func getSLQParser(input string) *slq.SLQParser { is := antlr.NewInputStream(input) lex := slq.NewSLQLexer(is) ts := antlr.NewCommonTokenStream(lex, 0) p := slq.NewSLQParser(ts) return p } // buildInitialAST returns a new AST created by parseTreeVisitor. The AST has not // yet been processed. func buildInitialAST(t *testing.T, input string) (*AST, error) { log := slogt.New(t) p := getSLQParser(input) q, _ := p.Query().(*slq.QueryContext) v := &parseTreeVisitor{log: log} err := q.Accept(v) if err != nil { return nil, err.(error) } return v.ast, nil } // mustParse builds a full AST from the input SLQ, or fails on any error. func mustParse(t *testing.T, input string) *AST { log := slogt.New(t) ast, err := Parse(log, input) require.NoError(t, err) return ast } func TestSimpleQuery(t *testing.T) { log := slogt.New(t) const input = fixtSelect1 ptree, err := parseSLQ(log, input) require.Nil(t, err) require.NotNil(t, ptree) ast, err := buildAST(log, ptree) require.Nil(t, err) require.NotNil(t, ast) } func TestParseBuild(t *testing.T) { for test, input := range slqInputs { test, input := test, input t.Run(test, func(t *testing.T) { t.Logf(input) log := slogt.New(t) ptree, err := parseSLQ(log, input) require.Nil(t, err) require.NotNil(t, ptree) ast, err := buildAST(log, ptree) require.Nil(t, err) require.NotNil(t, ast) }) } } func TestInspector_FindWhereClauses(t *testing.T) { log := slogt.New(t) // Verify that ".uid > 4" becomes a WHERE clause. const input = "@my1 | .actor | .uid > 4 | .uid, .username" ptree, err := parseSLQ(log, input) require.Nil(t, err) require.NotNil(t, ptree) nRoot, err := buildAST(log, ptree) require.Nil(t, err) insp := NewInspector(nRoot) whereNodes, err := insp.FindWhereClauses() require.NoError(t, err) require.Len(t, whereNodes, 1) }