sq/libsq/ast/sqlbuilder/basebuilder.go
Neil O'Toole b06b631e76
Bug #87: generated SQL should always quote table and column names in join statement (#89)
* BaseFragmentBuilder now quotes table and col names for joins

* Refactored libsq.engine so that the SQL generated from SLQ input can be tested

* Deleted dead code; additional comments
2021-03-07 23:27:35 -07:00

379 lines
9.3 KiB
Go

package sqlbuilder
import (
"bytes"
"fmt"
"math"
"strings"
"github.com/neilotoole/lg"
"github.com/neilotoole/sq/libsq/ast"
"github.com/neilotoole/sq/libsq/core/errz"
)
// baseOps is a map of SLQ operator (e.g. "==" or "!=") to its default SQL rendering.
var baseOps = map[string]string{
`==`: `=`,
}
// BaseOps returns a default map of SLQ operator (e.g. "==" or "!=") to its default SQL rendering.
// The returned map is a copy and can be safely modified by the caller.
func BaseOps() map[string]string {
ops := make(map[string]string, len(baseOps))
for k, v := range baseOps {
ops[k] = v
}
return ops
}
// BaseFragmentBuilder is a default implementation of sqlbuilder.FragmentBuilder.
type BaseFragmentBuilder struct {
Log lg.Log
// Quote is the driver-specific quote rune, e.g. " or `
Quote string
ColQuote string
Ops map[string]string
}
// Operator implements FragmentBuilder.
func (fb *BaseFragmentBuilder) Operator(op *ast.Operator) (string, error) {
if val, ok := fb.Ops[op.Text()]; ok {
return val, nil
}
return op.Text(), nil
}
// Where implements FragmentBuilder.
func (fb *BaseFragmentBuilder) Where(where *ast.Where) (string, error) {
sql, err := fb.Expr(where.Expr())
if err != nil {
return "", err
}
sql = "WHERE " + sql
return sql, nil
}
// Expr implements FragmentBuilder.
func (fb *BaseFragmentBuilder) Expr(expr *ast.Expr) (string, error) {
var sql string
for _, child := range expr.Children() {
switch child := child.(type) {
case *ast.Selector:
val := child.SelValue()
parts := strings.Split(val, ".")
identifier := fb.ColQuote + strings.Join(parts, fb.ColQuote+"."+fb.ColQuote) + fb.ColQuote
sql = sql + " " + identifier
case *ast.Operator:
val, err := fb.Operator(child)
if err != nil {
return "", err
}
sql = sql + " " + val
case *ast.Expr:
val, err := fb.Expr(child)
if err != nil {
return "", err
}
sql = sql + " " + val
default:
sql = sql + " " + child.Text()
}
}
return sql, nil
}
// SelectAll implements FragmentBuilder.
func (fb *BaseFragmentBuilder) SelectAll(tblSel *ast.TblSelector) (string, error) {
sql := fmt.Sprintf("SELECT * FROM %v%s%v", fb.Quote, tblSel.SelValue(), fb.Quote)
return sql, nil
}
// Function implements FragmentBuilder.
func (fb *BaseFragmentBuilder) Function(fn *ast.Func) (string, error) {
buf := &bytes.Buffer{}
children := fn.Children()
if len(children) == 0 {
// no children, let's just grab the direct text
// HACK: this stuff basically doesn't work at all...
// but for COUNT(), here's a quick hack to make it work on some DBs
if fn.Context().GetText() == "count()" {
buf.WriteString("COUNT(*)")
} else {
buf.WriteString(fn.Context().GetText())
}
return buf.String(), nil
}
buf.WriteString(fn.FuncName())
buf.WriteRune('(')
for i, child := range children {
if i > 0 {
buf.WriteString(", ")
}
switch child := child.(type) {
case *ast.ColSelector:
buf.WriteString(child.SelValue())
default:
fb.Log.Debugf("unknown AST child node type %T", child)
}
}
buf.WriteRune(')')
sql := buf.String()
return sql, nil
}
// FromTable implements FragmentBuilder.
func (fb *BaseFragmentBuilder) FromTable(tblSel *ast.TblSelector) (string, error) {
tblName := tblSel.SelValue()
if tblName == "" {
return "", errz.Errorf("selector has empty table name: %q", tblSel.Text())
}
clause := fmt.Sprintf("FROM %v%s%v", fb.Quote, tblSel.SelValue(), fb.Quote)
return clause, nil
}
// Join implements FragmentBuilder.
func (fb *BaseFragmentBuilder) Join(fnJoin *ast.Join) (string, error) {
joinType := "INNER JOIN"
onClause := ""
if len(fnJoin.Children()) == 0 {
joinType = "NATURAL JOIN"
} else {
joinExpr, ok := fnJoin.Children()[0].(*ast.JoinConstraint)
if !ok {
return "", errz.Errorf("expected *FnJoinExpr but got %T", fnJoin.Children()[0])
}
leftOperand := ""
operator := ""
rightOperand := ""
if len(joinExpr.Children()) == 1 {
// It's a single col selector
colSel, ok := joinExpr.Children()[0].(*ast.ColSelector)
if !ok {
return "", errz.Errorf("expected *ColSelector but got %T", joinExpr.Children()[0])
}
leftOperand = fmt.Sprintf("%s%s%s.%s%s%s", fb.Quote, fnJoin.LeftTbl().SelValue(), fb.Quote, fb.Quote, colSel.SelValue(), fb.Quote)
operator = "=="
rightOperand = fmt.Sprintf("%s%s%s.%s%s%s", fb.Quote, fnJoin.RightTbl().SelValue(), fb.Quote, fb.Quote, colSel.SelValue(), fb.Quote)
} else {
var err error
leftOperand, err = quoteTableOrColSelector(fb.Quote, joinExpr.Children()[0].Text())
if err != nil {
return "", err
}
operator = joinExpr.Children()[1].Text()
rightOperand, err = quoteTableOrColSelector(fb.Quote, joinExpr.Children()[2].Text())
if err != nil {
return "", err
}
}
if operator == "==" {
operator = "="
}
onClause = fmt.Sprintf("ON %s %s %s", leftOperand, operator, rightOperand)
}
sql := fmt.Sprintf("FROM %s%s%s %s %s%s%s", fb.Quote, fnJoin.LeftTbl().SelValue(), fb.Quote, joinType, fb.Quote, fnJoin.RightTbl().SelValue(), fb.Quote)
sql = sqlAppend(sql, onClause)
return sql, nil
}
// sqlAppend is a convenience function for building the SQL string.
// The main purpose is to ensure that there's always a consistent amount
// of whitespace. Thus, if existing has a space suffix and add has a
// space prefix, the returned string will only have one space. If add
// is the empty string or just whitespace, this function simply
// returns existing.
func sqlAppend(existing, add string) string {
add = strings.TrimSpace(add)
if add == "" {
return existing
}
existing = strings.TrimSpace(existing)
return existing + " " + add
}
// quoteTableOrColSelector returns a quote table, col, or table/col
// selector for use in a SQL statement. For example:
//
// .table --> "table"
// .col --> "col"
// .table.col --> "table"."col"
//
// Thus, the selector must have exactly one or two periods.
func quoteTableOrColSelector(quote string, selector string) (string, error) {
if len(selector) < 2 || selector[0] != '.' {
return "", errz.Errorf("invalid selector: %s", selector)
}
parts := strings.Split(selector[1:], ".")
switch len(parts) {
case 1:
return quote + parts[0] + quote, nil
case 2:
return quote + parts[0] + quote + "." + quote + parts[1] + quote, nil
default:
return "", errz.Errorf("invalid selector: %s", selector)
}
}
// Range implements FragmentBuilder.
func (fb *BaseFragmentBuilder) Range(rr *ast.RowRange) (string, error) {
if rr == nil {
return "", nil
}
if rr.Limit < 0 && rr.Offset < 0 {
return "", nil
}
limit := ""
offset := ""
if rr.Limit > -1 {
limit = fmt.Sprintf(" LIMIT %d", rr.Limit)
}
if rr.Offset > -1 {
offset = fmt.Sprintf(" OFFSET %d", rr.Offset)
if rr.Limit == -1 {
// MySQL requires a LIMIT if OFFSET is used. Therefore
// we make the LIMIT a very large number
limit = fmt.Sprintf(" LIMIT %d", math.MaxInt64)
}
}
sql := limit + offset
return sql, nil
}
// SelectCols implements FragmentBuilder.
func (fb *BaseFragmentBuilder) SelectCols(cols []ast.ColExpr) (string, error) {
if len(cols) == 0 {
return "SELECT *", nil
}
vals := make([]string, len(cols))
for i, col := range cols {
colText, err := col.ColExpr()
if err != nil {
return "", errz.Errorf("unable to extract col expr from %q: %v", col, err)
}
fn, ok := col.(*ast.Func)
if ok {
// it's a function
vals[i], err = fb.Function(fn)
if err != nil {
return "", err
}
continue
}
if !col.IsColName() {
// it's a function or expression
vals[i] = colText // for now, we just return the raw text
continue
}
// it's a column name, e.g. "uid" or "user.uid"
if !strings.ContainsRune(colText, '.') {
// it's a regular (non-scoped) col name, e.g. "uid"
vals[i] = fmt.Sprintf("%s%s%s", fb.Quote, colText, fb.Quote)
continue
}
// the expr contains a period, so it's likely scoped, e.g. "user.uid"
parts := strings.Split(colText, ".")
if len(parts) != 2 {
return "", errz.Errorf("expected scoped col expr %q to have 2 parts, but got: %v", col, parts)
}
vals[i] = fmt.Sprintf("%s%s%s.%s%s%s", fb.Quote, parts[0], fb.Quote, fb.Quote, parts[1], fb.Quote)
}
text := "SELECT " + strings.Join(vals, ", ")
return text, nil
}
// BaseQueryBuilder is a default implementation
// of sqlbuilder.QueryBuilder.
type BaseQueryBuilder struct {
SelectClause string
FromClause string
WhereClause string
RangeClause string
OrderByClause string
}
// SetSelect implements QueryBuilder.
func (qb *BaseQueryBuilder) SetSelect(cols string) {
qb.SelectClause = cols
}
// SetFrom implements QueryBuilder.
func (qb *BaseQueryBuilder) SetFrom(from string) {
qb.FromClause = from
}
// SetWhere implements QueryBuilder.
func (qb *BaseQueryBuilder) SetWhere(where string) {
qb.WhereClause = where
}
// SetRange implements QueryBuilder.
func (qb *BaseQueryBuilder) SetRange(rng string) {
qb.RangeClause = rng
}
// SQL implements QueryBuilder.
func (qb *BaseQueryBuilder) SQL() (string, error) {
buf := &bytes.Buffer{}
buf.WriteString(qb.SelectClause)
buf.WriteRune(' ')
buf.WriteString(qb.FromClause)
if qb.WhereClause != "" {
buf.WriteRune(' ')
buf.WriteString(qb.WhereClause)
}
if qb.OrderByClause != "" {
buf.WriteRune(' ')
buf.WriteString(qb.OrderByClause)
}
if qb.RangeClause != "" {
buf.WriteRune(' ')
buf.WriteString(qb.RangeClause)
}
return buf.String(), nil
}