sq/libsq/ast/inspector.go
2020-08-06 11:58:47 -06:00

122 lines
2.8 KiB
Go

package ast
import (
"reflect"
"github.com/neilotoole/lg"
)
// Inspector provides functionality for AST interrogation.
type Inspector struct {
log lg.Log
ast *AST
}
// NewInspector returns an Inspector instance for ast.
func NewInspector(log lg.Log, ast *AST) *Inspector {
return &Inspector{log: log, ast: ast}
}
// CountNodes counts the number of nodes having typ.
func (in *Inspector) CountNodes(typ reflect.Type) int {
count := 0
w := NewWalker(in.log, in.ast)
w.AddVisitor(typ, func(log lg.Log, w *Walker, node Node) error {
count++
return nil
})
_ = w.Walk()
return count
}
// FindNodes returns all of the nodes having typ.
func (in *Inspector) FindNodes(typ reflect.Type) []Node {
var nodes []Node
w := NewWalker(in.log, in.ast)
w.AddVisitor(typ, func(log lg.Log, w *Walker, node Node) error {
nodes = append(nodes, node)
return nil
})
_ = w.Walk()
return nodes
}
// FindWhereClauses returns all the WHERE clauses in the AST.
func (in *Inspector) FindWhereClauses() ([]*Where, error) {
ws := []*Where{}
for _, seg := range in.ast.Segments() {
// Where clauses must be the only child of a segment
if len(seg.Children()) == 1 {
if w, ok := seg.Children()[0].(*Where); ok {
ws = append(ws, w)
}
}
}
return ws, nil
}
// FindColExprSegment returns the segment containing col expressions (such as
// ".uid, .email"). This is typically the last segment. It's also possible that
// there is no such segment (which usually results in a SELECT * FROM).
func (in *Inspector) FindColExprSegment() (*Segment, error) {
segs := in.ast.Segments()
// work backwards from the end
for i := len(segs) - 1; i > 0; i-- {
elems := segs[i].Children()
numColExprs := 0
for _, elem := range elems {
if _, ok := elem.(ColExpr); !ok {
if numColExprs > 0 {
return nil, errorf("found non-homogenous col expr segment [%d]: also has element %T", i, elem)
}
// else it's not a col expr segment, break
break
}
numColExprs++
}
if numColExprs > 0 {
return segs[i], nil
}
}
return nil, nil
}
// FindSelectableSegments returns the segments that have at least one child
// that implements Selectable.
func (in *Inspector) FindSelectableSegments() []*Segment {
segs := in.ast.Segments()
selSegs := make([]*Segment, 0, 2)
for _, seg := range segs {
for _, child := range seg.Children() {
if _, ok := child.(Selectable); ok {
selSegs = append(selSegs, seg)
break
}
}
}
return selSegs
}
// FindFinalSelectableSegment returns the final segment that
// has at lest one child that implements Selectable.
func (in *Inspector) FindFinalSelectableSegment() (*Segment, error) {
selectableSegs := in.FindSelectableSegments()
if len(selectableSegs) == 0 {
return nil, errorf("no selectable segments")
}
selectableSeg := selectableSegs[len(selectableSegs)-1]
return selectableSeg, nil
}