SLQ support for column aliases (#150)

* alias: more early work

* alias: test cases working for sqlite

* alias: SQL builder tests

* alias: func (col expr) aliases now working for SQLite

* linting

* CHANGELOG update

* Docs update

* Docs update

* Rename buildAst() -> buildAST()

* CHANGELOG typo
This commit is contained in:
Neil O'Toole 2023-03-18 22:58:00 -06:00 committed by GitHub
parent 62f067f633
commit d3e6f89829
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
31 changed files with 1261 additions and 388 deletions

View File

@ -5,6 +5,13 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [v0.25.0] - 2023-03-18
### Added
- [#15] Column Aliases. You can now change specify an alias for a column (or column expression
such as a function). For example: `sq '.actor | .first_name:given_name`, or `sq .actor | count(*):quantity`.
## [v0.24.4] - 2023-03-15 ## [v0.24.4] - 2023-03-15
### Fixed ### Fixed
@ -159,6 +166,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- [#89]: Bug with SQL generated for joins. - [#89]: Bug with SQL generated for joins.
[v0.25.0]: https://github.com/neilotoole/sq/compare/v0.24.4...v0.25.0
[v0.24.4]: https://github.com/neilotoole/sq/compare/v0.24.3...v0.24.4 [v0.24.4]: https://github.com/neilotoole/sq/compare/v0.24.3...v0.24.4
[v0.24.3]: https://github.com/neilotoole/sq/compare/v0.24.2...v0.24.3 [v0.24.3]: https://github.com/neilotoole/sq/compare/v0.24.2...v0.24.3
[v0.24.2]: https://github.com/neilotoole/sq/compare/v0.24.1...v0.24.2 [v0.24.2]: https://github.com/neilotoole/sq/compare/v0.24.1...v0.24.2
@ -184,3 +192,4 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
[#95]: https://github.com/neilotoole/sq/issues/93 [#95]: https://github.com/neilotoole/sq/issues/93
[#91]: https://github.com/neilotoole/sq/pull/91 [#91]: https://github.com/neilotoole/sq/pull/91
[#89]: https://github.com/neilotoole/sq/pull/89 [#89]: https://github.com/neilotoole/sq/pull/89
[#15]: https://github.com/neilotoole/sq/issues/15

View File

@ -32,8 +32,8 @@ When adding a data source, LOCATION is the only required arg.
# Add a postgres source with handle "@sakila_pg" # Add a postgres source with handle "@sakila_pg"
$ sq add -h @sakila_pg 'postgres://user:pass@localhost/sakila' $ sq add -h @sakila_pg 'postgres://user:pass@localhost/sakila'
The format of LOCATION varies, but is generally a DB connection string, a The format of LOCATION is driver-specific,but is generally a DB connection
file path, or a URL. string, a file path, or a URL.
DRIVER://USER:PASS@HOST:PORT/DBNAME DRIVER://USER:PASS@HOST:PORT/DBNAME
/path/to/local/file.ext /path/to/local/file.ext
@ -74,7 +74,7 @@ is ambiguous, explicitly specify the driver type.
$ sq add --driver=tsv ./mystery.data $ sq add --driver=tsv ./mystery.data
Available source driver types can be listed via "sq drivers". At a Available source driver types can be listed via "sq driver ls". At a
minimum, the following drivers are bundled: minimum, the following drivers are bundled:
sqlite3 SQLite sqlite3 SQLite
@ -88,6 +88,9 @@ minimum, the following drivers are bundled:
jsonl JSON Lines: LF-delimited JSON objects jsonl JSON Lines: LF-delimited JSON objects
xlsx Microsoft Excel XLSX xlsx Microsoft Excel XLSX
If there isn't already an active source, the newly added source becomes the
active source.
More examples: More examples:
# Add a source, but prompt user for password # Add a source, but prompt user for password

View File

@ -82,7 +82,7 @@ const (
flagTableUsage = "Output text table" flagTableUsage = "Output text table"
flagTblData = "data" flagTblData = "data"
flagTblDataUsage = "Copy table data (default true)" flagTblDataUsage = "Copy table data"
flagPingTimeout = "timeout" flagPingTimeout = "timeout"
flagPingTimeoutUsage = "Max time to wait for ping" flagPingTimeoutUsage = "Max time to wait for ping"

View File

@ -0,0 +1,85 @@
package mysql_test
import (
"testing"
"github.com/neilotoole/sq/libsq"
"github.com/stretchr/testify/require"
_ "github.com/mattn/go-sqlite3"
"github.com/neilotoole/sq/testh"
"github.com/neilotoole/sq/testh/sakila"
)
func TestSLQ2SQL(t *testing.T) {
t.Parallel()
testCases := []struct {
name string
handles []string
slq string
wantSQL string
wantErr bool
}{
{
name: "join",
handles: []string{sakila.My8},
slq: `@sakila_my8 | .actor, .film_actor | join(.film_actor.actor_id == .actor.actor_id)`,
wantSQL: "SELECT * FROM `actor` INNER JOIN `film_actor` ON `film_actor`.`actor_id` = `actor`.`actor_id`",
},
{
name: "select-cols",
handles: []string{sakila.My8},
slq: `@sakila_my8 | .actor | .first_name, .last_name`,
wantSQL: "SELECT `first_name`, `last_name` FROM `actor`",
},
{
name: "select-cols-aliases",
handles: []string{sakila.My8},
slq: `@sakila_my8 | .actor | .first_name:given_name, .last_name:family_name`,
wantSQL: "SELECT `first_name` AS `given_name`, `last_name` AS `family_name` FROM `actor`",
},
{
name: "select-count-star",
handles: []string{sakila.My8},
slq: `@sakila_my8 | .actor | count(*)`,
wantSQL: "SELECT COUNT(*) FROM `actor`",
},
{
name: "select-count",
handles: []string{sakila.My8},
slq: `@sakila_my8 | .actor | count()`,
wantSQL: "SELECT COUNT(*) FROM `actor`",
},
{
name: "select-count-alias",
handles: []string{sakila.My8},
slq: `@sakila_my8 | .actor | count(*):quantity`,
wantSQL: "SELECT COUNT(*) AS `quantity` FROM `actor`",
},
}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
th := testh.New(t)
srcs := th.NewSourceSet(tc.handles...)
_, err := srcs.SetActive(tc.handles[0])
require.NoError(t, err)
dbases := th.Databases()
gotSQL, gotErr := libsq.SLQ2SQL(th.Context, th.Log, dbases, dbases, srcs, tc.slq)
if tc.wantErr {
require.Error(t, gotErr)
return
}
require.NoError(t, gotErr)
require.Equal(t, tc.wantSQL, gotSQL)
})
}
}

View File

@ -0,0 +1,85 @@
package postgres_test
import (
"testing"
"github.com/neilotoole/sq/libsq"
"github.com/stretchr/testify/require"
_ "github.com/mattn/go-sqlite3"
"github.com/neilotoole/sq/testh"
"github.com/neilotoole/sq/testh/sakila"
)
func TestSLQ2SQL(t *testing.T) {
t.Parallel()
testCases := []struct {
name string
handles []string
slq string
wantSQL string
wantErr bool
}{
{
name: "join",
handles: []string{sakila.Pg12},
slq: `@sakila_pg12 | .actor, .film_actor | join(.film_actor.actor_id == .actor.actor_id)`,
wantSQL: `SELECT * FROM "actor" INNER JOIN "film_actor" ON "film_actor"."actor_id" = "actor"."actor_id"`,
},
{
name: "select-cols",
handles: []string{sakila.Pg12},
slq: `@sakila_pg12 | .actor | .first_name, .last_name`,
wantSQL: `SELECT "first_name", "last_name" FROM "actor"`,
},
{
name: "select-cols-aliases",
handles: []string{sakila.Pg12},
slq: `@sakila_pg12 | .actor | .first_name:given_name, .last_name:family_name`,
wantSQL: `SELECT "first_name" AS "given_name", "last_name" AS "family_name" FROM "actor"`,
},
{
name: "select-count-star",
handles: []string{sakila.Pg12},
slq: `@sakila_pg12 | .actor | count(*)`,
wantSQL: `SELECT COUNT(*) FROM "actor"`,
},
{
name: "select-count",
handles: []string{sakila.Pg12},
slq: `@sakila_pg12 | .actor | count()`,
wantSQL: `SELECT COUNT(*) FROM "actor"`,
},
{
name: "select-count-alias",
handles: []string{sakila.Pg12},
slq: `@sakila_pg12 | .actor | count(*):quantity`,
wantSQL: `SELECT COUNT(*) AS "quantity" FROM "actor"`,
},
}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
th := testh.New(t)
srcs := th.NewSourceSet(tc.handles...)
_, err := srcs.SetActive(tc.handles[0])
require.NoError(t, err)
dbases := th.Databases()
gotSQL, gotErr := libsq.SLQ2SQL(th.Context, th.Log, dbases, dbases, srcs, tc.slq)
if tc.wantErr {
require.Error(t, gotErr)
return
}
require.NoError(t, gotErr)
require.Equal(t, tc.wantSQL, gotSQL)
})
}
}

View File

@ -0,0 +1,85 @@
package sqlite3_test
import (
"testing"
"github.com/neilotoole/sq/libsq"
"github.com/stretchr/testify/require"
_ "github.com/mattn/go-sqlite3"
"github.com/neilotoole/sq/testh"
"github.com/neilotoole/sq/testh/sakila"
)
func TestSLQ2SQL(t *testing.T) {
t.Parallel()
testCases := []struct {
name string
handles []string
slq string
wantSQL string
wantErr bool
}{
{
name: "join",
handles: []string{sakila.SL3},
slq: `@sakila_sl3 | .actor, .film_actor | join(.film_actor.actor_id == .actor.actor_id)`,
wantSQL: `SELECT * FROM "actor" INNER JOIN "film_actor" ON "film_actor"."actor_id" = "actor"."actor_id"`,
},
{
name: "select-cols",
handles: []string{sakila.SL3},
slq: `@sakila_sl3 | .actor | .first_name, .last_name`,
wantSQL: `SELECT "first_name", "last_name" FROM "actor"`,
},
{
name: "select-cols-aliases",
handles: []string{sakila.SL3},
slq: `@sakila_sl3 | .actor | .first_name:given_name, .last_name:family_name`,
wantSQL: `SELECT "first_name" AS "given_name", "last_name" AS "family_name" FROM "actor"`,
},
{
name: "select-count-star",
handles: []string{sakila.SL3},
slq: `@sakila_sl3 | .actor | count(*)`,
wantSQL: `SELECT COUNT(*) FROM "actor"`,
},
{
name: "select-count",
handles: []string{sakila.SL3},
slq: `@sakila_sl3 | .actor | count()`,
wantSQL: `SELECT COUNT(*) FROM "actor"`,
},
{
name: "select-count-alias",
handles: []string{sakila.SL3},
slq: `@sakila_sl3 | .actor | count(*):quantity`,
wantSQL: `SELECT COUNT(*) AS "quantity" FROM "actor"`,
},
}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
th := testh.New(t)
srcs := th.NewSourceSet(tc.handles...)
_, err := srcs.SetActive(tc.handles[0])
require.NoError(t, err)
dbases := th.Databases()
gotSQL, gotErr := libsq.SLQ2SQL(th.Context, th.Log, dbases, dbases, srcs, tc.slq)
if tc.wantErr {
require.Error(t, gotErr)
return
}
require.NoError(t, gotErr)
require.Equal(t, tc.wantSQL, gotSQL)
})
}
}

