sq/libsq/ast/inspector.go

204 lines
4.9 KiB
Go
Raw Normal View History

2016-10-17 07:14:01 +03:00
package ast
2020-08-06 20:58:47 +03:00
import (
"reflect"
2016-10-17 07:14:01 +03:00
"github.com/samber/lo"
"github.com/ryboe/q"
2020-08-06 20:58:47 +03:00
"github.com/neilotoole/lg"
)
// Inspector provides functionality for AST interrogation.
2016-10-17 07:14:01 +03:00
type Inspector struct {
2020-08-06 20:58:47 +03:00
log lg.Log
2016-10-17 07:14:01 +03:00
ast *AST
}
2020-08-06 20:58:47 +03:00
// NewInspector returns an Inspector instance for ast.
func NewInspector(log lg.Log, ast *AST) *Inspector {
return &Inspector{log: log, ast: ast}
2016-10-17 07:14:01 +03:00
}
2020-08-06 20:58:47 +03:00
// CountNodes counts the number of nodes having typ.
func (in *Inspector) CountNodes(typ reflect.Type) int {
2016-10-17 07:14:01 +03:00
count := 0
2020-08-06 20:58:47 +03:00
w := NewWalker(in.log, in.ast)
w.AddVisitor(typ, func(log lg.Log, w *Walker, node Node) error {
2016-10-17 07:14:01 +03:00
count++
if typ == typeSelectorNode {
// found it
// FIXME: delete this
q.Q("found it", node)
}
2016-10-17 07:14:01 +03:00
return nil
})
2020-08-06 20:58:47 +03:00
_ = w.Walk()
2016-10-17 07:14:01 +03:00
return count
}
// FindNodes returns the nodes having typ.
2020-08-06 20:58:47 +03:00
func (in *Inspector) FindNodes(typ reflect.Type) []Node {
var nodes []Node
w := NewWalker(in.log, in.ast)
w.AddVisitor(typ, func(log lg.Log, w *Walker, node Node) error {
2016-10-17 07:14:01 +03:00
nodes = append(nodes, node)
return nil
})
if err := w.Walk(); err != nil {
// Should never happen
panic(err)
}
2016-10-17 07:14:01 +03:00
return nodes
2020-08-06 20:58:47 +03:00
}
// FindHandles returns all handles mentioned in the AST.
func (in *Inspector) FindHandles() []string {
var handles []string
if err := walkWith(in.log, in.ast, typeHandleNode, func(log lg.Log, walker *Walker, node Node) error {
handles = append(handles, node.Text())
return nil
}); err != nil {
panic(err)
}
if err := walkWith(in.log, in.ast, typeTblSelectorNode, func(log lg.Log, walker *Walker, node Node) error {
n, _ := node.(*TblSelectorNode)
if n.handle != "" {
handles = append(handles, n.handle)
}
return nil
}); err != nil {
panic(err)
}
return lo.Uniq(handles)
}
2020-08-06 20:58:47 +03:00
// FindWhereClauses returns all the WHERE clauses in the AST.
func (in *Inspector) FindWhereClauses() ([]*WhereNode, error) {
var nodes []*WhereNode
2016-10-17 07:14:01 +03:00
2020-08-06 20:58:47 +03:00
for _, seg := range in.ast.Segments() {
// WhereNode clauses must be the only child of a segment
2020-08-06 20:58:47 +03:00
if len(seg.Children()) == 1 {
if w, ok := seg.Children()[0].(*WhereNode); ok {
nodes = append(nodes, w)
2020-08-06 20:58:47 +03:00
}
}
}
return nodes, nil
2016-10-17 07:14:01 +03:00
}
// FindColExprSegment returns the segment containing col expressions (such as
// ".uid, .email"). This is typically the last segment. It's also possible that
2020-08-06 20:58:47 +03:00
// there is no such segment (which usually results in a SELECT * FROM).
func (in *Inspector) FindColExprSegment() (*SegmentNode, error) {
2016-10-17 07:14:01 +03:00
segs := in.ast.Segments()
// work backwards from the end
for i := len(segs) - 1; i > 0; i-- {
elems := segs[i].Children()
numColExprs := 0
for _, elem := range elems {
if _, ok := elem.(ResultColumn); !ok {
2016-10-17 07:14:01 +03:00
if numColExprs > 0 {
return nil, errorf("found non-homogenous col expr segment [%d]: also has element %T", i, elem)
}
2020-08-06 20:58:47 +03:00
2016-10-17 07:14:01 +03:00
// else it's not a col expr segment, break
break
}
2020-08-06 20:58:47 +03:00
2016-10-17 07:14:01 +03:00
numColExprs++
}
if numColExprs > 0 {
return segs[i], nil
}
}
return nil, nil //nolint:nilnil
2016-10-17 07:14:01 +03:00
}
// FindOrderByNode returns the OrderByNode, or nil if not found.
func (in *Inspector) FindOrderByNode() (*OrderByNode, error) {
segs := in.ast.Segments()
for i := range segs {
nodes := nodesWithType(segs[i].Children(), typeOrderByNode)
switch len(nodes) {
case 0:
// No OrderByNode in this segment, continue searching.
continue
case 1:
// Found it
node, _ := nodes[0].(*OrderByNode)
return node, nil
default:
// Shouldn't be possible
return nil, errorf("Segment {%s} has %d OrderByNode children, but should have a max of 1", segs[i])
}
}
return nil, nil //nolint:nilnil
}
// FindGroupByNode returns the GroupByNode, or nil if not found.
func (in *Inspector) FindGroupByNode() (*GroupByNode, error) {
segs := in.ast.Segments()
for i := range segs {
nodes := nodesWithType(segs[i].Children(), typeGroupByNode)
switch len(nodes) {
case 0:
// No GroupByNode in this segment, continue searching.
continue
case 1:
// Found it
node, _ := nodes[0].(*GroupByNode)
return node, nil
default:
// Shouldn't be possible
return nil, errorf("Segment {%s} has %d GroupByNode children, but should have a max of 1", segs[i])
}
}
return nil, nil //nolint:nilnil
}
2016-10-17 07:14:01 +03:00
// FindSelectableSegments returns the segments that have at least one child
// that implements Tabler.
func (in *Inspector) FindSelectableSegments() []*SegmentNode {
2016-10-17 07:14:01 +03:00
segs := in.ast.Segments()
selSegs := make([]*SegmentNode, 0, 2)
2016-10-17 07:14:01 +03:00
for _, seg := range segs {
for _, child := range seg.Children() {
if _, ok := child.(Tabler); ok {
2016-10-17 07:14:01 +03:00
selSegs = append(selSegs, seg)
break
}
}
}
return selSegs
}
2020-08-06 20:58:47 +03:00
// FindFinalSelectableSegment returns the final segment that
// has at lest one child that implements Tabler.
func (in *Inspector) FindFinalSelectableSegment() (*SegmentNode, error) {
2016-10-17 07:14:01 +03:00
selectableSegs := in.FindSelectableSegments()
if len(selectableSegs) == 0 {
return nil, errorf("no selectable segments")
}
selectableSeg := selectableSegs[len(selectableSegs)-1]
return selectableSeg, nil
}