Make naming CHECK constraints mandatory (#99)

Make it required to supply a name for the `CHECK` constraint when adding
one with the `set_check_constraint` operation.

It should be possible to drop constraints with a later migration (not
yet implemented), so requiring a name and not relying on automatic
generation of constraint names will make this easier.

The same thing was done for indexes in
https://github.com/xataio/pg-roll/pull/59
This commit is contained in:
Andrew Farries 2023-09-14 11:06:13 +01:00 committed by GitHub
parent 85f9ae0df3
commit c1b8c65dd5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 61 additions and 44 deletions

View File

@ -5,6 +5,7 @@
"alter_column": {
"table": "posts",
"column": "title",
"constraint_name": "title_length",
"check": "length(title) > 3",
"up": "(SELECT CASE WHEN length(title) <= 3 THEN LPAD(title, 4, '-') ELSE title END)",
"down": "title"

View File

@ -9,15 +9,16 @@ import (
)
type OpAlterColumn struct {
Table string `json:"table"`
Column string `json:"column"`
Name string `json:"name"`
Type string `json:"type"`
Check string `json:"check"`
References *ColumnReference `json:"references"`
NotNull *bool `json:"not_null"`
Up string `json:"up"`
Down string `json:"down"`
Table string `json:"table"`
Column string `json:"column"`
Name string `json:"name"`
Type string `json:"type"`
ConstraintName string `json:"constraint_name"`
Check string `json:"check"`
References *ColumnReference `json:"references"`
NotNull *bool `json:"not_null"`
Up string `json:"up"`
Down string `json:"down"`
}
var _ Operation = (*OpAlterColumn)(nil)
@ -96,11 +97,12 @@ func (o *OpAlterColumn) innerOperation() Operation {
case o.Check != "":
return &OpSetCheckConstraint{
Table: o.Table,
Column: o.Column,
Check: o.Check,
Up: o.Up,
Down: o.Down,
Table: o.Table,
Column: o.Column,
ConstraintName: o.ConstraintName,
Check: o.Check,
Up: o.Up,
Down: o.Down,
}
case o.References != nil:

View File

@ -11,11 +11,12 @@ import (
)
type OpSetCheckConstraint struct {
Table string `json:"table"`
Column string `json:"column"`
Check string `json:"check"`
Up string `json:"up"`
Down string `json:"down"`
Table string `json:"table"`
Column string `json:"column"`
ConstraintName string `json:"constraint_name"`
Check string `json:"check"`
Up string `json:"up"`
Down string `json:"down"`
}
var _ Operation = (*OpSetCheckConstraint)(nil)
@ -79,12 +80,10 @@ func (o *OpSetCheckConstraint) Start(ctx context.Context, conn *sql.DB, stateSch
}
func (o *OpSetCheckConstraint) Complete(ctx context.Context, conn *sql.DB) error {
tempName := TemporaryName(o.Column)
// Validate the check constraint
_, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE IF EXISTS %s VALIDATE CONSTRAINT %s",
pq.QuoteIdentifier(o.Table),
pq.QuoteIdentifier(CheckConstraintName(o.Table, tempName))))
pq.QuoteIdentifier(o.ConstraintName)))
if err != nil {
return err
}
@ -120,13 +119,6 @@ func (o *OpSetCheckConstraint) Complete(ctx context.Context, conn *sql.DB) error
return err
}
// Rename the check constraint to use the final (non-temporary) column name.
_, err = conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE IF EXISTS %s RENAME CONSTRAINT %s TO %s",
pq.QuoteIdentifier(o.Table),
pq.QuoteIdentifier(CheckConstraintName(o.Table, tempName)),
pq.QuoteIdentifier(CheckConstraintName(o.Table, o.Column)),
))
return err
}
@ -161,6 +153,10 @@ func (o *OpSetCheckConstraint) Validate(ctx context.Context, s *schema.Schema) e
return FieldRequiredError{Name: "check"}
}
if o.ConstraintName == "" {
return FieldRequiredError{Name: "constraint_name"}
}
if o.Up == "" {
return FieldRequiredError{Name: "up"}
}
@ -172,14 +168,10 @@ func (o *OpSetCheckConstraint) Validate(ctx context.Context, s *schema.Schema) e
return nil
}
func CheckConstraintName(tableName, columnName string) string {
return fmt.Sprintf("_pgroll_check_%s_%s", tableName, columnName)
}
func (o *OpSetCheckConstraint) addCheckConstraint(ctx context.Context, conn *sql.DB) error {
_, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE %s ADD CONSTRAINT %s CHECK (%s) NOT VALID",
pq.QuoteIdentifier(o.Table),
pq.QuoteIdentifier(CheckConstraintName(o.Table, TemporaryName(o.Column))),
pq.QuoteIdentifier(o.ConstraintName),
rewriteCheckExpression(o.Check, o.Column, TemporaryName(o.Column)),
))

View File

@ -37,11 +37,12 @@ func TestSetCheckConstraint(t *testing.T) {
Name: "02_add_check_constraint",
Operations: migrations.Operations{
&migrations.OpAlterColumn{
Table: "posts",
Column: "title",
Check: "length(title) > 3",
Up: "(SELECT CASE WHEN length(title) <= 3 THEN LPAD(title, 4, '-') ELSE title END)",
Down: "title",
Table: "posts",
Column: "title",
ConstraintName: "check_title_length",
Check: "length(title) > 3",
Up: "(SELECT CASE WHEN length(title) <= 3 THEN LPAD(title, 4, '-') ELSE title END)",
Down: "title",
},
},
},
@ -158,7 +159,7 @@ func TestSetCheckConstraintValidation(t *testing.T) {
ExecuteTests(t, TestCases{
{
name: "up SQL is mandatory",
name: "name of the check constraint is mandatory",
migrations: []migrations.Migration{
createTableMigration,
{
@ -168,11 +169,31 @@ func TestSetCheckConstraintValidation(t *testing.T) {
Table: "posts",
Column: "title",
Check: "length(title) > 3",
Up: "(SELECT CASE WHEN length(title) <= 3 THEN LPAD(title, 4, '-') ELSE title END)",
Down: "title",
},
},
},
},
wantStartErr: migrations.FieldRequiredError{Name: "constraint_name"},
},
{
name: "up SQL is mandatory",
migrations: []migrations.Migration{
createTableMigration,
{
Name: "02_add_check_constraint",
Operations: migrations.Operations{
&migrations.OpAlterColumn{
Table: "posts",
Column: "title",
ConstraintName: "check_title_length",
Check: "length(title) > 3",
Down: "title",
},
},
},
},
wantStartErr: migrations.FieldRequiredError{Name: "up"},
},
{
@ -183,10 +204,11 @@ func TestSetCheckConstraintValidation(t *testing.T) {
Name: "02_add_check_constraint",
Operations: migrations.Operations{
&migrations.OpAlterColumn{
Table: "posts",
Column: "title",
Check: "length(title) > 3",
Up: "(SELECT CASE WHEN length(title) <= 3 THEN LPAD(title, 4, '-') ELSE title END)",
Table: "posts",
Column: "title",
ConstraintName: "check_title_length",
Check: "length(title) > 3",
Up: "(SELECT CASE WHEN length(title) <= 3 THEN LPAD(title, 4, '-') ELSE title END)",
},
},
},