sq/libsq/ast/func.go
Neil O'Toole 7c56377b40
Struct alignment (#369)
* Field alignment
2024-01-27 00:11:24 -07:00

188 lines
4.3 KiB
Go

package ast
import (
"strings"
"github.com/neilotoole/sq/libsq/ast/internal/slq"
)
const (
FuncNameAvg = "avg"
FuncNameCount = "count"
FuncNameCountUnique = "count_unique"
FuncNameMax = "max"
FuncNameMin = "min"
FuncNameSchema = "schema"
FuncNameCatalog = "catalog"
FuncNameSum = "sum"
FuncNameRowNum = "rownum"
)
var (
_ Node = (*FuncNode)(nil)
_ ResultColumn = (*FuncNode)(nil)
)
// FuncNode models a function. For example, "COUNT()".
type FuncNode struct {
fnName string
alias string
baseNode
proprietary bool
}
// resultColumn implements ast.ResultColumn.
//
// REVISIT: should ast.FuncNode implement ast.ResultColumn?
func (fn *FuncNode) resultColumn() {
}
// FuncName returns the function name.
func (fn *FuncNode) FuncName() string {
return fn.fnName
}
// IsProprietary returns true if this is a DB-proprietary function, as
// opposed to a portable function. For example, SQLite has
// a "strftime" function. In the SLQ, this is referenced
// as "_strftime": SLQ uses the underscore to indicate a proprietary
// function.
func (fn *FuncNode) IsProprietary() bool {
return fn.proprietary
}
// String returns a log/debug-friendly representation.
func (fn *FuncNode) String() string {
str := nodeString(fn)
if fn.alias != "" {
str += ":" + fn.alias
}
return str
}
// Text implements ResultColumn.
func (fn *FuncNode) Text() string {
return fn.ctx.GetText()
}
// Alias implements ResultColumn.
func (fn *FuncNode) Alias() string {
return fn.alias
}
// SetChildren implements Node.
func (fn *FuncNode) SetChildren(children []Node) error {
fn.doSetChildren(children)
return nil
}
// AddChild implements Node.
func (fn *FuncNode) AddChild(child Node) error {
// TODO: add check for valid FuncNode child types
fn.addChild(child)
return child.SetParent(fn)
}
// VisitFuncName implements slq.SLQVisitor.
func (v *parseTreeVisitor) VisitFuncName(_ *slq.FuncNameContext) any {
// no-op
return nil
}
// VisitFuncElement implements slq.SLQVisitor.
func (v *parseTreeVisitor) VisitFuncElement(ctx *slq.FuncElementContext) any {
childCount := ctx.GetChildCount()
if childCount == 0 || childCount > 2 {
return errorf("parser: invalid function: expected 1 or 2 children, but got %d: %v",
childCount, ctx.GetText())
}
// e.g. count(*)
child1 := ctx.GetChild(0)
fnCtx, ok := child1.(*slq.FuncContext)
if !ok {
return errorf("expected first child to be %T but was %T: %v", fnCtx, child1, ctx.GetText())
}
if err := v.VisitFunc(fnCtx); err != nil {
return err
}
// Check if there's an alias
if childCount == 2 {
child2 := ctx.GetChild(1)
aliasCtx, ok := child2.(*slq.AliasContext)
if !ok {
return errorf("expected second child to be %T but was %T: %v", aliasCtx, child2, ctx.GetText())
}
// VisitAlias will expect v.cur to be a FuncNode.
lastNode := nodeLastChild(v.cur)
fnNode, ok := lastNode.(*FuncNode)
if !ok {
return errorf("expected %T but got %T: %v", fnNode, lastNode, ctx.GetText())
}
return v.using(fnNode, func() any {
return v.VisitAlias(aliasCtx)
})
}
return nil
}
// VisitFunc implements slq.SLQVisitor.
func (v *parseTreeVisitor) VisitFunc(ctx *slq.FuncContext) any {
node := &FuncNode{fnName: ctx.FuncName().GetText()}
if node.fnName[0] == '_' {
node.fnName = node.fnName[1:]
}
node.ctx = ctx
node.text = ctx.GetText()
if err := node.SetParent(v.cur); err != nil {
return err
}
if err := v.using(node, func() any {
return v.VisitChildren(ctx)
}); err != nil {
return err
}
if node.alias == "" {
node.alias = ctx.GetText()
node.alias = strings.TrimPrefix(node.alias, "_")
}
return v.cur.AddChild(node)
}
// VisitCountFunc implements antlr.ParseTreeVisitor.
// Although the "count" func has special handling in the grammar (because
// it has a no-arg form, e.g. ".actor | count"), a regular FuncNode is
// inserted into the AST.
func (v *parseTreeVisitor) VisitCountFunc(ctx *slq.CountFuncContext) interface{} {
node := &FuncNode{fnName: "count"}
node.ctx = ctx
node.text = ctx.GetText()
if err := v.cur.AddChild(node); err != nil {
return err
}
if err := v.using(node, func() any {
return v.VisitChildren(ctx)
}); err != nil {
return err
}
if len(node.Children()) == 0 && ctx.Alias() == nil {
// If there's no children, and no alias, we explicitly set the
// alias to "count".
node.alias = "count"
}
return nil
}