sq/libsq/ast/parser_test.go

118 lines
3.0 KiB
Go
Raw Normal View History

2020-08-06 20:58:47 +03:00
package ast
import (
"testing"
"github.com/neilotoole/sq/testh/tutil"
"github.com/neilotoole/slogt"
"github.com/antlr4-go/antlr/v4"
2020-08-06 20:58:47 +03:00
"github.com/stretchr/testify/require"
"github.com/neilotoole/sq/libsq/ast/internal/slq"
2020-08-06 20:58:47 +03:00
)
// getSLQParser returns a parser for the given SQL input.
func getSLQParser(input string) *slq.SLQParser {
is := antlr.NewInputStream(input)
lex := slq.NewSLQLexer(is)
ts := antlr.NewCommonTokenStream(lex, 0)
p := slq.NewSLQParser(ts)
return p
}
// buildInitialAST returns a new AST created by parseTreeVisitor. The AST has not
// yet been processed.
func buildInitialAST(t *testing.T, input string) (*AST, error) {
log := slogt.New(t)
2020-08-06 20:58:47 +03:00
p := getSLQParser(input)
q, _ := p.Query().(*slq.QueryContext)
2020-08-06 20:58:47 +03:00
v := &parseTreeVisitor{log: log}
err := q.Accept(v)
if err != nil {
return nil, err.(error)
}
return v.ast, nil
2020-08-06 20:58:47 +03:00
}
// mustParse builds a full AST from the input SLQ, or fails on any error.
func mustParse(t *testing.T, input string) *AST {
log := slogt.New(t)
2020-08-06 20:58:47 +03:00
ast, err := Parse(log, input)
require.NoError(t, err)
return ast
}
func TestSimpleQuery(t *testing.T) {
const q1 = `@mydb1 | .user | .uid, .username`
log := slogt.New(t)
ptree, err := parseSLQ(log, q1)
2020-08-06 20:58:47 +03:00
require.Nil(t, err)
require.NotNil(t, ptree)
ast, err := buildAST(log, ptree)
require.Nil(t, err)
require.NotNil(t, ast)
}
// TestParseBuild performs some basic testing of the parser.
// These tests are largely duplicates of other tests, and
// probably should be consolidated.
2020-08-06 20:58:47 +03:00
func TestParseBuild(t *testing.T) {
testCases := []struct {
name string
in string
}{
{"rr1", `@mydb1 | .user | .uid, .username | .[]`},
{"rr2", `@mydb1 | .user | .uid, .username | .[2]`},
{"rr3", `@mydb1 | .user | .uid, .username | .[1:3]`},
{"rr4", `@mydb1 | .user | .uid, .username | .[0:3]`},
{"rr5", `@mydb1 | .user | .uid, .username | .[:3]`},
{"rr6", `@mydb1 | .user | .uid, .username | .[2:]`},
{"join with row range", `@my1 |.user | join(.address, .uid) | .[0:4] | .user.uid, .username, .country`},
{"join1", `@mydb1 | .user | join(.address, .user.uid == .address.uid) | .uid, .username, .country`},
{"select1", `@mydb1 | .user | .uid, .username`},
{"tbl datasource", `@mydb1.user | .uid, .username`},
{"count1", `@mydb1.user | count`},
}
for i, tc := range testCases {
t.Run(tutil.Name(i, tc.name), func(t *testing.T) {
t.Logf(tc.in)
log := slogt.New(t)
ptree, err := parseSLQ(log, tc.in)
require.Nil(t, err)
require.NotNil(t, ptree)
2020-08-06 20:58:47 +03:00
ast, err := buildAST(log, ptree)
require.Nil(t, err)
require.NotNil(t, ast)
})
2020-08-06 20:58:47 +03:00
}
}
func TestInspector_FindWhereClauses(t *testing.T) {
log := slogt.New(t)
2020-08-06 20:58:47 +03:00
// Verify that "where(.uid > 4)" becomes a WHERE clause.
const input = "@my1 | .actor | where(.uid > 4) | .uid, .username"
2020-08-06 20:58:47 +03:00
ptree, err := parseSLQ(log, input)
require.Nil(t, err)
require.NotNil(t, ptree)
nRoot, err := buildAST(log, ptree)
require.Nil(t, err)
insp := NewInspector(nRoot)
2020-08-06 20:58:47 +03:00
whereNodes, err := insp.FindWhereClauses()
require.NoError(t, err)
require.Len(t, whereNodes, 1)
}