mirror of
https://github.com/xataio/pgroll.git
synced 2024-09-11 13:55:28 +03:00
Support adding columns with UNIQUE
, NOT NULL
and DEFAULT
constraints (#30)
Allow the **add column** operation to add columns with `NOT NULL`, `UNIQUE` and `DEFAULT` constraints by re-using the SQL generation code that adds columns to tables.
This commit is contained in:
parent
87f0be5ea5
commit
dce42da85a
@ -10,6 +10,17 @@
|
||||
"nullable": true
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"add_column": {
|
||||
"table": "products",
|
||||
"column": {
|
||||
"name": "stock",
|
||||
"type": "int",
|
||||
"nullable": false,
|
||||
"default": "100"
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
|
@ -20,12 +20,9 @@ var _ Operation = (*OpAddColumn)(nil)
|
||||
|
||||
func (o *OpAddColumn) Start(ctx context.Context, conn *sql.DB, s *schema.Schema) error {
|
||||
table := s.GetTable(o.Table)
|
||||
if o.Column.Nullable {
|
||||
if err := addNullableColumn(ctx, conn, o, table); err != nil {
|
||||
return fmt.Errorf("failed to start add column operation: %w", err)
|
||||
}
|
||||
} else {
|
||||
return errors.New("addition of non-nullable columns not implemented")
|
||||
|
||||
if err := addColumn(ctx, conn, *o, table); err != nil {
|
||||
return fmt.Errorf("failed to start add column operation: %w", err)
|
||||
}
|
||||
|
||||
table.AddColumn(o.Column.Name, schema.Column{
|
||||
@ -65,14 +62,23 @@ func (o *OpAddColumn) Validate(ctx context.Context, s *schema.Schema) error {
|
||||
return ColumnAlreadyExistsError{Name: o.Column.Name, Table: o.Table}
|
||||
}
|
||||
|
||||
if !o.Column.Nullable && o.Column.Default == nil {
|
||||
return errors.New("adding non-nullable columns without a default is not supported")
|
||||
}
|
||||
|
||||
if o.Column.PrimaryKey {
|
||||
return errors.New("adding primary key columns is not supported")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func addNullableColumn(ctx context.Context, conn *sql.DB, o *OpAddColumn, t *schema.Table) error {
|
||||
_, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE %s ADD COLUMN %s %s",
|
||||
func addColumn(ctx context.Context, conn *sql.DB, o OpAddColumn, t *schema.Table) error {
|
||||
o.Column.Name = TemporaryName(o.Column.Name)
|
||||
|
||||
_, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE %s ADD COLUMN %s",
|
||||
pq.QuoteIdentifier(t.Name),
|
||||
pq.QuoteIdentifier(TemporaryName(o.Column.Name)),
|
||||
o.Column.Type,
|
||||
ColumnToSQL(o.Column),
|
||||
))
|
||||
return err
|
||||
}
|
||||
|
@ -50,30 +50,6 @@ func (o *OpCreateTable) Start(ctx context.Context, conn *sql.DB, s *schema.Schem
|
||||
return nil
|
||||
}
|
||||
|
||||
func columnsToSQL(cols []Column) string {
|
||||
var sql string
|
||||
for i, col := range cols {
|
||||
if i > 0 {
|
||||
sql += ", "
|
||||
}
|
||||
sql += fmt.Sprintf("%s %s", pq.QuoteIdentifier(col.Name), col.Type)
|
||||
|
||||
if col.PrimaryKey {
|
||||
sql += " PRIMARY KEY"
|
||||
}
|
||||
if col.Unique {
|
||||
sql += " UNIQUE"
|
||||
}
|
||||
if !col.Nullable {
|
||||
sql += " NOT NULL"
|
||||
}
|
||||
if col.Default != nil {
|
||||
sql += fmt.Sprintf(" DEFAULT %s", *col.Default)
|
||||
}
|
||||
}
|
||||
return sql
|
||||
}
|
||||
|
||||
func (o *OpCreateTable) Complete(ctx context.Context, conn *sql.DB) error {
|
||||
tempName := TemporaryName(o.Name)
|
||||
_, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE IF EXISTS %s RENAME TO %s",
|
||||
@ -97,3 +73,32 @@ func (o *OpCreateTable) Validate(ctx context.Context, s *schema.Schema) error {
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func columnsToSQL(cols []Column) string {
|
||||
var sql string
|
||||
for i, col := range cols {
|
||||
if i > 0 {
|
||||
sql += ", "
|
||||
}
|
||||
sql += ColumnToSQL(col)
|
||||
}
|
||||
return sql
|
||||
}
|
||||
|
||||
func ColumnToSQL(col Column) string {
|
||||
sql := fmt.Sprintf("%s %s", pq.QuoteIdentifier(col.Name), col.Type)
|
||||
|
||||
if col.PrimaryKey {
|
||||
sql += " PRIMARY KEY"
|
||||
}
|
||||
if col.Unique {
|
||||
sql += " UNIQUE"
|
||||
}
|
||||
if !col.Nullable {
|
||||
sql += " NOT NULL"
|
||||
}
|
||||
if col.Default != nil {
|
||||
sql += fmt.Sprintf(" DEFAULT %s", pq.QuoteLiteral(*col.Default))
|
||||
}
|
||||
return sql
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user