Support adding columns with UNIQUE, NOT NULL and DEFAULT constraints (#30)

Allow the **add column** operation to add columns with `NOT NULL`,
`UNIQUE` and `DEFAULT` constraints by re-using the SQL generation code
that adds columns to tables.
This commit is contained in:
Andrew Farries 2023-07-13 08:38:43 +01:00 committed by GitHub
parent 87f0be5ea5
commit dce42da85a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 56 additions and 34 deletions

View File

@ -10,6 +10,17 @@
"nullable": true
}
}
},
{
"add_column": {
"table": "products",
"column": {
"name": "stock",
"type": "int",
"nullable": false,
"default": "100"
}
}
}
]
}

View File

@ -20,12 +20,9 @@ var _ Operation = (*OpAddColumn)(nil)
func (o *OpAddColumn) Start(ctx context.Context, conn *sql.DB, s *schema.Schema) error {
table := s.GetTable(o.Table)
if o.Column.Nullable {
if err := addNullableColumn(ctx, conn, o, table); err != nil {
return fmt.Errorf("failed to start add column operation: %w", err)
}
} else {
return errors.New("addition of non-nullable columns not implemented")
if err := addColumn(ctx, conn, *o, table); err != nil {
return fmt.Errorf("failed to start add column operation: %w", err)
}
table.AddColumn(o.Column.Name, schema.Column{
@ -65,14 +62,23 @@ func (o *OpAddColumn) Validate(ctx context.Context, s *schema.Schema) error {
return ColumnAlreadyExistsError{Name: o.Column.Name, Table: o.Table}
}
if !o.Column.Nullable && o.Column.Default == nil {
return errors.New("adding non-nullable columns without a default is not supported")
}
if o.Column.PrimaryKey {
return errors.New("adding primary key columns is not supported")
}
return nil
}
func addNullableColumn(ctx context.Context, conn *sql.DB, o *OpAddColumn, t *schema.Table) error {
_, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE %s ADD COLUMN %s %s",
func addColumn(ctx context.Context, conn *sql.DB, o OpAddColumn, t *schema.Table) error {
o.Column.Name = TemporaryName(o.Column.Name)
_, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE %s ADD COLUMN %s",
pq.QuoteIdentifier(t.Name),
pq.QuoteIdentifier(TemporaryName(o.Column.Name)),
o.Column.Type,
ColumnToSQL(o.Column),
))
return err
}

View File

@ -50,30 +50,6 @@ func (o *OpCreateTable) Start(ctx context.Context, conn *sql.DB, s *schema.Schem
return nil
}
func columnsToSQL(cols []Column) string {
var sql string
for i, col := range cols {
if i > 0 {
sql += ", "
}
sql += fmt.Sprintf("%s %s", pq.QuoteIdentifier(col.Name), col.Type)
if col.PrimaryKey {
sql += " PRIMARY KEY"
}
if col.Unique {
sql += " UNIQUE"
}
if !col.Nullable {
sql += " NOT NULL"
}
if col.Default != nil {
sql += fmt.Sprintf(" DEFAULT %s", *col.Default)
}
}
return sql
}
func (o *OpCreateTable) Complete(ctx context.Context, conn *sql.DB) error {
tempName := TemporaryName(o.Name)
_, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE IF EXISTS %s RENAME TO %s",
@ -97,3 +73,32 @@ func (o *OpCreateTable) Validate(ctx context.Context, s *schema.Schema) error {
}
return nil
}
func columnsToSQL(cols []Column) string {
var sql string
for i, col := range cols {
if i > 0 {
sql += ", "
}
sql += ColumnToSQL(col)
}
return sql
}
func ColumnToSQL(col Column) string {
sql := fmt.Sprintf("%s %s", pq.QuoteIdentifier(col.Name), col.Type)
if col.PrimaryKey {
sql += " PRIMARY KEY"
}
if col.Unique {
sql += " UNIQUE"
}
if !col.Nullable {
sql += " NOT NULL"
}
if col.Default != nil {
sql += fmt.Sprintf(" DEFAULT %s", pq.QuoteLiteral(*col.Default))
}
return sql
}