sq/libsq/ast/parser_test.go
Neil O'Toole a1a89ee9dd
Support table and column names with spaces. (#156)
* sakila: initial test data

* sakila: more test data

* sakila: yet more test data setup

* whitespace cols: now working for sqlite

* grammar cleanup

* whitespace cols: now working inside count() func for sqlite

* whitespace cols: tests mostly passing; begining refactoring

* grammar: refactor handle

* grammar: more refactoring

* grammar: rename selElement to selector

* wip

* all tests passing

* all tests passing

* linting

* driver: implement CurrentSchema for all driver.SQLDriver impls

* driver: tests for AlterTableRename and AlterTableRenameColumn

* undo reformat of SQL

* undo reformat of SQL

* undo reformat of SQL

* undo reformat of SQL
2023-03-22 00:17:34 -06:00

123 lines
3.5 KiB
Go

package ast
import (
"testing"
"github.com/antlr/antlr4/runtime/Go/antlr/v4"
"github.com/stretchr/testify/require"
"github.com/neilotoole/lg/testlg"
"github.com/neilotoole/sq/libsq/ast/internal/slq"
)
const (
fixtRowRange1 = `@mydb1 | .user | .uid, .username | .[]`
fixtRowRange2 = `@mydb1 | .user | .uid, .username | .[2]`
fixtRowRange3 = `@mydb1 | .user | .uid, .username | .[1:3]`
fixtRowRange4 = `@mydb1 | .user | .uid, .username | .[0:3]`
fixtRowRange5 = `@mydb1 | .user | .uid, .username | .[:3]`
fixtRowRange6 = `@mydb1 | .user | .uid, .username | .[2:]`
fixtJoinQuery1 = `@mydb1 | .user, .address | join(.user.uid == .address.uid) | .uid, .username, .country`
fixtSelect1 = `@mydb1 | .user | .uid, .username`
)
var slqInputs = map[string]string{
"rr1": `@mydb1 | .user | .uid, .username | .[]`,
"rr2": `@mydb1 | .user | .uid, .username | .[2]`,
"rr3": `@mydb1 | .user | .uid, .username | .[1:3]`,
"rr4": `@mydb1 | .user | .uid, .username | .[0:3]`,
"rr5": `@mydb1 | .user | .uid, .username | .[:3]`,
"rr6": `@mydb1 | .user | .uid, .username | .[2:]`,
"join with row range": `@my1 |.user, .address | join(.uid) | .[0:4] | .user.uid, .username, .country`,
"join1": `@mydb1 | .user, .address | join(.user.uid == .address.uid) | .uid, .username, .country`,
"select1": `@mydb1 | .user | .uid, .username`,
"tbl datasource": `@mydb1.user | .uid, .username`,
"count1": `@mydb1.user | count(*)`,
}
// getSLQParser returns a parser for the given SQL input.
func getSLQParser(input string) *slq.SLQParser {
is := antlr.NewInputStream(input)
lex := slq.NewSLQLexer(is)
ts := antlr.NewCommonTokenStream(lex, 0)
p := slq.NewSLQParser(ts)
return p
}
// buildInitialAST returns a new AST created by parseTreeVisitor. The AST has not
// yet been processed.
func buildInitialAST(t *testing.T, input string) (*AST, error) {
log := testlg.New(t).Strict(true)
p := getSLQParser(input)
q, _ := p.Query().(*slq.QueryContext)
v := &parseTreeVisitor{log: log}
err := q.Accept(v)
if err != nil {
return nil, err.(error)
}
return v.ast, nil
}
// mustParse builds a full AST from the input SLQ, or fails on any error.
func mustParse(t *testing.T, input string) *AST {
log := testlg.New(t).Strict(true)
ast, err := Parse(log, input)
require.NoError(t, err)
return ast
}
func TestSimpleQuery(t *testing.T) {
log := testlg.New(t).Strict(true)
const input = fixtSelect1
ptree, err := parseSLQ(log, input)
require.Nil(t, err)
require.NotNil(t, ptree)
ast, err := buildAST(log, ptree)
require.Nil(t, err)
require.NotNil(t, ast)
}
func TestParseBuild(t *testing.T) {
for test, input := range slqInputs {
test, input := test, input
t.Run(test, func(t *testing.T) {
t.Logf(input)
log := testlg.New(t)
ptree, err := parseSLQ(log, input)
require.Nil(t, err)
require.NotNil(t, ptree)
ast, err := buildAST(log, ptree)
require.Nil(t, err)
require.NotNil(t, ast)
})
}
}
func TestInspector_FindWhereClauses(t *testing.T) {
log := testlg.New(t)
// Verify that ".uid > 4" becomes a WHERE clause.
const input = "@my1 | .tbluser | .uid > 4 | .uid, .username"
ptree, err := parseSLQ(log, input)
require.Nil(t, err)
require.NotNil(t, ptree)
nRoot, err := buildAST(log, ptree)
require.Nil(t, err)
insp := NewInspector(log, nRoot)
whereNodes, err := insp.FindWhereClauses()
require.NoError(t, err)
require.Len(t, whereNodes, 1)
}