View File

@ -0,0 +1,85 @@
package sqlserver_test
import (
"testing"
"github.com/neilotoole/sq/libsq"
"github.com/stretchr/testify/require"
_ "github.com/mattn/go-sqlite3"
"github.com/neilotoole/sq/testh"
"github.com/neilotoole/sq/testh/sakila"
)
func TestSLQ2SQL(t *testing.T) {
t.Parallel()
testCases := []struct {
name string
handles []string
slq string
wantSQL string
wantErr bool
}{
{
name: "join",
handles: []string{sakila.MS17},
slq: `@sakila_ms17 | .actor, .film_actor | join(.film_actor.actor_id == .actor.actor_id)`,
wantSQL: `SELECT * FROM "actor" INNER JOIN "film_actor" ON "film_actor"."actor_id" = "actor"."actor_id"`,
},
{
name: "select-cols",
handles: []string{sakila.MS17},
slq: `@sakila_ms17 | .actor | .first_name, .last_name`,
wantSQL: `SELECT "first_name", "last_name" FROM "actor"`,
},
{
name: "select-cols-aliases",
handles: []string{sakila.MS17},
slq: `@sakila_ms17 | .actor | .first_name:given_name, .last_name:family_name`,
wantSQL: `SELECT "first_name" AS "given_name", "last_name" AS "family_name" FROM "actor"`,
},
{
name: "select-count-star",
handles: []string{sakila.MS17},
slq: `@sakila_ms17 | .actor | count(*)`,
wantSQL: `SELECT COUNT(*) FROM "actor"`,
},
{
name: "select-count",
handles: []string{sakila.MS17},
slq: `@sakila_ms17 | .actor | count()`,
wantSQL: `SELECT COUNT(*) FROM "actor"`,
},
{
name: "select-count-alias",
handles: []string{sakila.MS17},
slq: `@sakila_ms17 | .actor | count(*):quantity`,
wantSQL: `SELECT COUNT(*) AS "quantity" FROM "actor"`,
},
}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
th := testh.New(t)
srcs := th.NewSourceSet(tc.handles...)
_, err := srcs.SetActive(tc.handles[0])
require.NoError(t, err)
dbases := th.Databases()
gotSQL, gotErr := libsq.SLQ2SQL(th.Context, th.Log, dbases, dbases, srcs, tc.slq)
if tc.wantErr {
require.Error(t, gotErr)
return
}
require.NoError(t, gotErr)
require.Equal(t, tc.wantSQL, gotSQL)
})
}
}

View File

