mirror of
https://github.com/neilotoole/sq.git
synced 2025-01-05 15:24:00 +03:00
7c56377b40
* Field alignment
188 lines
4.3 KiB
Go
188 lines
4.3 KiB
Go
package ast
|
|
|
|
import (
|
|
"strings"
|
|
|
|
"github.com/neilotoole/sq/libsq/ast/internal/slq"
|
|
)
|
|
|
|
const (
|
|
FuncNameAvg = "avg"
|
|
FuncNameCount = "count"
|
|
FuncNameCountUnique = "count_unique"
|
|
FuncNameMax = "max"
|
|
FuncNameMin = "min"
|
|
FuncNameSchema = "schema"
|
|
FuncNameCatalog = "catalog"
|
|
FuncNameSum = "sum"
|
|
FuncNameRowNum = "rownum"
|
|
)
|
|
|
|
var (
|
|
_ Node = (*FuncNode)(nil)
|
|
_ ResultColumn = (*FuncNode)(nil)
|
|
)
|
|
|
|
// FuncNode models a function. For example, "COUNT()".
|
|
type FuncNode struct {
|
|
fnName string
|
|
alias string
|
|
baseNode
|
|
proprietary bool
|
|
}
|
|
|
|
// resultColumn implements ast.ResultColumn.
|
|
//
|
|
// REVISIT: should ast.FuncNode implement ast.ResultColumn?
|
|
func (fn *FuncNode) resultColumn() {
|
|
}
|
|
|
|
// FuncName returns the function name.
|
|
func (fn *FuncNode) FuncName() string {
|
|
return fn.fnName
|
|
}
|
|
|
|
// IsProprietary returns true if this is a DB-proprietary function, as
|
|
// opposed to a portable function. For example, SQLite has
|
|
// a "strftime" function. In the SLQ, this is referenced
|
|
// as "_strftime": SLQ uses the underscore to indicate a proprietary
|
|
// function.
|
|
func (fn *FuncNode) IsProprietary() bool {
|
|
return fn.proprietary
|
|
}
|
|
|
|
// String returns a log/debug-friendly representation.
|
|
func (fn *FuncNode) String() string {
|
|
str := nodeString(fn)
|
|
if fn.alias != "" {
|
|
str += ":" + fn.alias
|
|
}
|
|
return str
|
|
}
|
|
|
|
// Text implements ResultColumn.
|
|
func (fn *FuncNode) Text() string {
|
|
return fn.ctx.GetText()
|
|
}
|
|
|
|
// Alias implements ResultColumn.
|
|
func (fn *FuncNode) Alias() string {
|
|
return fn.alias
|
|
}
|
|
|
|
// SetChildren implements Node.
|
|
func (fn *FuncNode) SetChildren(children []Node) error {
|
|
fn.doSetChildren(children)
|
|
return nil
|
|
}
|
|
|
|
// AddChild implements Node.
|
|
func (fn *FuncNode) AddChild(child Node) error {
|
|
// TODO: add check for valid FuncNode child types
|
|
fn.addChild(child)
|
|
return child.SetParent(fn)
|
|
}
|
|
|
|
// VisitFuncName implements slq.SLQVisitor.
|
|
func (v *parseTreeVisitor) VisitFuncName(_ *slq.FuncNameContext) any {
|
|
// no-op
|
|
return nil
|
|
}
|
|
|
|
// VisitFuncElement implements slq.SLQVisitor.
|
|
func (v *parseTreeVisitor) VisitFuncElement(ctx *slq.FuncElementContext) any {
|
|
childCount := ctx.GetChildCount()
|
|
if childCount == 0 || childCount > 2 {
|
|
return errorf("parser: invalid function: expected 1 or 2 children, but got %d: %v",
|
|
childCount, ctx.GetText())
|
|
}
|
|
|
|
// e.g. count(*)
|
|
child1 := ctx.GetChild(0)
|
|
fnCtx, ok := child1.(*slq.FuncContext)
|
|
if !ok {
|
|
return errorf("expected first child to be %T but was %T: %v", fnCtx, child1, ctx.GetText())
|
|
}
|
|
|
|
if err := v.VisitFunc(fnCtx); err != nil {
|
|
return err
|
|
}
|
|
|
|
// Check if there's an alias
|
|
if childCount == 2 {
|
|
child2 := ctx.GetChild(1)
|
|
aliasCtx, ok := child2.(*slq.AliasContext)
|
|
if !ok {
|
|
return errorf("expected second child to be %T but was %T: %v", aliasCtx, child2, ctx.GetText())
|
|
}
|
|
|
|
// VisitAlias will expect v.cur to be a FuncNode.
|
|
lastNode := nodeLastChild(v.cur)
|
|
fnNode, ok := lastNode.(*FuncNode)
|
|
if !ok {
|
|
return errorf("expected %T but got %T: %v", fnNode, lastNode, ctx.GetText())
|
|
}
|
|
|
|
return v.using(fnNode, func() any {
|
|
return v.VisitAlias(aliasCtx)
|
|
})
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// VisitFunc implements slq.SLQVisitor.
|
|
func (v *parseTreeVisitor) VisitFunc(ctx *slq.FuncContext) any {
|
|
node := &FuncNode{fnName: ctx.FuncName().GetText()}
|
|
if node.fnName[0] == '_' {
|
|
node.fnName = node.fnName[1:]
|
|
}
|
|
|
|
node.ctx = ctx
|
|
node.text = ctx.GetText()
|
|
if err := node.SetParent(v.cur); err != nil {
|
|
return err
|
|
}
|
|
|
|
if err := v.using(node, func() any {
|
|
return v.VisitChildren(ctx)
|
|
}); err != nil {
|
|
return err
|
|
}
|
|
|
|
if node.alias == "" {
|
|
node.alias = ctx.GetText()
|
|
node.alias = strings.TrimPrefix(node.alias, "_")
|
|
}
|
|
|
|
return v.cur.AddChild(node)
|
|
}
|
|
|
|
// VisitCountFunc implements antlr.ParseTreeVisitor.
|
|
// Although the "count" func has special handling in the grammar (because
|
|
// it has a no-arg form, e.g. ".actor | count"), a regular FuncNode is
|
|
// inserted into the AST.
|
|
func (v *parseTreeVisitor) VisitCountFunc(ctx *slq.CountFuncContext) interface{} {
|
|
node := &FuncNode{fnName: "count"}
|
|
node.ctx = ctx
|
|
node.text = ctx.GetText()
|
|
|
|
if err := v.cur.AddChild(node); err != nil {
|
|
return err
|
|
}
|
|
|
|
if err := v.using(node, func() any {
|
|
return v.VisitChildren(ctx)
|
|
}); err != nil {
|
|
return err
|
|
}
|
|
|
|
if len(node.Children()) == 0 && ctx.Alias() == nil {
|
|
// If there's no children, and no alias, we explicitly set the
|
|
// alias to "count".
|
|
node.alias = "count"
|
|
}
|
|
|
|
return nil
|
|
}
|