mirror of
https://github.com/neilotoole/sq.git
synced 2024-12-19 14:11:45 +03:00
44d27207f8
* Column-only queries
309 lines
6.5 KiB
Go
309 lines
6.5 KiB
Go
package ast
|
|
|
|
import (
|
|
"fmt"
|
|
|
|
"github.com/neilotoole/sq/libsq/ast/internal/slq"
|
|
|
|
"github.com/antlr/antlr4/runtime/Go/antlr/v4"
|
|
)
|
|
|
|
// VisitJoin implements slq.SLQVisitor.
|
|
func (v *parseTreeVisitor) VisitJoin(ctx *slq.JoinContext) any {
|
|
// parent node must be a segment
|
|
seg, ok := v.cur.(*SegmentNode)
|
|
if !ok {
|
|
return errorf("parent of JOIN() must be SegmentNode, but got: %T", v.cur)
|
|
}
|
|
|
|
join := &JoinNode{seg: seg, ctx: ctx}
|
|
err := seg.AddChild(join)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
expr := ctx.JoinConstraint()
|
|
if expr == nil {
|
|
return nil
|
|
}
|
|
|
|
// the join contains a constraint, let's hit it
|
|
v.cur = join
|
|
err2 := v.VisitJoinConstraint(expr.(*slq.JoinConstraintContext))
|
|
if err2 != nil {
|
|
return err2
|
|
}
|
|
// set cur back to previous
|
|
v.cur = seg
|
|
return nil
|
|
}
|
|
|
|
// VisitJoinConstraint implements slq.SLQVisitor.
|
|
func (v *parseTreeVisitor) VisitJoinConstraint(ctx *slq.JoinConstraintContext) any {
|
|
joinNode, ok := v.cur.(*JoinNode)
|
|
if !ok {
|
|
return errorf("JOIN constraint must have JOIN parent, but got %T", v.cur)
|
|
}
|
|
|
|
// the constraint could be empty
|
|
children := ctx.GetChildren()
|
|
if len(children) == 0 {
|
|
return nil
|
|
}
|
|
|
|
// the constraint could be a single SEL (in which case, there's no comparison operator)
|
|
if ctx.Cmpr() == nil {
|
|
// there should be exactly one SEL
|
|
sels := ctx.AllSelector()
|
|
if len(sels) != 1 {
|
|
return errorf("JOIN constraint without a comparison operator must have exactly one selector")
|
|
}
|
|
|
|
joinExprNode := &JoinConstraint{join: joinNode, ctx: ctx}
|
|
|
|
colSelNode, err := newSelectorNode(joinExprNode, sels[0])
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if err := joinExprNode.AddChild(colSelNode); err != nil {
|
|
return err
|
|
}
|
|
|
|
return joinNode.AddChild(joinExprNode)
|
|
}
|
|
|
|
// We've got a comparison operator
|
|
sels := ctx.AllSelector()
|
|
if len(sels) != 2 {
|
|
// REVISIT: probably unnecessary, should be caught by the parser
|
|
return errorf("JOIN constraint must have 2 operands (left & right), but got %d", len(sels))
|
|
}
|
|
|
|
join, ok := v.cur.(*JoinNode)
|
|
if !ok {
|
|
return errorf("JoinConstraint must have JoinNode parent, but got %T", v.cur)
|
|
}
|
|
joinCondition := &JoinConstraint{join: join, ctx: ctx}
|
|
|
|
leftSel, err := newSelectorNode(joinCondition, sels[0])
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if err = joinCondition.AddChild(leftSel); err != nil {
|
|
return err
|
|
}
|
|
|
|
cmpr := newCmprNode(joinCondition, ctx.Cmpr())
|
|
if err = joinCondition.AddChild(cmpr); err != nil {
|
|
return err
|
|
}
|
|
|
|
rightSel, err := newSelectorNode(joinCondition, sels[1])
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if err = joinCondition.AddChild(rightSel); err != nil {
|
|
return err
|
|
}
|
|
|
|
return join.AddChild(joinCondition)
|
|
}
|
|
|
|
var _ Node = (*JoinNode)(nil)
|
|
|
|
// JoinNode models a SQL JOIN node. It has a child of type JoinConstraint.
|
|
type JoinNode struct {
|
|
seg *SegmentNode
|
|
ctx antlr.ParseTree
|
|
constraint *JoinConstraint
|
|
leftTbl *TblSelectorNode
|
|
rightTbl *TblSelectorNode
|
|
}
|
|
|
|
// LeftTbl is the selector for the left table of the join.
|
|
func (jn *JoinNode) LeftTbl() *TblSelectorNode {
|
|
return jn.leftTbl
|
|
}
|
|
|
|
// RightTbl is the selector for the right table of the join.
|
|
func (jn *JoinNode) RightTbl() *TblSelectorNode {
|
|
return jn.rightTbl
|
|
}
|
|
|
|
// Tabler implements the Tabler marker interface.
|
|
func (jn *JoinNode) tabler() {
|
|
// no-op
|
|
}
|
|
|
|
func (jn *JoinNode) Parent() Node {
|
|
return jn.seg
|
|
}
|
|
|
|
func (jn *JoinNode) SetParent(parent Node) error {
|
|
seg, ok := parent.(*SegmentNode)
|
|
if !ok {
|
|
return errorf("%T requires parent of type %s", jn, typeSegmentNode)
|
|
}
|
|
jn.seg = seg
|
|
return nil
|
|
}
|
|
|
|
func (jn *JoinNode) Children() []Node {
|
|
if jn.constraint == nil {
|
|
return []Node{}
|
|
}
|
|
|
|
return []Node{jn.constraint}
|
|
}
|
|
|
|
func (jn *JoinNode) AddChild(node Node) error {
|
|
jc, ok := node.(*JoinConstraint)
|
|
if !ok {
|
|
return errorf("JOIN() child must be *JoinConstraint, but got: %T", node)
|
|
}
|
|
|
|
if jn.constraint != nil {
|
|
return errorf("JOIN() has max 1 child: failed to add: %T", node)
|
|
}
|
|
|
|
jn.constraint = jc
|
|
return nil
|
|
}
|
|
|
|
func (jn *JoinNode) SetChildren(children []Node) error {
|
|
if len(children) == 0 {
|
|
jn.constraint = nil
|
|
return nil
|
|
}
|
|
|
|
if len(children) > 1 {
|
|
return errorf("JOIN() can have only one child: failed to add %d children", len(children))
|
|
}
|
|
|
|
expr, ok := children[0].(*JoinConstraint)
|
|
if !ok {
|
|
return errorf("JOIN() child must be *FnJoinExpr, but got: %T", children[0])
|
|
}
|
|
|
|
jn.constraint = expr
|
|
return nil
|
|
}
|
|
|
|
func (jn *JoinNode) Context() antlr.ParseTree {
|
|
return jn.ctx
|
|
}
|
|
|
|
func (jn *JoinNode) SetContext(ctx antlr.ParseTree) error {
|
|
jn.ctx = ctx
|
|
return nil
|
|
}
|
|
|
|
func (jn *JoinNode) Text() string {
|
|
return jn.ctx.GetText()
|
|
}
|
|
|
|
func (jn *JoinNode) Segment() *SegmentNode {
|
|
return jn.seg
|
|
}
|
|
|
|
func (jn *JoinNode) String() string {
|
|
text := nodeString(jn)
|
|
|
|
leftTblName := ""
|
|
rightTblName := ""
|
|
|
|
if jn.leftTbl != nil {
|
|
leftTblName, _ = jn.leftTbl.SelValue()
|
|
}
|
|
if jn.rightTbl != nil {
|
|
rightTblName, _ = jn.rightTbl.SelValue()
|
|
}
|
|
|
|
text += fmt.Sprintf(" | left_table: {%s} | right_table: {%s}", leftTblName, rightTblName)
|
|
return text
|
|
}
|
|
|
|
var _ Node = (*JoinConstraint)(nil)
|
|
|
|
// JoinConstraint models a join's constraint.
|
|
// For example the elements inside the parentheses
|
|
// in "join(.uid == .user_id)".
|
|
type JoinConstraint struct {
|
|
// join is the parent node
|
|
join *JoinNode
|
|
ctx antlr.ParseTree
|
|
children []Node
|
|
}
|
|
|
|
func (n *JoinConstraint) Parent() Node {
|
|
return n.join
|
|
}
|
|
|
|
func (n *JoinConstraint) SetParent(parent Node) error {
|
|
join, ok := parent.(*JoinNode)
|
|
if !ok {
|
|
return errorf("%T requires parent of type %s", n, typeJoinNode)
|
|
}
|
|
n.join = join
|
|
return nil
|
|
}
|
|
|
|
func (n *JoinConstraint) Children() []Node {
|
|
return n.children
|
|
}
|
|
|
|
func (n *JoinConstraint) AddChild(child Node) error {
|
|
nodeCtx := child.Context()
|
|
|
|
switch nodeCtx.(type) {
|
|
case *antlr.TerminalNodeImpl:
|
|
case *slq.SelectorContext:
|
|
default:
|
|
return errorf("cannot add child node %T to %T", nodeCtx, n.ctx)
|
|
}
|
|
|
|
n.children = append(n.children, child)
|
|
return nil
|
|
}
|
|
|
|
func (n *JoinConstraint) SetChildren(children []Node) error {
|
|
for _, child := range children {
|
|
nodeCtx := child.Context()
|
|
|
|
switch nodeCtx.(type) {
|
|
case *antlr.TerminalNodeImpl:
|
|
case *slq.SelectorContext:
|
|
default:
|
|
return errorf("cannot add child node %T to %T", nodeCtx, n.ctx)
|
|
}
|
|
}
|
|
|
|
if len(children) == 0 {
|
|
n.children = children
|
|
return nil
|
|
}
|
|
|
|
n.children = children
|
|
return nil
|
|
}
|
|
|
|
func (n *JoinConstraint) Context() antlr.ParseTree {
|
|
return n.ctx
|
|
}
|
|
|
|
func (n *JoinConstraint) SetContext(ctx antlr.ParseTree) error {
|
|
n.ctx = ctx // TODO: check for correct type
|
|
return nil
|
|
}
|
|
|
|
func (n *JoinConstraint) Text() string {
|
|
return n.ctx.GetText()
|
|
}
|
|
|
|
func (n *JoinConstraint) String() string {
|
|
return nodeString(n)
|
|
}
|