@ -1,7 +1,8 @@
// This is the grammar for SLQ, the query language used by sq (https://sq.io).
// The grammar is not yet finalized; it is subject to change in any new sq release.
grammar SLQ; grammar SLQ;
// "@mysql_db1 | .user, .address | join(.user.uid == .address.uid) | .[0:3] | .uid, .username, .country" // "@mysql_db1 | .user, .address | join(.user.uid == .address.uid) | .[0:3] | .uid, .username, .country"
stmtList: ';'* query ( ';'+ query)* ';'*; stmtList: ';'* query ( ';'+ query)* ';'*;
query: segment ('|' segment)*; query: segment ('|' segment)*;
@ -15,16 +16,16 @@ element:
| join | join
| group | group
| rowRange | rowRange
| fn | fnElement
| expr; | expr;
cmpr: LT_EQ | LT | GT_EQ | GT | EQ | NEQ; cmpr: LT_EQ | LT | GT_EQ | GT | EQ | NEQ;
//whereExpr
// : expr ;
fn: fnName '(' ( expr ( ',' expr)* | '*')? ')'; fn: fnName '(' ( expr ( ',' expr)* | '*')? ')';
fnElement: fn (alias)?;
join: ('join' | 'JOIN' | 'j') '(' joinConstraint ')'; join: ('join' | 'JOIN' | 'j') '(' joinConstraint ')';
joinConstraint: joinConstraint:
@ -33,12 +34,21 @@ joinConstraint:
group: ('group' | 'GROUP' | 'g') '(' SEL (',' SEL)* ')'; group: ('group' | 'GROUP' | 'g') '(' SEL (',' SEL)* ')';
selElement: SEL; // alias, for columns, implements "col AS alias".
// For example: ".first_name:given_name" : "given_name" is the alias.
alias: ':' ID;
selElement: SEL (alias)?;
dsTblElement: dsTblElement:
DATASOURCE SEL; // data source table element, e.g. @my1.user // dsTblElement is a data source table element. This is a data
// source with followed by a table.
// - @my1.user
DATASOURCE SEL;
dsElement: DATASOURCE; // data source element, e.g. @my1 dsElement:
// dsElement is a data source element, e.g. @my1
DATASOURCE;
// [] select all rows [10] select row 10 [10:15] select rows 10 thru 15 [0:15] select rows 0 thru 15 // [] select all rows [10] select row 10 [10:15] select rows 10 thru 15 [0:15] select rows 0 thru 15
// [:15] same as above (0 thru 15) [10:] select all rows from 10 onwards // [:15] same as above (0 thru 15) [10:] select all rows from 10 onwards
@ -107,6 +117,8 @@ GT: '>';
NEQ: '!='; NEQ: '!=';
EQ: '=='; EQ: '==';
SEL: SEL:
'.' ID ('.' ID)*; // SEL can be .THING or .THING.OTHERTHING etc. '.' ID ('.' ID)*; // SEL can be .THING or .THING.OTHERTHING etc.
DATASOURCE: DATASOURCE:

View File

@ -0,0 +1 @@
@sakila | .actor | .first_name:given_name, .last_name:family_name

View File

