2016-10-17 07:14:01 +03:00
|
|
|
package ast
|
|
|
|
|
2020-08-06 20:58:47 +03:00
|
|
|
import (
|
|
|
|
"reflect"
|
2016-10-17 07:14:01 +03:00
|
|
|
|
2023-07-03 18:34:19 +03:00
|
|
|
"github.com/neilotoole/sq/libsq/core/errz"
|
|
|
|
|
2023-03-27 05:03:40 +03:00
|
|
|
"github.com/samber/lo"
|
2020-08-06 20:58:47 +03:00
|
|
|
)
|
|
|
|
|
|
|
|
// Inspector provides functionality for AST interrogation.
|
2016-10-17 07:14:01 +03:00
|
|
|
type Inspector struct {
|
|
|
|
ast *AST
|
|
|
|
}
|
|
|
|
|
2020-08-06 20:58:47 +03:00
|
|
|
// NewInspector returns an Inspector instance for ast.
|
2023-04-07 11:00:49 +03:00
|
|
|
func NewInspector(ast *AST) *Inspector {
|
|
|
|
return &Inspector{ast: ast}
|
2016-10-17 07:14:01 +03:00
|
|
|
}
|
|
|
|
|
2020-08-06 20:58:47 +03:00
|
|
|
// CountNodes counts the number of nodes having typ.
|
|
|
|
func (in *Inspector) CountNodes(typ reflect.Type) int {
|
2016-10-17 07:14:01 +03:00
|
|
|
count := 0
|
2023-04-07 11:00:49 +03:00
|
|
|
w := NewWalker(in.ast)
|
|
|
|
w.AddVisitor(typ, func(w *Walker, node Node) error {
|
2016-10-17 07:14:01 +03:00
|
|
|
count++
|
|
|
|
return nil
|
|
|
|
})
|
|
|
|
|
2020-08-06 20:58:47 +03:00
|
|
|
_ = w.Walk()
|
2016-10-17 07:14:01 +03:00
|
|
|
return count
|
|
|
|
}
|
|
|
|
|
2023-03-26 04:20:53 +03:00
|
|
|
// FindNodes returns the nodes having typ.
|
2020-08-06 20:58:47 +03:00
|
|
|
func (in *Inspector) FindNodes(typ reflect.Type) []Node {
|
|
|
|
var nodes []Node
|
2023-04-07 11:00:49 +03:00
|
|
|
w := NewWalker(in.ast)
|
|
|
|
w.AddVisitor(typ, func(w *Walker, node Node) error {
|
2016-10-17 07:14:01 +03:00
|
|
|
nodes = append(nodes, node)
|
|
|
|
return nil
|
|
|
|
})
|
|
|
|
|
2023-03-27 05:03:40 +03:00
|
|
|
if err := w.Walk(); err != nil {
|
|
|
|
// Should never happen
|
|
|
|
panic(err)
|
|
|
|
}
|
2016-10-17 07:14:01 +03:00
|
|
|
return nodes
|
2020-08-06 20:58:47 +03:00
|
|
|
}
|
|
|
|
|
2023-03-27 05:03:40 +03:00
|
|
|
// FindHandles returns all handles mentioned in the AST.
|
|
|
|
func (in *Inspector) FindHandles() []string {
|
|
|
|
var handles []string
|
|
|
|
|
2023-04-07 11:00:49 +03:00
|
|
|
if err := walkWith(in.ast, typeHandleNode, func(walker *Walker, node Node) error {
|
2023-03-27 05:03:40 +03:00
|
|
|
handles = append(handles, node.Text())
|
|
|
|
return nil
|
|
|
|
}); err != nil {
|
|
|
|
panic(err)
|
|
|
|
}
|
|
|
|
|
2023-04-07 11:00:49 +03:00
|
|
|
if err := walkWith(in.ast, typeTblSelectorNode, func(walker *Walker, node Node) error {
|
2023-03-27 05:03:40 +03:00
|
|
|
n, _ := node.(*TblSelectorNode)
|
|
|
|
if n.handle != "" {
|
|
|
|
handles = append(handles, n.handle)
|
|
|
|
}
|
|
|
|
return nil
|
|
|
|
}); err != nil {
|
|
|
|
panic(err)
|
|
|
|
}
|
|
|
|
|
|
|
|
return lo.Uniq(handles)
|
|
|
|
}
|
|
|
|
|
2020-08-06 20:58:47 +03:00
|
|
|
// FindWhereClauses returns all the WHERE clauses in the AST.
|
2023-03-26 04:20:53 +03:00
|
|
|
func (in *Inspector) FindWhereClauses() ([]*WhereNode, error) {
|
|
|
|
var nodes []*WhereNode
|
2016-10-17 07:14:01 +03:00
|
|
|
|
2020-08-06 20:58:47 +03:00
|
|
|
for _, seg := range in.ast.Segments() {
|
2023-03-26 04:20:53 +03:00
|
|
|
// WhereNode clauses must be the only child of a segment
|
2020-08-06 20:58:47 +03:00
|
|
|
if len(seg.Children()) == 1 {
|
2023-03-26 04:20:53 +03:00
|
|
|
if w, ok := seg.Children()[0].(*WhereNode); ok {
|
|
|
|
nodes = append(nodes, w)
|
2020-08-06 20:58:47 +03:00
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2023-03-26 04:20:53 +03:00
|
|
|
return nodes, nil
|
2016-10-17 07:14:01 +03:00
|
|
|
}
|
|
|
|
|
|
|
|
// FindColExprSegment returns the segment containing col expressions (such as
|
|
|
|
// ".uid, .email"). This is typically the last segment. It's also possible that
|
2020-08-06 20:58:47 +03:00
|
|
|
// there is no such segment (which usually results in a SELECT * FROM).
|
2023-03-22 09:17:34 +03:00
|
|
|
func (in *Inspector) FindColExprSegment() (*SegmentNode, error) {
|
2016-10-17 07:14:01 +03:00
|
|
|
segs := in.ast.Segments()
|
|
|
|
|
|
|
|
// work backwards from the end
|
2023-06-18 04:28:11 +03:00
|
|
|
for i := len(segs) - 1; i >= 0; i-- {
|
2016-10-17 07:14:01 +03:00
|
|
|
elems := segs[i].Children()
|
|
|
|
numColExprs := 0
|
|
|
|
|
|
|
|
for _, elem := range elems {
|
2023-03-22 09:17:34 +03:00
|
|
|
if _, ok := elem.(ResultColumn); !ok {
|
2016-10-17 07:14:01 +03:00
|
|
|
if numColExprs > 0 {
|
|
|
|
return nil, errorf("found non-homogenous col expr segment [%d]: also has element %T", i, elem)
|
|
|
|
}
|
2020-08-06 20:58:47 +03:00
|
|
|
|
2016-10-17 07:14:01 +03:00
|
|
|
// else it's not a col expr segment, break
|
|
|
|
break
|
|
|
|
}
|
2020-08-06 20:58:47 +03:00
|
|
|
|
2016-10-17 07:14:01 +03:00
|
|
|
numColExprs++
|
|
|
|
}
|
|
|
|
|
|
|
|
if numColExprs > 0 {
|
|
|
|
return segs[i], nil
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2022-12-18 03:51:33 +03:00
|
|
|
return nil, nil //nolint:nilnil
|
2016-10-17 07:14:01 +03:00
|
|
|
}
|
|
|
|
|
2023-03-26 04:20:53 +03:00
|
|
|
// FindOrderByNode returns the OrderByNode, or nil if not found.
|
|
|
|
func (in *Inspector) FindOrderByNode() (*OrderByNode, error) {
|
|
|
|
segs := in.ast.Segments()
|
|
|
|
|
|
|
|
for i := range segs {
|
|
|
|
nodes := nodesWithType(segs[i].Children(), typeOrderByNode)
|
|
|
|
switch len(nodes) {
|
|
|
|
case 0:
|
|
|
|
// No OrderByNode in this segment, continue searching.
|
|
|
|
continue
|
|
|
|
case 1:
|
|
|
|
// Found it
|
|
|
|
node, _ := nodes[0].(*OrderByNode)
|
|
|
|
return node, nil
|
|
|
|
default:
|
|
|
|
// Shouldn't be possible
|
2023-04-02 22:49:45 +03:00
|
|
|
return nil, errorf("segment {%s} has %d OrderByNode children, but max is 1",
|
|
|
|
segs[i], len(nodes))
|
2023-03-26 04:20:53 +03:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
return nil, nil //nolint:nilnil
|
|
|
|
}
|
|
|
|
|
2023-03-26 11:01:41 +03:00
|
|
|
// FindGroupByNode returns the GroupByNode, or nil if not found.
|
|
|
|
func (in *Inspector) FindGroupByNode() (*GroupByNode, error) {
|
|
|
|
segs := in.ast.Segments()
|
|
|
|
|
|
|
|
for i := range segs {
|
|
|
|
nodes := nodesWithType(segs[i].Children(), typeGroupByNode)
|
|
|
|
switch len(nodes) {
|
|
|
|
case 0:
|
|
|
|
// No GroupByNode in this segment, continue searching.
|
|
|
|
continue
|
|
|
|
case 1:
|
|
|
|
// Found it
|
|
|
|
node, _ := nodes[0].(*GroupByNode)
|
|
|
|
return node, nil
|
|
|
|
default:
|
|
|
|
// Shouldn't be possible
|
2023-04-02 22:49:45 +03:00
|
|
|
return nil, errorf("segment {%s} has %d GroupByNode children, but max is 1",
|
|
|
|
segs[i], len(nodes))
|
2023-03-26 11:01:41 +03:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
return nil, nil //nolint:nilnil
|
|
|
|
}
|
|
|
|
|
2023-07-03 18:34:19 +03:00
|
|
|
// FindTableSegments returns the segments that have at least one child
|
|
|
|
// that is a ast.TblSelectorNode.
|
|
|
|
func (in *Inspector) FindTableSegments() []*SegmentNode {
|
2016-10-17 07:14:01 +03:00
|
|
|
segs := in.ast.Segments()
|
2023-03-22 09:17:34 +03:00
|
|
|
selSegs := make([]*SegmentNode, 0, 2)
|
2016-10-17 07:14:01 +03:00
|
|
|
|
|
|
|
for _, seg := range segs {
|
|
|
|
for _, child := range seg.Children() {
|
2023-07-03 18:34:19 +03:00
|
|
|
if _, ok := child.(*TblSelectorNode); ok {
|
2016-10-17 07:14:01 +03:00
|
|
|
selSegs = append(selSegs, seg)
|
|
|
|
break
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
return selSegs
|
|
|
|
}
|
|
|
|
|
2023-06-18 04:28:11 +03:00
|
|
|
// FindFirstHandle returns the first handle mentioned in the query,
|
|
|
|
// or returns empty string.
|
|
|
|
func (in *Inspector) FindFirstHandle() (handle string) {
|
|
|
|
nodes := in.FindNodes(typeHandleNode)
|
|
|
|
if len(nodes) > 0 {
|
|
|
|
handle = nodes[0].(*HandleNode).Handle()
|
|
|
|
return handle
|
|
|
|
}
|
|
|
|
|
|
|
|
nodes = in.FindNodes(typeTblSelectorNode)
|
|
|
|
for _, node := range nodes {
|
|
|
|
handle = node.(*TblSelectorNode).Handle()
|
|
|
|
if handle != "" {
|
|
|
|
return handle
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
return ""
|
|
|
|
}
|
|
|
|
|
2023-07-03 18:34:19 +03:00
|
|
|
// FindFirstTableSelector returns the first top-level (child of a segment)
|
|
|
|
// table selector node.
|
|
|
|
func (in *Inspector) FindFirstTableSelector() *TblSelectorNode {
|
|
|
|
segs := in.ast.Segments()
|
|
|
|
if len(segs) == 0 {
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
|
|
|
var tblSelNode *TblSelectorNode
|
|
|
|
var ok bool
|
|
|
|
|
|
|
|
for _, seg := range segs {
|
|
|
|
for _, child := range seg.Children() {
|
|
|
|
if tblSelNode, ok = child.(*TblSelectorNode); ok {
|
|
|
|
return tblSelNode
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
|
|
|
// FindFinalTableSegment returns the final segment that
|
|
|
|
// has at least one child that is an ast.TblSelectorNode.
|
|
|
|
func (in *Inspector) FindFinalTableSegment() (*SegmentNode, error) {
|
|
|
|
selectableSegs := in.FindTableSegments()
|
2016-10-17 07:14:01 +03:00
|
|
|
if len(selectableSegs) == 0 {
|
|
|
|
return nil, errorf("no selectable segments")
|
|
|
|
}
|
|
|
|
selectableSeg := selectableSegs[len(selectableSegs)-1]
|
|
|
|
return selectableSeg, nil
|
|
|
|
}
|
2023-03-28 09:48:24 +03:00
|
|
|
|
2023-07-03 18:34:19 +03:00
|
|
|
// FindJoins returns all ast.JoinNode instances.
|
|
|
|
func (in *Inspector) FindJoins() ([]*JoinNode, error) {
|
|
|
|
nodes := in.FindNodes(typeJoinNode)
|
|
|
|
joinNodes := make([]*JoinNode, len(nodes))
|
|
|
|
var ok bool
|
|
|
|
for i := range nodes {
|
|
|
|
joinNodes[i], ok = nodes[i].(*JoinNode)
|
|
|
|
if !ok {
|
|
|
|
return nil, errz.Errorf("expected %T but got %T", (*JoinNode)(nil), nodes[i])
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
return joinNodes, nil
|
|
|
|
}
|
|
|
|
|
2023-03-28 09:48:24 +03:00
|
|
|
// FindUniqueNode returns any UniqueNode, or nil.
|
|
|
|
func (in *Inspector) FindUniqueNode() (*UniqueNode, error) {
|
|
|
|
nodes := in.FindNodes(typeUniqueNode)
|
|
|
|
if len(nodes) == 0 {
|
|
|
|
return nil, nil //nolint:nilnil
|
|
|
|
}
|
|
|
|
return nodes[0].(*UniqueNode), nil
|
|
|
|
}
|
2023-06-18 04:28:11 +03:00
|
|
|
|
|
|
|
// FindRowRangeNode returns the single RowRangeNode, or nil.
|
|
|
|
// An error can be returned if the AST is in an illegal state.
|
|
|
|
func (in *Inspector) FindRowRangeNode() (*RowRangeNode, error) {
|
|
|
|
nodes := in.FindNodes(typeRowRangeNode)
|
|
|
|
switch len(nodes) {
|
|
|
|
case 0:
|
|
|
|
return nil, nil //nolint:nilnil
|
|
|
|
case 1:
|
|
|
|
return nodes[0].(*RowRangeNode), nil
|
|
|
|
default:
|
|
|
|
// Shouldn't be possible
|
|
|
|
return nil, errorf("illegal query: only one %T allowed, but found %d", (*RowRangeNode)(nil), len(nodes))
|
|
|
|
}
|
|
|
|
}
|