sq/libsq/ast/segment.go

194 lines
4.2 KiB
Go
Raw Normal View History

2016-10-17 07:14:01 +03:00
package ast
import (
"fmt"
"reflect"
"strings"
"github.com/antlr4-go/antlr/v4"
"github.com/neilotoole/sq/libsq/ast/internal/slq"
2016-10-17 07:14:01 +03:00
)
// VisitSegment implements slq.SLQVisitor.
func (v *parseTreeVisitor) VisitSegment(ctx *slq.SegmentContext) any {
seg := newSegmentNode(v.ast, ctx)
v.ast.AddSegment(seg)
v.cur = seg
return v.VisitChildren(ctx)
}
func newSegmentNode(ast *AST, ctx *slq.SegmentContext) *SegmentNode {
seg := &SegmentNode{}
seg.bn.ctx = ctx
seg.bn.parent = ast
seg.bn.text = ctx.GetText()
return seg
}
2020-08-06 20:58:47 +03:00
var _ Node = (*SegmentNode)(nil)
// SegmentNode models a segment of a query (the elements separated by pipes).
// For example, ".user | .uid, .username" is two segments: ".user",
// and ".uid, .username".
type SegmentNode struct {
2020-08-06 20:58:47 +03:00
bn baseNode
2016-10-17 07:14:01 +03:00
}
// Parent implements ast.Node.
func (s *SegmentNode) Parent() Node {
2016-10-17 07:14:01 +03:00
return s.bn.Parent()
}
// SetParent implements ast.Node.
func (s *SegmentNode) SetParent(parent Node) error {
2016-10-17 07:14:01 +03:00
ast, ok := parent.(*AST)
if !ok {
2020-08-06 20:58:47 +03:00
return errorf("%T requires parent of type %s", s, typeAST)
2016-10-17 07:14:01 +03:00
}
return s.bn.SetParent(ast)
}
// Children implements ast.Node.
func (s *SegmentNode) Children() []Node {
2016-10-17 07:14:01 +03:00
return s.bn.Children()
}
// AddChild implements ast.Node.
func (s *SegmentNode) AddChild(child Node) error {
2016-10-17 07:14:01 +03:00
s.bn.addChild(child)
return child.SetParent(s)
}
// SetChildren implements ast.Node.
func (s *SegmentNode) SetChildren(children []Node) error {
s.bn.doSetChildren(children)
2016-10-17 07:14:01 +03:00
return nil
}
// context implements ast.Node.
func (s *SegmentNode) context() antlr.ParseTree {
return s.bn.context()
2016-10-17 07:14:01 +03:00
}
// setContext implements ast.Node.
func (s *SegmentNode) setContext(ctx antlr.ParseTree) error {
2016-10-17 07:14:01 +03:00
segCtx, ok := ctx.(*slq.SegmentContext)
if !ok {
return errorf("expected *parser.SegmentContext, but got %T", ctx)
}
return s.bn.setContext(segCtx)
2016-10-17 07:14:01 +03:00
}
// ChildType returns the expected Type of the segment's elements, based
// on the content of the segment's node's children. The type should be something
// like SelectorNode|FuncNode.
func (s *SegmentNode) ChildType() (reflect.Type, error) {
2016-10-17 07:14:01 +03:00
if len(s.Children()) == 0 {
return nil, nil
}
2020-08-06 20:58:47 +03:00
_, err := s.uniformChildren()
2016-10-17 07:14:01 +03:00
if err != nil {
return nil, err
}
return reflect.TypeOf(s.Children()[0]), nil
}
2020-08-06 20:58:47 +03:00
// uniformChildren returns true if all the nodes of the segment
// are of a uniform type.
func (s *SegmentNode) uniformChildren() (bool, error) {
2016-10-17 07:14:01 +03:00
if len(s.Children()) == 0 {
return true, nil
}
2020-08-06 20:58:47 +03:00
typs := map[string]struct{}{}
2016-10-17 07:14:01 +03:00
for _, elem := range s.Children() {
2020-08-06 20:58:47 +03:00
typs[reflect.TypeOf(elem).String()] = struct{}{}
2016-10-17 07:14:01 +03:00
}
2020-08-06 20:58:47 +03:00
if len(typs) > 1 {
var str []string
for typ := range typs {
str = append(str, typ)
2016-10-17 07:14:01 +03:00
}
return false, fmt.Errorf("segment [%d] has more than one element node type: [%s]", s.SegIndex(),
strings.Join(str, ", "))
2016-10-17 07:14:01 +03:00
}
return true, nil
}
2020-08-06 20:58:47 +03:00
// SegIndex returns the index of this segment.
func (s *SegmentNode) SegIndex() int {
2016-10-17 07:14:01 +03:00
for i, seg := range s.bn.parent.Children() {
if s == seg {
return i
}
}
return -1
}
// String returns a log/debug-friendly representation.
func (s *SegmentNode) String() string {
2016-10-17 07:14:01 +03:00
if len(s.Children()) == 1 {
return fmt.Sprintf("segment[%d]: [1 element]", s.SegIndex())
}
return fmt.Sprintf("segment[%d]: [%d elements]", s.SegIndex(), len(s.Children()))
}
// Text implements ast.Node.
func (s *SegmentNode) Text() string {
return s.bn.context().GetText()
2016-10-17 07:14:01 +03:00
}
2020-08-06 20:58:47 +03:00
// Prev returns the previous segment, or nil if this is
// the first segment.
func (s *SegmentNode) Prev() *SegmentNode {
2016-10-17 07:14:01 +03:00
parent := s.Parent()
children := parent.Children()
index := -1
for i, child := range children {
childSeg, ok := child.(*SegmentNode)
2016-10-17 07:14:01 +03:00
if !ok {
// should never happen
panic("sibling is not *ast.SegmentNode")
2016-10-17 07:14:01 +03:00
}
if childSeg == s {
index = i
break
}
}
if index == -1 {
2020-08-06 20:58:47 +03:00
// Should never happen
2016-10-17 07:14:01 +03:00
panic(fmt.Sprintf("did not find index for this segment: %s", s))
}
if index == 0 {
return nil
}
return children[index-1].(*SegmentNode)
2016-10-17 07:14:01 +03:00
}
2020-08-06 20:58:47 +03:00
// Next returns the next segment, or nil if this is the last segment.
func (s *SegmentNode) Next() *SegmentNode {
2016-10-17 07:14:01 +03:00
for i, seg := range s.bn.parent.Children() {
if seg == s {
if i >= len(s.bn.parent.Children())-1 {
return nil
}
return s.bn.parent.Children()[i+1].(*SegmentNode)
2016-10-17 07:14:01 +03:00
}
}
return nil
}