mirror of
https://github.com/neilotoole/sq.git
synced 2024-12-29 19:24:26 +03:00
b06b631e76
* BaseFragmentBuilder now quotes table and col names for joins * Refactored libsq.engine so that the SQL generated from SLQ input can be tested * Deleted dead code; additional comments
379 lines
9.3 KiB
Go
379 lines
9.3 KiB
Go
package sqlbuilder
|
|
|
|
import (
|
|
"bytes"
|
|
"fmt"
|
|
"math"
|
|
"strings"
|
|
|
|
"github.com/neilotoole/lg"
|
|
|
|
"github.com/neilotoole/sq/libsq/ast"
|
|
"github.com/neilotoole/sq/libsq/core/errz"
|
|
)
|
|
|
|
// baseOps is a map of SLQ operator (e.g. "==" or "!=") to its default SQL rendering.
|
|
var baseOps = map[string]string{
|
|
`==`: `=`,
|
|
}
|
|
|
|
// BaseOps returns a default map of SLQ operator (e.g. "==" or "!=") to its default SQL rendering.
|
|
// The returned map is a copy and can be safely modified by the caller.
|
|
func BaseOps() map[string]string {
|
|
ops := make(map[string]string, len(baseOps))
|
|
for k, v := range baseOps {
|
|
ops[k] = v
|
|
}
|
|
return ops
|
|
}
|
|
|
|
// BaseFragmentBuilder is a default implementation of sqlbuilder.FragmentBuilder.
|
|
type BaseFragmentBuilder struct {
|
|
Log lg.Log
|
|
// Quote is the driver-specific quote rune, e.g. " or `
|
|
Quote string
|
|
ColQuote string
|
|
Ops map[string]string
|
|
}
|
|
|
|
// Operator implements FragmentBuilder.
|
|
func (fb *BaseFragmentBuilder) Operator(op *ast.Operator) (string, error) {
|
|
if val, ok := fb.Ops[op.Text()]; ok {
|
|
return val, nil
|
|
}
|
|
|
|
return op.Text(), nil
|
|
}
|
|
|
|
// Where implements FragmentBuilder.
|
|
func (fb *BaseFragmentBuilder) Where(where *ast.Where) (string, error) {
|
|
sql, err := fb.Expr(where.Expr())
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
sql = "WHERE " + sql
|
|
return sql, nil
|
|
}
|
|
|
|
// Expr implements FragmentBuilder.
|
|
func (fb *BaseFragmentBuilder) Expr(expr *ast.Expr) (string, error) {
|
|
var sql string
|
|
|
|
for _, child := range expr.Children() {
|
|
switch child := child.(type) {
|
|
case *ast.Selector:
|
|
val := child.SelValue()
|
|
parts := strings.Split(val, ".")
|
|
identifier := fb.ColQuote + strings.Join(parts, fb.ColQuote+"."+fb.ColQuote) + fb.ColQuote
|
|
sql = sql + " " + identifier
|
|
case *ast.Operator:
|
|
val, err := fb.Operator(child)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
sql = sql + " " + val
|
|
case *ast.Expr:
|
|
val, err := fb.Expr(child)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
sql = sql + " " + val
|
|
|
|
default:
|
|
sql = sql + " " + child.Text()
|
|
}
|
|
}
|
|
|
|
return sql, nil
|
|
}
|
|
|
|
// SelectAll implements FragmentBuilder.
|
|
func (fb *BaseFragmentBuilder) SelectAll(tblSel *ast.TblSelector) (string, error) {
|
|
sql := fmt.Sprintf("SELECT * FROM %v%s%v", fb.Quote, tblSel.SelValue(), fb.Quote)
|
|
return sql, nil
|
|
}
|
|
|
|
// Function implements FragmentBuilder.
|
|
func (fb *BaseFragmentBuilder) Function(fn *ast.Func) (string, error) {
|
|
|
|
buf := &bytes.Buffer{}
|
|
children := fn.Children()
|
|
|
|
if len(children) == 0 {
|
|
// no children, let's just grab the direct text
|
|
|
|
// HACK: this stuff basically doesn't work at all...
|
|
// but for COUNT(), here's a quick hack to make it work on some DBs
|
|
if fn.Context().GetText() == "count()" {
|
|
buf.WriteString("COUNT(*)")
|
|
} else {
|
|
buf.WriteString(fn.Context().GetText())
|
|
}
|
|
|
|
return buf.String(), nil
|
|
}
|
|
|
|
buf.WriteString(fn.FuncName())
|
|
buf.WriteRune('(')
|
|
for i, child := range children {
|
|
if i > 0 {
|
|
buf.WriteString(", ")
|
|
}
|
|
|
|
switch child := child.(type) {
|
|
case *ast.ColSelector:
|
|
buf.WriteString(child.SelValue())
|
|
default:
|
|
fb.Log.Debugf("unknown AST child node type %T", child)
|
|
}
|
|
}
|
|
|
|
buf.WriteRune(')')
|
|
sql := buf.String()
|
|
return sql, nil
|
|
}
|
|
|
|
// FromTable implements FragmentBuilder.
|
|
func (fb *BaseFragmentBuilder) FromTable(tblSel *ast.TblSelector) (string, error) {
|
|
tblName := tblSel.SelValue()
|
|
if tblName == "" {
|
|
return "", errz.Errorf("selector has empty table name: %q", tblSel.Text())
|
|
}
|
|
|
|
clause := fmt.Sprintf("FROM %v%s%v", fb.Quote, tblSel.SelValue(), fb.Quote)
|
|
return clause, nil
|
|
}
|
|
|
|
// Join implements FragmentBuilder.
|
|
func (fb *BaseFragmentBuilder) Join(fnJoin *ast.Join) (string, error) {
|
|
joinType := "INNER JOIN"
|
|
onClause := ""
|
|
|
|
if len(fnJoin.Children()) == 0 {
|
|
joinType = "NATURAL JOIN"
|
|
} else {
|
|
joinExpr, ok := fnJoin.Children()[0].(*ast.JoinConstraint)
|
|
if !ok {
|
|
return "", errz.Errorf("expected *FnJoinExpr but got %T", fnJoin.Children()[0])
|
|
}
|
|
|
|
leftOperand := ""
|
|
operator := ""
|
|
rightOperand := ""
|
|
|
|
if len(joinExpr.Children()) == 1 {
|
|
// It's a single col selector
|
|
colSel, ok := joinExpr.Children()[0].(*ast.ColSelector)
|
|
if !ok {
|
|
return "", errz.Errorf("expected *ColSelector but got %T", joinExpr.Children()[0])
|
|
}
|
|
|
|
leftOperand = fmt.Sprintf("%s%s%s.%s%s%s", fb.Quote, fnJoin.LeftTbl().SelValue(), fb.Quote, fb.Quote, colSel.SelValue(), fb.Quote)
|
|
operator = "=="
|
|
rightOperand = fmt.Sprintf("%s%s%s.%s%s%s", fb.Quote, fnJoin.RightTbl().SelValue(), fb.Quote, fb.Quote, colSel.SelValue(), fb.Quote)
|
|
} else {
|
|
var err error
|
|
|
|
leftOperand, err = quoteTableOrColSelector(fb.Quote, joinExpr.Children()[0].Text())
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
operator = joinExpr.Children()[1].Text()
|
|
|
|
rightOperand, err = quoteTableOrColSelector(fb.Quote, joinExpr.Children()[2].Text())
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
}
|
|
|
|
if operator == "==" {
|
|
operator = "="
|
|
}
|
|
|
|
onClause = fmt.Sprintf("ON %s %s %s", leftOperand, operator, rightOperand)
|
|
}
|
|
|
|
sql := fmt.Sprintf("FROM %s%s%s %s %s%s%s", fb.Quote, fnJoin.LeftTbl().SelValue(), fb.Quote, joinType, fb.Quote, fnJoin.RightTbl().SelValue(), fb.Quote)
|
|
sql = sqlAppend(sql, onClause)
|
|
|
|
return sql, nil
|
|
}
|
|
|
|
// sqlAppend is a convenience function for building the SQL string.
|
|
// The main purpose is to ensure that there's always a consistent amount
|
|
// of whitespace. Thus, if existing has a space suffix and add has a
|
|
// space prefix, the returned string will only have one space. If add
|
|
// is the empty string or just whitespace, this function simply
|
|
// returns existing.
|
|
func sqlAppend(existing, add string) string {
|
|
add = strings.TrimSpace(add)
|
|
if add == "" {
|
|
return existing
|
|
}
|
|
|
|
existing = strings.TrimSpace(existing)
|
|
return existing + " " + add
|
|
}
|
|
|
|
// quoteTableOrColSelector returns a quote table, col, or table/col
|
|
// selector for use in a SQL statement. For example:
|
|
//
|
|
// .table --> "table"
|
|
// .col --> "col"
|
|
// .table.col --> "table"."col"
|
|
//
|
|
// Thus, the selector must have exactly one or two periods.
|
|
func quoteTableOrColSelector(quote string, selector string) (string, error) {
|
|
if len(selector) < 2 || selector[0] != '.' {
|
|
return "", errz.Errorf("invalid selector: %s", selector)
|
|
}
|
|
|
|
parts := strings.Split(selector[1:], ".")
|
|
switch len(parts) {
|
|
case 1:
|
|
return quote + parts[0] + quote, nil
|
|
case 2:
|
|
return quote + parts[0] + quote + "." + quote + parts[1] + quote, nil
|
|
default:
|
|
return "", errz.Errorf("invalid selector: %s", selector)
|
|
}
|
|
}
|
|
|
|
// Range implements FragmentBuilder.
|
|
func (fb *BaseFragmentBuilder) Range(rr *ast.RowRange) (string, error) {
|
|
if rr == nil {
|
|
return "", nil
|
|
}
|
|
|
|
if rr.Limit < 0 && rr.Offset < 0 {
|
|
return "", nil
|
|
}
|
|
|
|
limit := ""
|
|
offset := ""
|
|
if rr.Limit > -1 {
|
|
limit = fmt.Sprintf(" LIMIT %d", rr.Limit)
|
|
}
|
|
if rr.Offset > -1 {
|
|
offset = fmt.Sprintf(" OFFSET %d", rr.Offset)
|
|
|
|
if rr.Limit == -1 {
|
|
// MySQL requires a LIMIT if OFFSET is used. Therefore
|
|
// we make the LIMIT a very large number
|
|
limit = fmt.Sprintf(" LIMIT %d", math.MaxInt64)
|
|
}
|
|
}
|
|
|
|
sql := limit + offset
|
|
|
|
return sql, nil
|
|
}
|
|
|
|
// SelectCols implements FragmentBuilder.
|
|
func (fb *BaseFragmentBuilder) SelectCols(cols []ast.ColExpr) (string, error) {
|
|
if len(cols) == 0 {
|
|
return "SELECT *", nil
|
|
}
|
|
|
|
vals := make([]string, len(cols))
|
|
|
|
for i, col := range cols {
|
|
colText, err := col.ColExpr()
|
|
if err != nil {
|
|
return "", errz.Errorf("unable to extract col expr from %q: %v", col, err)
|
|
}
|
|
|
|
fn, ok := col.(*ast.Func)
|
|
if ok {
|
|
// it's a function
|
|
vals[i], err = fb.Function(fn)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
continue
|
|
}
|
|
|
|
if !col.IsColName() {
|
|
// it's a function or expression
|
|
vals[i] = colText // for now, we just return the raw text
|
|
continue
|
|
}
|
|
|
|
// it's a column name, e.g. "uid" or "user.uid"
|
|
if !strings.ContainsRune(colText, '.') {
|
|
// it's a regular (non-scoped) col name, e.g. "uid"
|
|
vals[i] = fmt.Sprintf("%s%s%s", fb.Quote, colText, fb.Quote)
|
|
continue
|
|
}
|
|
|
|
// the expr contains a period, so it's likely scoped, e.g. "user.uid"
|
|
parts := strings.Split(colText, ".")
|
|
if len(parts) != 2 {
|
|
return "", errz.Errorf("expected scoped col expr %q to have 2 parts, but got: %v", col, parts)
|
|
}
|
|
|
|
vals[i] = fmt.Sprintf("%s%s%s.%s%s%s", fb.Quote, parts[0], fb.Quote, fb.Quote, parts[1], fb.Quote)
|
|
}
|
|
|
|
text := "SELECT " + strings.Join(vals, ", ")
|
|
return text, nil
|
|
}
|
|
|
|
// BaseQueryBuilder is a default implementation
|
|
// of sqlbuilder.QueryBuilder.
|
|
type BaseQueryBuilder struct {
|
|
SelectClause string
|
|
FromClause string
|
|
WhereClause string
|
|
RangeClause string
|
|
OrderByClause string
|
|
}
|
|
|
|
// SetSelect implements QueryBuilder.
|
|
func (qb *BaseQueryBuilder) SetSelect(cols string) {
|
|
qb.SelectClause = cols
|
|
}
|
|
|
|
// SetFrom implements QueryBuilder.
|
|
func (qb *BaseQueryBuilder) SetFrom(from string) {
|
|
qb.FromClause = from
|
|
}
|
|
|
|
// SetWhere implements QueryBuilder.
|
|
func (qb *BaseQueryBuilder) SetWhere(where string) {
|
|
qb.WhereClause = where
|
|
}
|
|
|
|
// SetRange implements QueryBuilder.
|
|
func (qb *BaseQueryBuilder) SetRange(rng string) {
|
|
qb.RangeClause = rng
|
|
}
|
|
|
|
// SQL implements QueryBuilder.
|
|
func (qb *BaseQueryBuilder) SQL() (string, error) {
|
|
buf := &bytes.Buffer{}
|
|
|
|
buf.WriteString(qb.SelectClause)
|
|
buf.WriteRune(' ')
|
|
buf.WriteString(qb.FromClause)
|
|
|
|
if qb.WhereClause != "" {
|
|
buf.WriteRune(' ')
|
|
buf.WriteString(qb.WhereClause)
|
|
}
|
|
|
|
if qb.OrderByClause != "" {
|
|
buf.WriteRune(' ')
|
|
buf.WriteString(qb.OrderByClause)
|
|
}
|
|
|
|
if qb.RangeClause != "" {
|
|
buf.WriteRune(' ')
|
|
buf.WriteString(qb.RangeClause)
|
|
}
|
|
|
|
return buf.String(), nil
|
|
}
|