sq/libsq/ast/inspector.go

215 lines
5.1 KiB
Go
Raw Normal View History

2016-10-17 07:14:01 +03:00
package ast
2020-08-06 20:58:47 +03:00
import (
"reflect"
2016-10-17 07:14:01 +03:00
"golang.org/x/exp/slog"
"github.com/samber/lo"
"github.com/ryboe/q"
2020-08-06 20:58:47 +03:00
)
// Inspector provides functionality for AST interrogation.
2016-10-17 07:14:01 +03:00
type Inspector struct {
log *slog.Logger
2016-10-17 07:14:01 +03:00
ast *AST
}
2020-08-06 20:58:47 +03:00
// NewInspector returns an Inspector instance for ast.
func NewInspector(log *slog.Logger, ast *AST) *Inspector {
2020-08-06 20:58:47 +03:00
return &Inspector{log: log, ast: ast}
2016-10-17 07:14:01 +03:00
}
2020-08-06 20:58:47 +03:00
// CountNodes counts the number of nodes having typ.
func (in *Inspector) CountNodes(typ reflect.Type) int {
2016-10-17 07:14:01 +03:00
count := 0
2020-08-06 20:58:47 +03:00
w := NewWalker(in.log, in.ast)
w.AddVisitor(typ, func(log *slog.Logger, w *Walker, node Node) error {
2016-10-17 07:14:01 +03:00
count++
if typ == typeSelectorNode {
// found it
// FIXME: delete this
q.Q("found it", node)
}
2016-10-17 07:14:01 +03:00
return nil
})
2020-08-06 20:58:47 +03:00
_ = w.Walk()
2016-10-17 07:14:01 +03:00
return count
}
// FindNodes returns the nodes having typ.
2020-08-06 20:58:47 +03:00
func (in *Inspector) FindNodes(typ reflect.Type) []Node {
var nodes []Node
w := NewWalker(in.log, in.ast)
w.AddVisitor(typ, func(log *slog.Logger, w *Walker, node Node) error {
2016-10-17 07:14:01 +03:00
nodes = append(nodes, node)
return nil
})
if err := w.Walk(); err != nil {
// Should never happen
panic(err)
}
2016-10-17 07:14:01 +03:00
return nodes
2020-08-06 20:58:47 +03:00
}
// FindHandles returns all handles mentioned in the AST.
func (in *Inspector) FindHandles() []string {
var handles []string
if err := walkWith(in.log, in.ast, typeHandleNode, func(log *slog.Logger, walker *Walker, node Node) error {
handles = append(handles, node.Text())
return nil
}); err != nil {
panic(err)
}
if err := walkWith(in.log, in.ast, typeTblSelectorNode, func(log *slog.Logger, walker *Walker, node Node) error {
n, _ := node.(*TblSelectorNode)
if n.handle != "" {
handles = append(handles, n.handle)
}
return nil
}); err != nil {
panic(err)
}
return lo.Uniq(handles)
}
2020-08-06 20:58:47 +03:00
// FindWhereClauses returns all the WHERE clauses in the AST.
func (in *Inspector) FindWhereClauses() ([]*WhereNode, error) {
var nodes []*WhereNode
2016-10-17 07:14:01 +03:00
2020-08-06 20:58:47 +03:00
for _, seg := range in.ast.Segments() {
// WhereNode clauses must be the only child of a segment
2020-08-06 20:58:47 +03:00
if len(seg.Children()) == 1 {
if w, ok := seg.Children()[0].(*WhereNode); ok {
nodes = append(nodes, w)
2020-08-06 20:58:47 +03:00
}
}
}
return nodes, nil
2016-10-17 07:14:01 +03:00
}
// FindColExprSegment returns the segment containing col expressions (such as
// ".uid, .email"). This is typically the last segment. It's also possible that
2020-08-06 20:58:47 +03:00
// there is no such segment (which usually results in a SELECT * FROM).
func (in *Inspector) FindColExprSegment() (*SegmentNode, error) {
2016-10-17 07:14:01 +03:00
segs := in.ast.Segments()
// work backwards from the end
for i := len(segs) - 1; i > 0; i-- {
elems := segs[i].Children()
numColExprs := 0
for _, elem := range elems {
if _, ok := elem.(ResultColumn); !ok {
2016-10-17 07:14:01 +03:00
if numColExprs > 0 {
return nil, errorf("found non-homogenous col expr segment [%d]: also has element %T", i, elem)
}
2020-08-06 20:58:47 +03:00
2016-10-17 07:14:01 +03:00
// else it's not a col expr segment, break
break
}
2020-08-06 20:58:47 +03:00
2016-10-17 07:14:01 +03:00
numColExprs++
}
if numColExprs > 0 {
return segs[i], nil
}
}
return nil, nil //nolint:nilnil
2016-10-17 07:14:01 +03:00
}
// FindOrderByNode returns the OrderByNode, or nil if not found.
func (in *Inspector) FindOrderByNode() (*OrderByNode, error) {
segs := in.ast.Segments()
for i := range segs {
nodes := nodesWithType(segs[i].Children(), typeOrderByNode)
switch len(nodes) {
case 0:
// No OrderByNode in this segment, continue searching.
continue
case 1:
// Found it
node, _ := nodes[0].(*OrderByNode)
return node, nil
default:
// Shouldn't be possible
return nil, errorf("segment {%s} has %d OrderByNode children, but max is 1",
segs[i], len(nodes))
}
}
return nil, nil //nolint:nilnil
}
// FindGroupByNode returns the GroupByNode, or nil if not found.
func (in *Inspector) FindGroupByNode() (*GroupByNode, error) {
segs := in.ast.Segments()
for i := range segs {
nodes := nodesWithType(segs[i].Children(), typeGroupByNode)
switch len(nodes) {
case 0:
// No GroupByNode in this segment, continue searching.
continue
case 1:
// Found it
node, _ := nodes[0].(*GroupByNode)
return node, nil
default:
// Shouldn't be possible
return nil, errorf("segment {%s} has %d GroupByNode children, but max is 1",
segs[i], len(nodes))
}
}
return nil, nil //nolint:nilnil
}
// FindTablerSegments returns the segments that have at least one child
// that implements Tabler.
func (in *Inspector) FindTablerSegments() []*SegmentNode {
2016-10-17 07:14:01 +03:00
segs := in.ast.Segments()
selSegs := make([]*SegmentNode, 0, 2)
2016-10-17 07:14:01 +03:00
for _, seg := range segs {
for _, child := range seg.Children() {
if _, ok := child.(Tabler); ok {
2016-10-17 07:14:01 +03:00
selSegs = append(selSegs, seg)
break
}
}
}
return selSegs
}
// FindFinalTablerSegment returns the final segment that
// has at lest one child that implements Tabler.
func (in *Inspector) FindFinalTablerSegment() (*SegmentNode, error) {
selectableSegs := in.FindTablerSegments()
2016-10-17 07:14:01 +03:00
if len(selectableSegs) == 0 {
return nil, errorf("no selectable segments")
}
selectableSeg := selectableSegs[len(selectableSegs)-1]
return selectableSeg, nil
}
// FindUniqueNode returns any UniqueNode, or nil.
func (in *Inspector) FindUniqueNode() (*UniqueNode, error) {
nodes := in.FindNodes(typeUniqueNode)
if len(nodes) == 0 {
return nil, nil //nolint:nilnil
}
return nodes[0].(*UniqueNode), nil
}