Move constraint name and expression into a new CheckConstraint struct (#107)

Move the `ConstraintName` and `Check` `string` fields on an
`alter_column` operation into a new `CheckConstraint` struct and make
validation a method on that new struct.

This is to facilitate being able to create tables and columns with
`CHECK` constraints in later PRs (#108, #109).
This commit is contained in:
Andrew Farries 2023-09-19 10:49:40 +01:00 committed by GitHub
parent 2a6a0e8c33
commit 0c7ecf2887
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 121 additions and 73 deletions

View File

@ -5,8 +5,10 @@
"alter_column": {
"table": "posts",
"column": "title",
"constraint_name": "title_length",
"check": "length(title) > 3",
"check": {
"name": "title_length",
"constraint": "length(title) > 3"
},
"up": "(SELECT CASE WHEN length(title) <= 3 THEN LPAD(title, 4, '-') ELSE title END)",
"down": "title"
}

18
pkg/migrations/check.go Normal file
View File

@ -0,0 +1,18 @@
package migrations
type CheckConstraint struct {
Name string `json:"name"`
Constraint string `json:"constraint"`
}
func (c *CheckConstraint) Validate() error {
if c.Name == "" {
return FieldRequiredError{Name: "name"}
}
if c.Constraint == "" {
return FieldRequiredError{Name: "constraint"}
}
return nil
}

View File

@ -100,6 +100,23 @@ func (e ColumnReferenceError) Error() string {
e.Err.Error())
}
type CheckConstraintError struct {
Table string
Column string
Err error
}
func (e CheckConstraintError) Unwrap() error {
return e.Err
}
func (e CheckConstraintError) Error() string {
return fmt.Sprintf("check constraint on column %q in table %q is invalid: %s",
e.Table,
e.Column,
e.Err.Error())
}
type NoUpSQLAllowedError struct{}
func (e NoUpSQLAllowedError) Error() string {

View File

@ -9,16 +9,15 @@ import (
)
type OpAlterColumn struct {
Table string `json:"table"`
Column string `json:"column"`
Name string `json:"name"`
Type string `json:"type"`
ConstraintName string `json:"constraint_name"`
Check string `json:"check"`
References *ForeignKeyReference `json:"references"`
NotNull *bool `json:"not_null"`
Up string `json:"up"`
Down string `json:"down"`
Table string `json:"table"`
Column string `json:"column"`
Name string `json:"name"`
Type string `json:"type"`
Check *CheckConstraint `json:"check"`
References *ForeignKeyReference `json:"references"`
NotNull *bool `json:"not_null"`
Up string `json:"up"`
Down string `json:"down"`
}
var _ Operation = (*OpAlterColumn)(nil)
@ -95,14 +94,13 @@ func (o *OpAlterColumn) innerOperation() Operation {
Down: o.Down,
}
case o.Check != "":
case o.Check != nil:
return &OpSetCheckConstraint{
Table: o.Table,
Column: o.Column,
ConstraintName: o.ConstraintName,
Check: o.Check,
Up: o.Up,
Down: o.Down,
Table: o.Table,
Column: o.Column,
Check: *o.Check,
Up: o.Up,
Down: o.Down,
}
case o.References != nil:
@ -136,7 +134,7 @@ func (o *OpAlterColumn) numChanges() int {
if o.Type != "" {
fieldsSet++
}
if o.Check != "" {
if o.Check != nil {
fieldsSet++
}
if o.References != nil {

View File

@ -38,12 +38,14 @@ func TestDropConstraint(t *testing.T) {
Name: "02_add_check_constraint",
Operations: migrations.Operations{
&migrations.OpAlterColumn{
Table: "posts",
Column: "title",
ConstraintName: "check_title_length",
Check: "length(title) > 3",
Up: "(SELECT CASE WHEN length(title) <= 3 THEN LPAD(title, 4, '-') ELSE title END)",
Down: "title",
Table: "posts",
Column: "title",
Check: &migrations.CheckConstraint{
Name: "check_title_length",
Constraint: "length(title) > 3",
},
Up: "(SELECT CASE WHEN length(title) <= 3 THEN LPAD(title, 4, '-') ELSE title END)",
Down: "title",
},
},
},
@ -165,12 +167,14 @@ func TestDropConstraint(t *testing.T) {
Name: "02_add_check_constraint",
Operations: migrations.Operations{
&migrations.OpAlterColumn{
Table: "posts",
Column: "title",
ConstraintName: "check_title_length",
Check: "length(title) > 3",
Up: "(SELECT CASE WHEN length(title) <= 3 THEN LPAD(title, 4, '-') ELSE title END)",
Down: "title",
Table: "posts",
Column: "title",
Check: &migrations.CheckConstraint{
Name: "check_title_length",
Constraint: "length(title) > 3",
},
Up: "(SELECT CASE WHEN length(title) <= 3 THEN LPAD(title, 4, '-') ELSE title END)",
Down: "title",
},
},
},
@ -397,12 +401,14 @@ func TestDropConstraintValidation(t *testing.T) {
Name: "02_add_check_constraint",
Operations: migrations.Operations{
&migrations.OpAlterColumn{
Table: "posts",
Column: "title",
ConstraintName: "check_title_length",
Check: "length(title) > 3",
Up: "(SELECT CASE WHEN length(title) <= 3 THEN LPAD(title, 4, '-') ELSE title END)",
Down: "title",
Table: "posts",
Column: "title",
Check: &migrations.CheckConstraint{
Name: "check_title_length",
Constraint: "length(title) > 3",
},
Up: "(SELECT CASE WHEN length(title) <= 3 THEN LPAD(title, 4, '-') ELSE title END)",
Down: "title",
},
},
}

View File

@ -11,12 +11,11 @@ import (
)
type OpSetCheckConstraint struct {
Table string `json:"table"`
Column string `json:"column"`
ConstraintName string `json:"constraint_name"`
Check string `json:"check"`
Up string `json:"up"`
Down string `json:"down"`
Table string `json:"table"`
Column string `json:"column"`
Check CheckConstraint `json:"check"`
Up string `json:"up"`
Down string `json:"down"`
}
var _ Operation = (*OpSetCheckConstraint)(nil)
@ -83,7 +82,7 @@ func (o *OpSetCheckConstraint) Complete(ctx context.Context, conn *sql.DB) error
// Validate the check constraint
_, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE IF EXISTS %s VALIDATE CONSTRAINT %s",
pq.QuoteIdentifier(o.Table),
pq.QuoteIdentifier(o.ConstraintName)))
pq.QuoteIdentifier(o.Check.Name)))
if err != nil {
return err
}
@ -149,12 +148,12 @@ func (o *OpSetCheckConstraint) Rollback(ctx context.Context, conn *sql.DB) error
}
func (o *OpSetCheckConstraint) Validate(ctx context.Context, s *schema.Schema) error {
if o.Check == "" {
return FieldRequiredError{Name: "check"}
}
if o.ConstraintName == "" {
return FieldRequiredError{Name: "constraint_name"}
if err := o.Check.Validate(); err != nil {
return CheckConstraintError{
Table: o.Table,
Column: o.Column,
Err: err,
}
}
if o.Up == "" {
@ -171,8 +170,8 @@ func (o *OpSetCheckConstraint) Validate(ctx context.Context, s *schema.Schema) e
func (o *OpSetCheckConstraint) addCheckConstraint(ctx context.Context, conn *sql.DB) error {
_, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE %s ADD CONSTRAINT %s CHECK (%s) NOT VALID",
pq.QuoteIdentifier(o.Table),
pq.QuoteIdentifier(o.ConstraintName),
rewriteCheckExpression(o.Check, o.Column, TemporaryName(o.Column)),
pq.QuoteIdentifier(o.Check.Name),
rewriteCheckExpression(o.Check.Constraint, o.Column, TemporaryName(o.Column)),
))
return err

View File

@ -37,12 +37,14 @@ func TestSetCheckConstraint(t *testing.T) {
Name: "02_add_check_constraint",
Operations: migrations.Operations{
&migrations.OpAlterColumn{
Table: "posts",
Column: "title",
ConstraintName: "check_title_length",
Check: "length(title) > 3",
Up: "(SELECT CASE WHEN length(title) <= 3 THEN LPAD(title, 4, '-') ELSE title END)",
Down: "title",
Table: "posts",
Column: "title",
Check: &migrations.CheckConstraint{
Name: "check_title_length",
Constraint: "length(title) > 3",
},
Up: "(SELECT CASE WHEN length(title) <= 3 THEN LPAD(title, 4, '-') ELSE title END)",
Down: "title",
},
},
},
@ -168,14 +170,16 @@ func TestSetCheckConstraintValidation(t *testing.T) {
&migrations.OpAlterColumn{
Table: "posts",
Column: "title",
Check: "length(title) > 3",
Up: "(SELECT CASE WHEN length(title) <= 3 THEN LPAD(title, 4, '-') ELSE title END)",
Down: "title",
Check: &migrations.CheckConstraint{
Constraint: "length(title) > 3",
},
Up: "(SELECT CASE WHEN length(title) <= 3 THEN LPAD(title, 4, '-') ELSE title END)",
Down: "title",
},
},
},
},
wantStartErr: migrations.FieldRequiredError{Name: "constraint_name"},
wantStartErr: migrations.FieldRequiredError{Name: "name"},
},
{
name: "up SQL is mandatory",
@ -185,11 +189,13 @@ func TestSetCheckConstraintValidation(t *testing.T) {
Name: "02_add_check_constraint",
Operations: migrations.Operations{
&migrations.OpAlterColumn{
Table: "posts",
Column: "title",
ConstraintName: "check_title_length",
Check: "length(title) > 3",
Down: "title",
Table: "posts",
Column: "title",
Check: &migrations.CheckConstraint{
Name: "check_title_length",
Constraint: "length(title) > 3",
},
Down: "title",
},
},
},
@ -204,11 +210,13 @@ func TestSetCheckConstraintValidation(t *testing.T) {
Name: "02_add_check_constraint",
Operations: migrations.Operations{
&migrations.OpAlterColumn{
Table: "posts",
Column: "title",
ConstraintName: "check_title_length",
Check: "length(title) > 3",
Up: "(SELECT CASE WHEN length(title) <= 3 THEN LPAD(title, 4, '-') ELSE title END)",
Table: "posts",
Column: "title",
Check: &migrations.CheckConstraint{
Name: "check_title_length",
Constraint: "length(title) > 3",
},
Up: "(SELECT CASE WHEN length(title) <= 3 THEN LPAD(title, 4, '-') ELSE title END)",
},
},
},