@ -14,18 +14,14 @@ import (
) )
// Parse parses the SLQ input string and builds the AST. // Parse parses the SLQ input string and builds the AST.
func Parse(log lg.Log, input string) (*AST, error) { func Parse(log lg.Log, input string) (*AST, error) { //nolint:staticcheck
log = lg.Discard() //nolint:staticcheck // Disable parser logging.
ptree, err := parseSLQ(log, input) ptree, err := parseSLQ(log, input)
if err != nil { if err != nil {
return nil, err return nil, err
} }
atree, err := buildAST(log, ptree) return buildAST(log, ptree)
if err != nil {
return nil, err
}
return atree, nil
} }
// buildAST constructs sq's AST from a parse tree. // buildAST constructs sq's AST from a parse tree.
@ -39,7 +35,9 @@ func buildAST(log lg.Log, query slq.IQueryContext) (*AST, error) {
return nil, errorf("unable to convert %T to *parser.QueryContext", query) return nil, errorf("unable to convert %T to *parser.QueryContext", query)
} }
v := &parseTreeVisitor{log: lg.Discard()} v := &parseTreeVisitor{log: log}
// Accept returns an interface{} instead of error (but it's always an error?)
er := q.Accept(v) er := q.Accept(v)
if er != nil { if er != nil {
return nil, er.(error) return nil, er.(error)

View File

@ -9,6 +9,7 @@ var (
type Func struct { type Func struct {
baseNode baseNode
fnName string fnName string
alias string
} }
// FuncName returns the function name. // FuncName returns the function name.
@ -16,8 +17,13 @@ func (fn *Func) FuncName() string {
return fn.fnName return fn.fnName
} }
// String returns a log/debug-friendly representation.
func (fn *Func) String() string { func (fn *Func) String() string {
return nodeString(fn) str := nodeString(fn)
if fn.alias != "" {
str += ":" + fn.alias
}
return str
} }
// ColExpr implements ColExpr. // ColExpr implements ColExpr.
@ -25,6 +31,11 @@ func (fn *Func) ColExpr() (string, error) {
return fn.ctx.GetText(), nil return fn.ctx.GetText(), nil
} }
// Alias implements ColExpr.
func (fn *Func) Alias() string {
return fn.alias
}
// SetChildren implements Node. // SetChildren implements Node.
func (fn *Func) SetChildren(children []Node) error { func (fn *Func) SetChildren(children []Node) error {
fn.setChildren(children) fn.setChildren(children)

File diff suppressed because one or more lines are too long

View File

@ -56,6 +56,12 @@ func (s *BaseSLQListener) EnterFn(ctx *FnContext) {}
// ExitFn is called when production fn is exited. // ExitFn is called when production fn is exited.
func (s *BaseSLQListener) ExitFn(ctx *FnContext) {} func (s *BaseSLQListener) ExitFn(ctx *FnContext) {}
// EnterFnElement is called when production fnElement is entered.
func (s *BaseSLQListener) EnterFnElement(ctx *FnElementContext) {}
// ExitFnElement is called when production fnElement is exited.
func (s *BaseSLQListener) ExitFnElement(ctx *FnElementContext) {}
// EnterJoin is called when production join is entered. // EnterJoin is called when production join is entered.
func (s *BaseSLQListener) EnterJoin(ctx *JoinContext) {} func (s *BaseSLQListener) EnterJoin(ctx *JoinContext) {}
@ -74,6 +80,12 @@ func (s *BaseSLQListener) EnterGroup(ctx *GroupContext) {}
// ExitGroup is called when production group is exited. // ExitGroup is called when production group is exited.
func (s *BaseSLQListener) ExitGroup(ctx *GroupContext) {} func (s *BaseSLQListener) ExitGroup(ctx *GroupContext) {}
// EnterAlias is called when production alias is entered.
func (s *BaseSLQListener) EnterAlias(ctx *AliasContext) {}
// ExitAlias is called when production alias is exited.
func (s *BaseSLQListener) ExitAlias(ctx *AliasContext) {}
// EnterSelElement is called when production selElement is entered. // EnterSelElement is called when production selElement is entered.
func (s *BaseSLQListener) EnterSelElement(ctx *SelElementContext) {} func (s *BaseSLQListener) EnterSelElement(ctx *SelElementContext) {}

View File

@ -31,6 +31,10 @@ func (v *BaseSLQVisitor) VisitFn(ctx *FnContext) interface{} {
return v.VisitChildren(ctx) return v.VisitChildren(ctx)
} }
func (v *BaseSLQVisitor) VisitFnElement(ctx *FnElementContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BaseSLQVisitor) VisitJoin(ctx *JoinContext) interface{} { func (v *BaseSLQVisitor) VisitJoin(ctx *JoinContext) interface{} {
return v.VisitChildren(ctx) return v.VisitChildren(ctx)
} }
@ -43,6 +47,10 @@ func (v *BaseSLQVisitor) VisitGroup(ctx *GroupContext) interface{} {
return v.VisitChildren(ctx) return v.VisitChildren(ctx)
} }
func (v *BaseSLQVisitor) VisitAlias(ctx *AliasContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *BaseSLQVisitor) VisitSelElement(ctx *SelElementContext) interface{} { func (v *BaseSLQVisitor) VisitSelElement(ctx *SelElementContext) interface{} {
return v.VisitChildren(ctx) return v.VisitChildren(ctx)
} }

View File

@ -25,6 +25,9 @@ type SLQListener interface {
// EnterFn is called when entering the fn production. // EnterFn is called when entering the fn production.
EnterFn(c *FnContext) EnterFn(c *FnContext)
// EnterFnElement is called when entering the fnElement production.
EnterFnElement(c *FnElementContext)
// EnterJoin is called when entering the join production. // EnterJoin is called when entering the join production.
EnterJoin(c *JoinContext) EnterJoin(c *JoinContext)
@ -34,6 +37,9 @@ type SLQListener interface {
// EnterGroup is called when entering the group production. // EnterGroup is called when entering the group production.
EnterGroup(c *GroupContext) EnterGroup(c *GroupContext)
// EnterAlias is called when entering the alias production.
EnterAlias(c *AliasContext)
// EnterSelElement is called when entering the selElement production. // EnterSelElement is called when entering the selElement production.
EnterSelElement(c *SelElementContext) EnterSelElement(c *SelElementContext)
@ -76,6 +82,9 @@ type SLQListener interface {
// ExitFn is called when exiting the fn production. // ExitFn is called when exiting the fn production.
ExitFn(c *FnContext) ExitFn(c *FnContext)
// ExitFnElement is called when exiting the fnElement production.
ExitFnElement(c *FnElementContext)
// ExitJoin is called when exiting the join production. // ExitJoin is called when exiting the join production.
ExitJoin(c *JoinContext) ExitJoin(c *JoinContext)
@ -85,6 +94,9 @@ type SLQListener interface {
// ExitGroup is called when exiting the group production. // ExitGroup is called when exiting the group production.
ExitGroup(c *GroupContext) ExitGroup(c *GroupContext)
// ExitAlias is called when exiting the alias production.
ExitAlias(c *AliasContext)
// ExitSelElement is called when exiting the selElement production. // ExitSelElement is called when exiting the selElement production.
ExitSelElement(c *SelElementContext) ExitSelElement(c *SelElementContext)

File diff suppressed because it is too large Load Diff

View File

@ -25,6 +25,9 @@ type SLQVisitor interface {
// Visit a parse tree produced by SLQParser#fn. // Visit a parse tree produced by SLQParser#fn.
VisitFn(ctx *FnContext) interface{} VisitFn(ctx *FnContext) interface{}
// Visit a parse tree produced by SLQParser#fnElement.
VisitFnElement(ctx *FnElementContext) interface{}
// Visit a parse tree produced by SLQParser#join. // Visit a parse tree produced by SLQParser#join.
VisitJoin(ctx *JoinContext) interface{} VisitJoin(ctx *JoinContext) interface{}
@ -34,6 +37,9 @@ type SLQVisitor interface {
// Visit a parse tree produced by SLQParser#group. // Visit a parse tree produced by SLQParser#group.
VisitGroup(ctx *GroupContext) interface{} VisitGroup(ctx *GroupContext) interface{}
// Visit a parse tree produced by SLQParser#alias.
VisitAlias(ctx *AliasContext) interface{}
// Visit a parse tree produced by SLQParser#selElement. // Visit a parse tree produced by SLQParser#selElement.
VisitSelElement(ctx *SelElementContext) interface{} VisitSelElement(ctx *SelElementContext) interface{}

View File

@ -15,7 +15,7 @@ type Node interface {
// SetParent sets the node's parent, returning an error if illegal. // SetParent sets the node's parent, returning an error if illegal.
SetParent(n Node) error SetParent(n Node) error
// Children returns the node's children (may be empty). // Children returns the node's children (which may be empty).
Children() []Node Children() []Node
// SetChildren sets the node's children, returning an error if illegal. // SetChildren sets the node's children, returning an error if illegal.
@ -49,8 +49,17 @@ type Selectable interface {
type ColExpr interface { type ColExpr interface {
// IsColName returns true if the expr is a column name, e.g. "uid" or "users.uid". // IsColName returns true if the expr is a column name, e.g. "uid" or "users.uid".
IsColName() bool IsColName() bool
// ColExpr returns the column expression value. For a simple ColSelector ".first_name",
// this would be "first_name".
ColExpr() (string, error) ColExpr() (string, error)
// String returns a log/debug-friendly representation.
String() string String() string
// Alias returns the column alias, which may be empty.
// For example, given the selector ".first_name:given_name", the alias is "given_name".
Alias() string
} }
// baseNode is a base implementation of Node. // baseNode is a base implementation of Node.
@ -60,15 +69,18 @@ type baseNode struct {
ctx antlr.ParseTree ctx antlr.ParseTree
} }
// Parent implements Node.Parent.
func (bn *baseNode) Parent() Node { func (bn *baseNode) Parent() Node {
return bn.parent return bn.parent
} }
// SetParent implements Node.SetParent.
func (bn *baseNode) SetParent(parent Node) error { func (bn *baseNode) SetParent(parent Node) error {
bn.parent = parent bn.parent = parent
return nil return nil
} }
// Children implements Node.Children.
func (bn *baseNode) Children() []Node { func (bn *baseNode) Children() []Node {
return bn.children return bn.children
} }
@ -111,9 +123,9 @@ func nodeString(n Node) string {
return fmt.Sprintf("%T: %s", n, n.Text()) return fmt.Sprintf("%T: %s", n, n.Text())
} }
// replaceNode replaces old with new. That is, nu becomes a child // nodeReplace replaces old with new. That is, nu becomes a child
// of old's parent. // of old's parent.
func replaceNode(old, nu Node) error { func nodeReplace(old, nu Node) error {
err := nu.SetContext(old.Context()) err := nu.SetContext(old.Context())
if err != nil { if err != nil {
return err return err
@ -121,7 +133,7 @@ func replaceNode(old, nu Node) error {
parent := old.Parent() parent := old.Parent()
index := childIndex(parent, old) index := nodeChildIndex(parent, old)
if index < 0 { if index < 0 {
return errorf("parent %T(%q) does not appear to have child %T(%q)", parent, parent.Text(), old, old.Text()) return errorf("parent %T(%q) does not appear to have child %T(%q)", parent, parent.Text(), old, old.Text())
} }
@ -131,18 +143,43 @@ func replaceNode(old, nu Node) error {
return parent.SetChildren(siblings) return parent.SetChildren(siblings)
} }
// childIndex returns the index of child in parent's children, or -1. // nodeChildIndex returns the index of child in parent's children, or -1.
func childIndex(parent, child Node) int { func nodeChildIndex(parent, child Node) int {
index := -1
for i, node := range parent.Children() { for i, node := range parent.Children() {
if node == child { if node == child {
index = i return i
break
} }
} }
return index return -1
}
// nodeFirstChild returns the first child of parent, or nil.
func nodeFirstChild(parent Node) Node { //nolint:unused
if parent == nil {
return nil
}
children := parent.Children()
if len(children) == 0 {
return nil
}
return children[0]
}
// nodeFirstChild returns the last child of parent, or nil.
func nodeLastChild(parent Node) Node {
if parent == nil {
return nil
}
children := parent.Children()
if len(children) == 0 {
return nil
}
return children[len(children)-1]
} }
// nodesWithType returns a new slice containing each member of nodes that is // nodesWithType returns a new slice containing each member of nodes that is

View File

@ -20,7 +20,7 @@ func TestChildIndex(t *testing.T) {
require.Equal(t, 4, len(ast.Segments())) require.Equal(t, 4, len(ast.Segments()))
for i, seg := range ast.Segments() { for i, seg := range ast.Segments() {
index := childIndex(ast, seg) index := nodeChildIndex(ast, seg)
require.Equal(t, i, index) require.Equal(t, i, index)
} }
} }
@ -36,6 +36,6 @@ func TestNodesWithType(t *testing.T) {
func TestAvg(t *testing.T) { func TestAvg(t *testing.T) {
const input = `@mydb1 | .user, .address | join(.user.uid == .address.uid) | .uid, .username, .country | .[0:2] | avg(.uid)` //nolint:lll const input = `@mydb1 | .user, .address | join(.user.uid == .address.uid) | .uid, .username, .country | .[0:2] | avg(.uid)` //nolint:lll
ast := mustBuildAST(t, input) ast := mustParse(t, input)
require.NotNil(t, ast) require.NotNil(t, ast)
} }

View File

@ -116,14 +116,27 @@ var _ slq.SLQVisitor = (*parseTreeVisitor)(nil)
// generate the preliminary AST. // generate the preliminary AST.
type parseTreeVisitor struct { type parseTreeVisitor struct {
log lg.Log log lg.Log
// cur is the currently-active node of the AST. // cur is the currently-active node of the AST.
cur Node cur Node
AST *AST AST *AST
} }
// using is a convenience function that sets v.cur to cur,
// executes fn, and then restores v.cur to its previous value.
// The type of the returned value is declared as "any" instead of
// error, because that's the generated antlr code returns "any".
func (v *parseTreeVisitor) using(cur Node, fn func() any) any {
prev := v.cur
v.cur = cur
defer func() { v.cur = prev }()
return fn()
}
// Visit implements antlr.ParseTreeVisitor. // Visit implements antlr.ParseTreeVisitor.
func (v *parseTreeVisitor) Visit(ctx antlr.ParseTree) any { func (v *parseTreeVisitor) Visit(ctx antlr.ParseTree) any {
v.log.Debugf("visiting %T: %v: ", ctx, ctx.GetText()) v.log.Debugf("visiting %T: %v", ctx, ctx.GetText())
switch ctx := ctx.(type) { switch ctx := ctx.(type) {
case *slq.SegmentContext: case *slq.SegmentContext:
@ -136,12 +149,16 @@ func (v *parseTreeVisitor) Visit(ctx antlr.ParseTree) any {
return v.VisitDsTblElement(ctx) return v.VisitDsTblElement(ctx)
case *slq.SelElementContext: case *slq.SelElementContext:
return v.VisitSelElement(ctx) return v.VisitSelElement(ctx)
case *slq.FnElementContext:
return v.VisitFnElement(ctx)
case *slq.FnContext: case *slq.FnContext:
return v.VisitFn(ctx) return v.VisitFn(ctx)
case *slq.FnNameContext: case *slq.FnNameContext:
return v.VisitFnName(ctx) return v.VisitFnName(ctx)
case *slq.JoinContext: case *slq.JoinContext:
return v.VisitJoin(ctx) return v.VisitJoin(ctx)
case *slq.AliasContext:
return v.VisitAlias(ctx)
case *slq.JoinConstraintContext: case *slq.JoinConstraintContext:
return v.VisitJoinConstraint(ctx) return v.VisitJoinConstraint(ctx)
case *slq.CmprContext: case *slq.CmprContext:
@ -231,7 +248,15 @@ func (v *parseTreeVisitor) VisitSelElement(ctx *slq.SelElementContext) any {
selector := &Selector{} selector := &Selector{}
selector.parent = v.cur selector.parent = v.cur
selector.ctx = ctx.SEL() selector.ctx = ctx.SEL()
return v.cur.AddChild(selector)
var err any
if err = v.cur.AddChild(selector); err != nil {
return err
}
return v.using(selector, func() any {
return v.VisitChildren(ctx)
})
} }
// VisitElement implements slq.SLQVisitor. // VisitElement implements slq.SLQVisitor.
@ -239,9 +264,69 @@ func (v *parseTreeVisitor) VisitElement(ctx *slq.ElementContext) any {
return v.VisitChildren(ctx) return v.VisitChildren(ctx)
} }
// VisitAlias implements slq.SLQVisitor.
func (v *parseTreeVisitor) VisitAlias(ctx *slq.AliasContext) any {
alias := ctx.ID().GetText()
switch node := v.cur.(type) {
case *Selector:
node.alias = alias
case *Func:
node.alias = alias
default:
return errorf("alias not allowed for type %T: %v", node, ctx.GetText())
}
return nil
}
// VisitFnElement implements slq.SLQVisitor.
func (v *parseTreeVisitor) VisitFnElement(ctx *slq.FnElementContext) any {
v.log.Debugf("visiting FnElement: %v", ctx.GetText())
childCount := ctx.GetChildCount()
if childCount == 0 || childCount > 2 {
return errorf("parser: invalid function: expected 1 or 2 children, but got %d: %v",
childCount, ctx.GetText())
}
// e.g. count(*)
child1 := ctx.GetChild(0)
fnCtx, ok := child1.(*slq.FnContext)
if !ok {
return errorf("expected first child to be %T but was %T: %v", fnCtx, child1, ctx.GetText())
}
if err := v.VisitFn(fnCtx); err != nil {
return err
}
// Check if there's an alias
if childCount == 2 {
child2 := ctx.GetChild(1)
aliasCtx, ok := child2.(*slq.AliasContext)
if !ok {
return errorf("expected second child to be %T but was %T: %v", aliasCtx, child2, ctx.GetText())
}
// VisitAlias will expect v.cur to be a Func.
lastNode := nodeLastChild(v.cur)
fnNode, ok := lastNode.(*Func)
if !ok {
return errorf("expected %T but got %T: %v", fnNode, lastNode, ctx.GetText())
}
return v.using(fnNode, func() any {
return v.VisitAlias(aliasCtx)
})
}
return nil
}
// VisitFn implements slq.SLQVisitor. // VisitFn implements slq.SLQVisitor.
func (v *parseTreeVisitor) VisitFn(ctx *slq.FnContext) any { func (v *parseTreeVisitor) VisitFn(ctx *slq.FnContext) any {
v.log.Debugf("visiting function: %v", ctx.GetText()) v.log.Debugf("visiting Fn: %v", ctx.GetText())
fn := &Func{fnName: ctx.FnName().GetText()} fn := &Func{fnName: ctx.FnName().GetText()}
fn.ctx = ctx fn.ctx = ctx
@ -250,12 +335,10 @@ func (v *parseTreeVisitor) VisitFn(ctx *slq.FnContext) any {
return err return err
} }
prev := v.cur if err2 := v.using(fn, func() any {
v.cur = fn return v.VisitChildren(ctx)
err2 := v.VisitChildren(ctx) }); err2 != nil {
v.cur = prev return err2
if err2 != nil {
return err2.(error)
} }
return v.cur.AddChild(fn) return v.cur.AddChild(fn)
@ -348,7 +431,7 @@ func (v *parseTreeVisitor) VisitGroup(ctx *slq.GroupContext) any {
} }
for _, selCtx := range sels { for _, selCtx := range sels {
err = grp.AddChild(newColSelector(grp, selCtx)) err = grp.AddChild(newColSelector(grp, selCtx, "")) // FIXME: Handle alias appropriately
if err != nil { if err != nil {
return err return err
} }
@ -442,7 +525,7 @@ func (v *parseTreeVisitor) VisitJoinConstraint(ctx *slq.JoinConstraintContext) a
return err return err
} }
cmpr := newCmnr(joinCondition, ctx.Cmpr()) cmpr := newCmpr(joinCondition, ctx.Cmpr())
err = joinCondition.AddChild(cmpr) err = joinCondition.AddChild(cmpr)
if err != nil { if err != nil {
return err return err
@ -481,8 +564,7 @@ func (v *parseTreeVisitor) VisitTerminal(ctx antlr.TerminalNode) any {
return nil return nil
} }
v.log.Warnf("unknown terminal: %q", val) // Unknown terminal, but that's not a problem.
return nil return nil
} }

View File

@ -61,10 +61,19 @@ func buildInitialAST(t *testing.T, input string) (*AST, error) {
return v.AST, nil return v.AST, nil
} }
// mustBuildAST builds a full AST from the input SLQ, or fails on any error. // mustParse builds a full AST from the input SLQ, or fails on any error.
func mustBuildAST(t *testing.T, input string) *AST { func mustParse(t *testing.T, input string) *AST {
log := testlg.New(t).Strict(true) log := testlg.New(t).Strict(true)
ast, err := Parse(log, input)
require.NoError(t, err)
return ast
}
func TestSimpleQuery(t *testing.T) {
log := testlg.New(t).Strict(true)
const input = fixtSelect1
ptree, err := parseSLQ(log, input) ptree, err := parseSLQ(log, input)
require.Nil(t, err) require.Nil(t, err)
require.NotNil(t, ptree) require.NotNil(t, ptree)
@ -72,7 +81,6 @@ func mustBuildAST(t *testing.T, input string) *AST {
ast, err := buildAST(log, ptree) ast, err := buildAST(log, ptree)
require.Nil(t, err) require.Nil(t, err)
require.NotNil(t, ast) require.NotNil(t, ast)
return ast
} }
func TestParseBuild(t *testing.T) { func TestParseBuild(t *testing.T) {

View File

@ -17,17 +17,17 @@ import (
func TestRowRange1(t *testing.T) { func TestRowRange1(t *testing.T) {
log := testlg.New(t).Strict(true) log := testlg.New(t).Strict(true)
ast := mustBuildAST(t, fixtRowRange1) ast := mustParse(t, fixtRowRange1)
assert.Equal(t, 0, NewInspector(log, ast).CountNodes(typeRowRange)) assert.Equal(t, 0, NewInspector(log, ast).CountNodes(typeRowRange))
} }
func TestRowRange2(t *testing.T) { func TestRowRange2(t *testing.T) {
log := testlg.New(t).Strict(true) log := testlg.New(t).Strict(true)
ast := mustBuildAST(t, fixtRowRange2) ast := mustParse(t, fixtRowRange2)
ins := NewInspector(log, ast) insp := NewInspector(log, ast)
assert.Equal(t, 1, ins.CountNodes(typeRowRange)) assert.Equal(t, 1, insp.CountNodes(typeRowRange))
nodes := ins.FindNodes(typeRowRange) nodes := insp.FindNodes(typeRowRange)
assert.Equal(t, 1, len(nodes)) assert.Equal(t, 1, len(nodes))
rr, _ := nodes[0].(*RowRange) rr, _ := nodes[0].(*RowRange)
assert.Equal(t, 2, rr.Offset) assert.Equal(t, 2, rr.Offset)
@ -37,9 +37,9 @@ func TestRowRange2(t *testing.T) {
func TestRowRange3(t *testing.T) { func TestRowRange3(t *testing.T) {
log := testlg.New(t).Strict(true) log := testlg.New(t).Strict(true)
ast := mustBuildAST(t, fixtRowRange3) ast := mustParse(t, fixtRowRange3)
ins := NewInspector(log, ast) insp := NewInspector(log, ast)
rr, _ := ins.FindNodes(typeRowRange)[0].(*RowRange) rr, _ := insp.FindNodes(typeRowRange)[0].(*RowRange)
assert.Equal(t, 1, rr.Offset) assert.Equal(t, 1, rr.Offset)
assert.Equal(t, 2, rr.Limit) assert.Equal(t, 2, rr.Limit)
} }
@ -47,27 +47,27 @@ func TestRowRange3(t *testing.T) {
func TestRowRange4(t *testing.T) { func TestRowRange4(t *testing.T) {
log := testlg.New(t).Strict(true) log := testlg.New(t).Strict(true)
ast := mustBuildAST(t, fixtRowRange4) ast := mustParse(t, fixtRowRange4)
ins := NewInspector(log, ast) insp := NewInspector(log, ast)
rr, _ := ins.FindNodes(typeRowRange)[0].(*RowRange) rr, _ := insp.FindNodes(typeRowRange)[0].(*RowRange)
assert.Equal(t, 0, rr.Offset) assert.Equal(t, 0, rr.Offset)
assert.Equal(t, 3, rr.Limit) assert.Equal(t, 3, rr.Limit)
} }
func TestRowRange5(t *testing.T) { func TestRowRange5(t *testing.T) {
log := testlg.New(t).Strict(true) log := testlg.New(t).Strict(true)
ast := mustBuildAST(t, fixtRowRange5) ast := mustParse(t, fixtRowRange5)
ins := NewInspector(log, ast) insp := NewInspector(log, ast)
rr, _ := ins.FindNodes(typeRowRange)[0].(*RowRange) rr, _ := insp.FindNodes(typeRowRange)[0].(*RowRange)
assert.Equal(t, 0, rr.Offset) assert.Equal(t, 0, rr.Offset)
assert.Equal(t, 3, rr.Limit) assert.Equal(t, 3, rr.Limit)
} }
func TestRowRange6(t *testing.T) { func TestRowRange6(t *testing.T) {
log := testlg.New(t).Strict(true) log := testlg.New(t).Strict(true)
ast := mustBuildAST(t, fixtRowRange6) ast := mustParse(t, fixtRowRange6)
ins := NewInspector(log, ast) insp := NewInspector(log, ast)
rr, _ := ins.FindNodes(typeRowRange)[0].(*RowRange) rr, _ := insp.FindNodes(typeRowRange)[0].(*RowRange)
assert.Equal(t, 2, rr.Offset) assert.Equal(t, 2, rr.Offset)
assert.Equal(t, -1, rr.Limit) assert.Equal(t, -1, rr.Limit)
} }

View File

@ -8,7 +8,7 @@ import (
func TestSegment(t *testing.T) { func TestSegment(t *testing.T) {
// `@mydb1 | .user, .address | join(.uid == .uid) | .uid, .username, .country` // `@mydb1 | .user, .address | join(.uid == .uid) | .uid, .username, .country`
ast := mustBuildAST(t, fixtJoinQuery1) ast := mustParse(t, fixtJoinQuery1)
segs := ast.Segments() segs := ast.Segments()
assert.Equal(t, 4, len(segs)) assert.Equal(t, 4, len(segs))

View File

@ -20,6 +20,10 @@ var _ Node = (*Selector)(nil)
// selector node such as TblSelector or ColSelector. // selector node such as TblSelector or ColSelector.
type Selector struct { type Selector struct {
baseNode baseNode
// alias is the (optional) alias part. For example, given ".first_name:given_name",
// the alias value is "given_name". May be empy.
alias string
} }
func (s *Selector) String() string { func (s *Selector) String() string {
@ -71,12 +75,14 @@ var (
// ColSelector models a column selector such as ".user_id". // ColSelector models a column selector such as ".user_id".
type ColSelector struct { type ColSelector struct {
Selector Selector
alias string
} }
func newColSelector(parent Node, ctx antlr.ParseTree) *ColSelector { func newColSelector(parent Node, ctx antlr.ParseTree, alias string) *ColSelector {
col := &ColSelector{} col := &ColSelector{}
col.parent = parent col.parent = parent
col.ctx = ctx col.ctx = ctx
col.alias = alias
return col return col
} }
@ -86,17 +92,29 @@ func (s *ColSelector) ColExpr() (string, error) {
return s.Text()[1:], nil return s.Text()[1:], nil
} }
// IsColName always returns true.
func (s *ColSelector) IsColName() bool { func (s *ColSelector) IsColName() bool {
return true return true
} }
// Alias returns the column alias, which may be empty.
// For example, given the selector ".first_name:given_name", the alias is "given_name".
func (s *ColSelector) Alias() string {
return s.alias
}
// String returns a log/debug-friendly representation.
func (s *ColSelector) String() string { func (s *ColSelector) String() string {
return nodeString(s) str := nodeString(s)
if s.alias != "" {
str += ":" + s.alias
}
return str
} }
var _ Node = (*Cmpr)(nil) var _ Node = (*Cmpr)(nil)
// Cmpr models a comparison. // Cmpr models a comparison, such as ".age == 42".
type Cmpr struct { type Cmpr struct {
baseNode baseNode
} }
@ -105,7 +123,7 @@ func (c *Cmpr) String() string {
return nodeString(c) return nodeString(c)
} }
func newCmnr(parent Node, ctx slq.ICmprContext) *Cmpr { func newCmpr(parent Node, ctx slq.ICmprContext) *Cmpr {
leaf, _ := ctx.GetChild(0).(*antlr.TerminalNodeImpl) leaf, _ := ctx.GetChild(0).(*antlr.TerminalNodeImpl)
cmpr := &Cmpr{} cmpr := &Cmpr{}
cmpr.ctx = leaf cmpr.ctx = leaf

View File

@ -0,0 +1,55 @@
package ast
import (
"testing"
"github.com/neilotoole/sq/testh/tutil"
"github.com/stretchr/testify/require"
"github.com/neilotoole/lg/testlg"
)
func TestColumnAlias(t *testing.T) {
t.Parallel()
testCases := []struct {
in string
wantErr bool
wantExpr string
wantAlias string
}{
{
in: `@sakila | .actor | .first_name:given_name`,
wantExpr: "first_name",
wantAlias: "given_name",
},
}
for _, tc := range testCases {
tc := tc
t.Run(tutil.Name(tc.in), func(t *testing.T) {
t.Parallel()
log := testlg.New(t)
ast, err := Parse(log, tc.in)
if tc.wantErr {
require.Error(t, err)
return
}
require.NoError(t, err)
insp := NewInspector(log, ast)
nodes := insp.FindNodes(typeColSelector)
require.Equal(t, 1, len(nodes))
colSel, ok := nodes[0].(*ColSelector)
require.True(t, ok)
expr, _ := colSel.ColExpr()
require.Equal(t, tc.wantExpr, expr)
require.Equal(t, tc.wantAlias, colSel.Alias())
})
}
}

View File

@ -113,7 +113,7 @@ func (fb *BaseFragmentBuilder) Function(fn *ast.Func) (string, error) {
return buf.String(), nil return buf.String(), nil
} }
buf.WriteString(fn.FuncName()) buf.WriteString(strings.ToUpper(fn.FuncName()))
buf.WriteRune('(') buf.WriteRune('(')
for i, child := range children { for i, child := range children {
if i > 0 { if i > 0 {
@ -123,6 +123,8 @@ func (fb *BaseFragmentBuilder) Function(fn *ast.Func) (string, error) {
switch child := child.(type) { switch child := child.(type) {
case *ast.ColSelector: case *ast.ColSelector:
buf.WriteString(child.SelValue()) buf.WriteString(child.SelValue())
case *ast.Operator:
buf.WriteString(child.Text())
default: default:
fb.Log.Debugf("unknown AST child node type %T", child) fb.Log.Debugf("unknown AST child node type %T", child)
} }
@ -287,6 +289,13 @@ func (fb *BaseFragmentBuilder) SelectCols(cols []ast.ColExpr) (string, error) {
return "", errz.Errorf("unable to extract col expr from %q: %v", col, err) return "", errz.Errorf("unable to extract col expr from %q: %v", col, err)
} }
// aliasFrag holds the "AS alias" fragment (if applicable).
// For example "@sakila | actor | .first_name:given_name" becomes "SELECT first_name AS given_name".
var aliasFrag string
if col.Alias() != "" {
aliasFrag = fmt.Sprintf(" AS %s%s%s", fb.Quote, col.Alias(), fb.Quote)
}
fn, ok := col.(*ast.Func) fn, ok := col.(*ast.Func)
if ok { if ok {
// it's a function // it's a function
@ -294,12 +303,14 @@ func (fb *BaseFragmentBuilder) SelectCols(cols []ast.ColExpr) (string, error) {
if err != nil { if err != nil {
return "", err return "", err
} }
vals[i] += aliasFrag
continue continue
} }
if !col.IsColName() { if !col.IsColName() {
// it's a function or expression // it's a function or expression
vals[i] = colText // for now, we just return the raw text vals[i] = colText // for now, we just return the raw text
vals[i] += aliasFrag
continue continue
} }
@ -307,6 +318,7 @@ func (fb *BaseFragmentBuilder) SelectCols(cols []ast.ColExpr) (string, error) {
if !strings.ContainsRune(colText, '.') { if !strings.ContainsRune(colText, '.') {
// it's a regular (non-scoped) col name, e.g. "uid" // it's a regular (non-scoped) col name, e.g. "uid"
vals[i] = fmt.Sprintf("%s%s%s", fb.Quote, colText, fb.Quote) vals[i] = fmt.Sprintf("%s%s%s", fb.Quote, colText, fb.Quote)
vals[i] += aliasFrag
continue continue
} }
@ -317,6 +329,7 @@ func (fb *BaseFragmentBuilder) SelectCols(cols []ast.ColExpr) (string, error) {
} }
vals[i] = fmt.Sprintf("%s%s%s.%s%s%s", fb.Quote, parts[0], fb.Quote, fb.Quote, parts[1], fb.Quote) vals[i] = fmt.Sprintf("%s%s%s.%s%s%s", fb.Quote, parts[0], fb.Quote, fb.Quote, parts[1], fb.Quote)
vals[i] += aliasFrag
} }
text := "SELECT " + strings.Join(vals, ", ") text := "SELECT " + strings.Join(vals, ", ")

View File

@ -87,7 +87,7 @@ func narrowTblSel(log lg.Log, w *Walker, node Node) error {
} }
if seg.SegIndex() == 0 { if seg.SegIndex() == 0 {
return errorf("syntax error: illegal to have raw selector in first segment: %q", sel.Text()) return errorf("@HANDLE must be first element: %q", sel.Text())
} }
typ, err := seg.Prev().ChildType() typ, err := seg.Prev().ChildType()
@ -104,7 +104,7 @@ func narrowTblSel(log lg.Log, w *Walker, node Node) error {
// this means that this selector must be a table selector // this means that this selector must be a table selector
tblSel := newTblSelector(seg, sel.SelValue(), sel.Context()) tblSel := newTblSelector(seg, sel.SelValue(), sel.Context())
tblSel.DSName = ds.Text() tblSel.DSName = ds.Text()
err = replaceNode(sel, tblSel) err = nodeReplace(sel, tblSel)
if err != nil { if err != nil {
return err return err
} }
@ -115,7 +115,7 @@ func narrowTblSel(log lg.Log, w *Walker, node Node) error {
// narrowColSel takes a generic selector, and if appropriate, converts it to a ColSel. // narrowColSel takes a generic selector, and if appropriate, converts it to a ColSel.
func narrowColSel(log lg.Log, w *Walker, node Node) error { func narrowColSel(log lg.Log, w *Walker, node Node) error {
// node is guaranteed to be typeSelector // node is guaranteed to be type Selector
sel, ok := node.(*Selector) sel, ok := node.(*Selector)
if !ok { if !ok {
return errorf("expected *Selector but got %T", node) return errorf("expected *Selector but got %T", node)
@ -127,8 +127,8 @@ func narrowColSel(log lg.Log, w *Walker, node Node) error {
case *JoinConstraint, *Func: case *JoinConstraint, *Func:
// selector parent is JoinConstraint or Func, therefore this is a ColSelector // selector parent is JoinConstraint or Func, therefore this is a ColSelector
log.Debugf("selector parent is %T, therefore this is a ColSelector", parent) log.Debugf("selector parent is %T, therefore this is a ColSelector", parent)
colSel := newColSelector(sel.Parent(), sel.ctx) colSel := newColSelector(sel.Parent(), sel.ctx, sel.alias)
return replaceNode(sel, colSel) return nodeReplace(sel, colSel)
case *Segment: case *Segment:
// if the parent is a segment, this is a "top-level" selector. // if the parent is a segment, this is a "top-level" selector.
// Only top-level selectors after the final selectable seg are // Only top-level selectors after the final selectable seg are
@ -143,8 +143,8 @@ func narrowColSel(log lg.Log, w *Walker, node Node) error {
return nil return nil
} }
colSel := newColSelector(sel.Parent(), sel.ctx) colSel := newColSelector(sel.Parent(), sel.ctx, sel.alias)
return replaceNode(sel, colSel) return nodeReplace(sel, colSel)
default: default:
log.Warnf("skipping this selector, as parent is not of a relevant type, but is %T", parent) log.Warnf("skipping this selector, as parent is not of a relevant type, but is %T", parent)

View File

@ -38,7 +38,7 @@ type engine struct {
// prepare prepares the engine to execute queryModel. // prepare prepares the engine to execute queryModel.
// When this method returns, targetDB and targetSQL will be set, // When this method returns, targetDB and targetSQL will be set,
// as will any tasks (may be empty). The tasks must be executed // as will any tasks (which may be empty). The tasks must be executed
// against targetDB before targetSQL is executed (the engine.execute // against targetDB before targetSQL is executed (the engine.execute
// method does this work). // method does this work).
func (ng *engine) prepare(ctx context.Context, qm *queryModel) error { func (ng *engine) prepare(ctx context.Context, qm *queryModel) error {
@ -259,6 +259,20 @@ func (ng *engine) crossSourceJoin(ctx context.Context, fnJoin *ast.Join) (fromCl
return fromClause, joinDB, nil return fromClause, joinDB, nil
} }
// SLQ2SQL simulates execution of a SLQ query, but instead of executing
// the resulting SQL query, that ultimate SQL is returned. Effectively it is
// equivalent to libsq.ExecuteSLQ, but without the execution.
func SLQ2SQL(ctx context.Context, log lg.Log, dbOpener driver.DatabaseOpener,
joinDBOpener driver.JoinDatabaseOpener, srcs *source.Set, query string,
) (targetSQL string, err error) {
var ng *engine
ng, err = newEngine(ctx, log, dbOpener, joinDBOpener, srcs, query)
if err != nil {
return "", err
}
return ng.targetSQL, nil
}
// tasker is the interface for executing a DB task. // tasker is the interface for executing a DB task.
type tasker interface { type tasker interface {
// executeTask executes a task against the DB. // executeTask executes a task against the DB.

View File

@ -1,48 +0,0 @@
package libsq_test
import (
"testing"
"github.com/neilotoole/sq/testh/tutil"
"github.com/stretchr/testify/require"
"github.com/neilotoole/sq/libsq"
"github.com/neilotoole/sq/testh"
"github.com/neilotoole/sq/testh/sakila"
)
func TestSLQ2SQL(t *testing.T) {
testCases := []struct {
handles []string
slq string
wantSQL string
wantErr bool
}{
// Obviously we could use about 1,000 additional test cases.
{
handles: []string{sakila.SL3},
slq: `@sakila_sl3 | .actor, .film_actor | join(.film_actor.actor_id == .actor.actor_id)`,
wantSQL: `SELECT * FROM "actor" INNER JOIN "film_actor" ON "film_actor"."actor_id" = "actor"."actor_id"`,
},
}
for _, tc := range testCases {
tc := tc
t.Run(tutil.Name(tc.slq), func(t *testing.T) {
th := testh.New(t)
srcs := th.NewSourceSet(tc.handles...)
gotSQL, gotErr := libsq.EngineSLQ2SQL(th.Context, th.Log, th.Databases(), th.Databases(), srcs, tc.slq)
if tc.wantErr {
require.Error(t, gotErr)
return
}
require.NoError(t, gotErr)
require.Equal(t, tc.wantSQL, gotSQL)
})
}
}

View File

@ -1,26 +0,0 @@
package libsq
import (
"context"
"github.com/neilotoole/lg"
"github.com/neilotoole/sq/libsq/driver"
"github.com/neilotoole/sq/libsq/source"
)
// EngineSLQ2SQL is a dedicated testing function that simulates
// execution of a SLQ query, but instead of executing the resulting
// SQL query, that ultimate SQL is returned. Effectively it is
// equivalent to libsq.ExecuteSLQ, but without the execution.
// Admittedly, this is an ugly workaround.
func EngineSLQ2SQL(ctx context.Context, log lg.Log, dbOpener driver.DatabaseOpener,
joinDBOpener driver.JoinDatabaseOpener, srcs *source.Set, query string,
) (targetSQL string, err error) {
var ng *engine
ng, err = newEngine(ctx, log, dbOpener, joinDBOpener, srcs, query)
if err != nil {
return "", err
}
return ng.targetSQL, nil
}

View File

@ -80,6 +80,11 @@ func Lint() error {
return sh.RunV("golangci-lint", "run", "./...") return sh.RunV("golangci-lint", "run", "./...")
} }
// Fmt runs gofumpt on the source.
func Fmt() error {
return sh.RunV("gofumpt", "-l", "-w", ".")
}
// Generate generates SLQ parser Go files from the // Generate generates SLQ parser Go files from the
// antlr grammar. Note that the antlr generator tool is Java-based; you // antlr grammar. Note that the antlr generator tool is Java-based; you
// must have Java installed. // must have Java installed.