package sqlbuilder import ( "bytes" "fmt" "math" "strings" "github.com/neilotoole/lg" "github.com/neilotoole/sq/libsq/ast" "github.com/neilotoole/sq/libsq/core/errz" ) // baseOps is a map of SLQ operator (e.g. "==" or "!=") to its default SQL rendering. var baseOps = map[string]string{ `==`: `=`, } // BaseOps returns a default map of SLQ operator (e.g. "==" or "!=") to its default SQL rendering. // The returned map is a copy and can be safely modified by the caller. func BaseOps() map[string]string { ops := make(map[string]string, len(baseOps)) for k, v := range baseOps { ops[k] = v } return ops } // BaseFragmentBuilder is a default implementation of sqlbuilder.FragmentBuilder. type BaseFragmentBuilder struct { Log lg.Log // Quote is the driver-specific quote rune, e.g. " or ` Quote string ColQuote string Ops map[string]string } // Operator implements FragmentBuilder. func (fb *BaseFragmentBuilder) Operator(op *ast.Operator) (string, error) { if val, ok := fb.Ops[op.Text()]; ok { return val, nil } return op.Text(), nil } // Where implements FragmentBuilder. func (fb *BaseFragmentBuilder) Where(where *ast.Where) (string, error) { sql, err := fb.Expr(where.Expr()) if err != nil { return "", err } sql = "WHERE " + sql return sql, nil } // Expr implements FragmentBuilder. func (fb *BaseFragmentBuilder) Expr(expr *ast.Expr) (string, error) { var sql string for _, child := range expr.Children() { switch child := child.(type) { case *ast.Selector: val := child.SelValue() parts := strings.Split(val, ".") identifier := fb.ColQuote + strings.Join(parts, fb.ColQuote+"."+fb.ColQuote) + fb.ColQuote sql = sql + " " + identifier case *ast.Operator: val, err := fb.Operator(child) if err != nil { return "", err } sql = sql + " " + val case *ast.Expr: val, err := fb.Expr(child) if err != nil { return "", err } sql = sql + " " + val default: sql = sql + " " + child.Text() } } return sql, nil } // SelectAll implements FragmentBuilder. func (fb *BaseFragmentBuilder) SelectAll(tblSel *ast.TblSelector) (string, error) { sql := fmt.Sprintf("SELECT * FROM %v%s%v", fb.Quote, tblSel.SelValue(), fb.Quote) return sql, nil } // Function implements FragmentBuilder. func (fb *BaseFragmentBuilder) Function(fn *ast.Func) (string, error) { buf := &bytes.Buffer{} children := fn.Children() if len(children) == 0 { // no children, let's just grab the direct text // HACK: this stuff basically doesn't work at all... // but for COUNT(), here's a quick hack to make it work on some DBs if fn.Context().GetText() == "count()" { buf.WriteString("COUNT(*)") } else { buf.WriteString(fn.Context().GetText()) } return buf.String(), nil } buf.WriteString(fn.FuncName()) buf.WriteRune('(') for i, child := range children { if i > 0 { buf.WriteString(", ") } switch child := child.(type) { case *ast.ColSelector: buf.WriteString(child.SelValue()) default: fb.Log.Debugf("unknown AST child node type %T", child) } } buf.WriteRune(')') sql := buf.String() return sql, nil } // FromTable implements FragmentBuilder. func (fb *BaseFragmentBuilder) FromTable(tblSel *ast.TblSelector) (string, error) { tblName := tblSel.SelValue() if tblName == "" { return "", errz.Errorf("selector has empty table name: %q", tblSel.Text()) } clause := fmt.Sprintf("FROM %v%s%v", fb.Quote, tblSel.SelValue(), fb.Quote) return clause, nil } // Join implements FragmentBuilder. func (fb *BaseFragmentBuilder) Join(fnJoin *ast.Join) (string, error) { joinType := "INNER JOIN" onClause := "" if len(fnJoin.Children()) == 0 { joinType = "NATURAL JOIN" } else { joinExpr, ok := fnJoin.Children()[0].(*ast.JoinConstraint) if !ok { return "", errz.Errorf("expected *FnJoinExpr but got %T", fnJoin.Children()[0]) } leftOperand := "" operator := "" rightOperand := "" if len(joinExpr.Children()) == 1 { // It's a single col selector colSel, ok := joinExpr.Children()[0].(*ast.ColSelector) if !ok { return "", errz.Errorf("expected *ColSelector but got %T", joinExpr.Children()[0]) } leftOperand = fmt.Sprintf("%s%s%s.%s%s%s", fb.Quote, fnJoin.LeftTbl().SelValue(), fb.Quote, fb.Quote, colSel.SelValue(), fb.Quote) operator = "==" rightOperand = fmt.Sprintf("%s%s%s.%s%s%s", fb.Quote, fnJoin.RightTbl().SelValue(), fb.Quote, fb.Quote, colSel.SelValue(), fb.Quote) } else { var err error leftOperand, err = quoteTableOrColSelector(fb.Quote, joinExpr.Children()[0].Text()) if err != nil { return "", err } operator = joinExpr.Children()[1].Text() rightOperand, err = quoteTableOrColSelector(fb.Quote, joinExpr.Children()[2].Text()) if err != nil { return "", err } } if operator == "==" { operator = "=" } onClause = fmt.Sprintf("ON %s %s %s", leftOperand, operator, rightOperand) } sql := fmt.Sprintf("FROM %s%s%s %s %s%s%s", fb.Quote, fnJoin.LeftTbl().SelValue(), fb.Quote, joinType, fb.Quote, fnJoin.RightTbl().SelValue(), fb.Quote) sql = sqlAppend(sql, onClause) return sql, nil } // sqlAppend is a convenience function for building the SQL string. // The main purpose is to ensure that there's always a consistent amount // of whitespace. Thus, if existing has a space suffix and add has a // space prefix, the returned string will only have one space. If add // is the empty string or just whitespace, this function simply // returns existing. func sqlAppend(existing, add string) string { add = strings.TrimSpace(add) if add == "" { return existing } existing = strings.TrimSpace(existing) return existing + " " + add } // quoteTableOrColSelector returns a quote table, col, or table/col // selector for use in a SQL statement. For example: // // .table --> "table" // .col --> "col" // .table.col --> "table"."col" // // Thus, the selector must have exactly one or two periods. func quoteTableOrColSelector(quote, selector string) (string, error) { if len(selector) < 2 || selector[0] != '.' { return "", errz.Errorf("invalid selector: %s", selector) } parts := strings.Split(selector[1:], ".") switch len(parts) { case 1: return quote + parts[0] + quote, nil case 2: return quote + parts[0] + quote + "." + quote + parts[1] + quote, nil default: return "", errz.Errorf("invalid selector: %s", selector) } } // Range implements FragmentBuilder. func (fb *BaseFragmentBuilder) Range(rr *ast.RowRange) (string, error) { if rr == nil { return "", nil } if rr.Limit < 0 && rr.Offset < 0 { return "", nil } limit := "" offset := "" if rr.Limit > -1 { limit = fmt.Sprintf(" LIMIT %d", rr.Limit) } if rr.Offset > -1 { offset = fmt.Sprintf(" OFFSET %d", rr.Offset) if rr.Limit == -1 { // MySQL requires a LIMIT if OFFSET is used. Therefore // we make the LIMIT a very large number limit = fmt.Sprintf(" LIMIT %d", math.MaxInt64) } } sql := limit + offset return sql, nil } // SelectCols implements FragmentBuilder. func (fb *BaseFragmentBuilder) SelectCols(cols []ast.ColExpr) (string, error) { if len(cols) == 0 { return "SELECT *", nil } vals := make([]string, len(cols)) for i, col := range cols { colText, err := col.ColExpr() if err != nil { return "", errz.Errorf("unable to extract col expr from %q: %v", col, err) } fn, ok := col.(*ast.Func) if ok { // it's a function vals[i], err = fb.Function(fn) if err != nil { return "", err } continue } if !col.IsColName() { // it's a function or expression vals[i] = colText // for now, we just return the raw text continue } // it's a column name, e.g. "uid" or "user.uid" if !strings.ContainsRune(colText, '.') { // it's a regular (non-scoped) col name, e.g. "uid" vals[i] = fmt.Sprintf("%s%s%s", fb.Quote, colText, fb.Quote) continue } // the expr contains a period, so it's likely scoped, e.g. "user.uid" parts := strings.Split(colText, ".") if len(parts) != 2 { return "", errz.Errorf("expected scoped col expr %q to have 2 parts, but got: %v", col, parts) } vals[i] = fmt.Sprintf("%s%s%s.%s%s%s", fb.Quote, parts[0], fb.Quote, fb.Quote, parts[1], fb.Quote) } text := "SELECT " + strings.Join(vals, ", ") return text, nil } // BaseQueryBuilder is a default implementation // of sqlbuilder.QueryBuilder. type BaseQueryBuilder struct { SelectClause string FromClause string WhereClause string RangeClause string OrderByClause string } // SetSelect implements QueryBuilder. func (qb *BaseQueryBuilder) SetSelect(cols string) { qb.SelectClause = cols } // SetFrom implements QueryBuilder. func (qb *BaseQueryBuilder) SetFrom(from string) { qb.FromClause = from } // SetWhere implements QueryBuilder. func (qb *BaseQueryBuilder) SetWhere(where string) { qb.WhereClause = where } // SetRange implements QueryBuilder. func (qb *BaseQueryBuilder) SetRange(rng string) { qb.RangeClause = rng } // SQL implements QueryBuilder. func (qb *BaseQueryBuilder) SQL() (string, error) { buf := &bytes.Buffer{} buf.WriteString(qb.SelectClause) buf.WriteRune(' ') buf.WriteString(qb.FromClause) if qb.WhereClause != "" { buf.WriteRune(' ') buf.WriteString(qb.WhereClause) } if qb.OrderByClause != "" { buf.WriteRune(' ') buf.WriteString(qb.OrderByClause) } if qb.RangeClause != "" { buf.WriteRune(' ') buf.WriteString(qb.RangeClause) } return buf.String(), nil }