mirror of
https://github.com/neilotoole/sq.git
synced 2024-12-30 03:32:59 +03:00
122 lines
2.8 KiB
Go
122 lines
2.8 KiB
Go
package ast
|
|
|
|
import (
|
|
"reflect"
|
|
|
|
"github.com/neilotoole/lg"
|
|
)
|
|
|
|
// Inspector provides functionality for AST interrogation.
|
|
type Inspector struct {
|
|
log lg.Log
|
|
ast *AST
|
|
}
|
|
|
|
// NewInspector returns an Inspector instance for ast.
|
|
func NewInspector(log lg.Log, ast *AST) *Inspector {
|
|
return &Inspector{log: log, ast: ast}
|
|
}
|
|
|
|
// CountNodes counts the number of nodes having typ.
|
|
func (in *Inspector) CountNodes(typ reflect.Type) int {
|
|
count := 0
|
|
w := NewWalker(in.log, in.ast)
|
|
w.AddVisitor(typ, func(log lg.Log, w *Walker, node Node) error {
|
|
count++
|
|
return nil
|
|
})
|
|
|
|
_ = w.Walk()
|
|
return count
|
|
}
|
|
|
|
// FindNodes returns all of the nodes having typ.
|
|
func (in *Inspector) FindNodes(typ reflect.Type) []Node {
|
|
var nodes []Node
|
|
w := NewWalker(in.log, in.ast)
|
|
w.AddVisitor(typ, func(log lg.Log, w *Walker, node Node) error {
|
|
nodes = append(nodes, node)
|
|
return nil
|
|
})
|
|
|
|
_ = w.Walk()
|
|
return nodes
|
|
}
|
|
|
|
// FindWhereClauses returns all the WHERE clauses in the AST.
|
|
func (in *Inspector) FindWhereClauses() ([]*Where, error) {
|
|
ws := []*Where{}
|
|
|
|
for _, seg := range in.ast.Segments() {
|
|
// Where clauses must be the only child of a segment
|
|
if len(seg.Children()) == 1 {
|
|
if w, ok := seg.Children()[0].(*Where); ok {
|
|
ws = append(ws, w)
|
|
}
|
|
}
|
|
}
|
|
|
|
return ws, nil
|
|
}
|
|
|
|
// FindColExprSegment returns the segment containing col expressions (such as
|
|
// ".uid, .email"). This is typically the last segment. It's also possible that
|
|
// there is no such segment (which usually results in a SELECT * FROM).
|
|
func (in *Inspector) FindColExprSegment() (*Segment, error) {
|
|
segs := in.ast.Segments()
|
|
|
|
// work backwards from the end
|
|
for i := len(segs) - 1; i > 0; i-- {
|
|
elems := segs[i].Children()
|
|
numColExprs := 0
|
|
|
|
for _, elem := range elems {
|
|
if _, ok := elem.(ColExpr); !ok {
|
|
if numColExprs > 0 {
|
|
return nil, errorf("found non-homogenous col expr segment [%d]: also has element %T", i, elem)
|
|
}
|
|
|
|
// else it's not a col expr segment, break
|
|
break
|
|
}
|
|
|
|
numColExprs++
|
|
}
|
|
|
|
if numColExprs > 0 {
|
|
return segs[i], nil
|
|
}
|
|
}
|
|
|
|
return nil, nil
|
|
}
|
|
|
|
// FindSelectableSegments returns the segments that have at least one child
|
|
// that implements Selectable.
|
|
func (in *Inspector) FindSelectableSegments() []*Segment {
|
|
segs := in.ast.Segments()
|
|
selSegs := make([]*Segment, 0, 2)
|
|
|
|
for _, seg := range segs {
|
|
for _, child := range seg.Children() {
|
|
if _, ok := child.(Selectable); ok {
|
|
selSegs = append(selSegs, seg)
|
|
break
|
|
}
|
|
}
|
|
}
|
|
|
|
return selSegs
|
|
}
|
|
|
|
// FindFinalSelectableSegment returns the final segment that
|
|
// has at lest one child that implements Selectable.
|
|
func (in *Inspector) FindFinalSelectableSegment() (*Segment, error) {
|
|
selectableSegs := in.FindSelectableSegments()
|
|
if len(selectableSegs) == 0 {
|
|
return nil, errorf("no selectable segments")
|
|
}
|
|
selectableSeg := selectableSegs[len(selectableSegs)-1]
|
|
return selectableSeg, nil
|
|
}
|