mirror of
https://github.com/xataio/pgroll.git
synced 2024-10-05 17:47:59 +03:00
Move constraint name and expression into a new CheckConstraint
struct (#107)
Move the `ConstraintName` and `Check` `string` fields on an `alter_column` operation into a new `CheckConstraint` struct and make validation a method on that new struct. This is to facilitate being able to create tables and columns with `CHECK` constraints in later PRs (#108, #109).
This commit is contained in:
parent
2a6a0e8c33
commit
0c7ecf2887
@ -5,8 +5,10 @@
|
||||
"alter_column": {
|
||||
"table": "posts",
|
||||
"column": "title",
|
||||
"constraint_name": "title_length",
|
||||
"check": "length(title) > 3",
|
||||
"check": {
|
||||
"name": "title_length",
|
||||
"constraint": "length(title) > 3"
|
||||
},
|
||||
"up": "(SELECT CASE WHEN length(title) <= 3 THEN LPAD(title, 4, '-') ELSE title END)",
|
||||
"down": "title"
|
||||
}
|
||||
|
18
pkg/migrations/check.go
Normal file
18
pkg/migrations/check.go
Normal file
@ -0,0 +1,18 @@
|
||||
package migrations
|
||||
|
||||
type CheckConstraint struct {
|
||||
Name string `json:"name"`
|
||||
Constraint string `json:"constraint"`
|
||||
}
|
||||
|
||||
func (c *CheckConstraint) Validate() error {
|
||||
if c.Name == "" {
|
||||
return FieldRequiredError{Name: "name"}
|
||||
}
|
||||
|
||||
if c.Constraint == "" {
|
||||
return FieldRequiredError{Name: "constraint"}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
@ -100,6 +100,23 @@ func (e ColumnReferenceError) Error() string {
|
||||
e.Err.Error())
|
||||
}
|
||||
|
||||
type CheckConstraintError struct {
|
||||
Table string
|
||||
Column string
|
||||
Err error
|
||||
}
|
||||
|
||||
func (e CheckConstraintError) Unwrap() error {
|
||||
return e.Err
|
||||
}
|
||||
|
||||
func (e CheckConstraintError) Error() string {
|
||||
return fmt.Sprintf("check constraint on column %q in table %q is invalid: %s",
|
||||
e.Table,
|
||||
e.Column,
|
||||
e.Err.Error())
|
||||
}
|
||||
|
||||
type NoUpSQLAllowedError struct{}
|
||||
|
||||
func (e NoUpSQLAllowedError) Error() string {
|
||||
|
@ -9,16 +9,15 @@ import (
|
||||
)
|
||||
|
||||
type OpAlterColumn struct {
|
||||
Table string `json:"table"`
|
||||
Column string `json:"column"`
|
||||
Name string `json:"name"`
|
||||
Type string `json:"type"`
|
||||
ConstraintName string `json:"constraint_name"`
|
||||
Check string `json:"check"`
|
||||
References *ForeignKeyReference `json:"references"`
|
||||
NotNull *bool `json:"not_null"`
|
||||
Up string `json:"up"`
|
||||
Down string `json:"down"`
|
||||
Table string `json:"table"`
|
||||
Column string `json:"column"`
|
||||
Name string `json:"name"`
|
||||
Type string `json:"type"`
|
||||
Check *CheckConstraint `json:"check"`
|
||||
References *ForeignKeyReference `json:"references"`
|
||||
NotNull *bool `json:"not_null"`
|
||||
Up string `json:"up"`
|
||||
Down string `json:"down"`
|
||||
}
|
||||
|
||||
var _ Operation = (*OpAlterColumn)(nil)
|
||||
@ -95,14 +94,13 @@ func (o *OpAlterColumn) innerOperation() Operation {
|
||||
Down: o.Down,
|
||||
}
|
||||
|
||||
case o.Check != "":
|
||||
case o.Check != nil:
|
||||
return &OpSetCheckConstraint{
|
||||
Table: o.Table,
|
||||
Column: o.Column,
|
||||
ConstraintName: o.ConstraintName,
|
||||
Check: o.Check,
|
||||
Up: o.Up,
|
||||
Down: o.Down,
|
||||
Table: o.Table,
|
||||
Column: o.Column,
|
||||
Check: *o.Check,
|
||||
Up: o.Up,
|
||||
Down: o.Down,
|
||||
}
|
||||
|
||||
case o.References != nil:
|
||||
@ -136,7 +134,7 @@ func (o *OpAlterColumn) numChanges() int {
|
||||
if o.Type != "" {
|
||||
fieldsSet++
|
||||
}
|
||||
if o.Check != "" {
|
||||
if o.Check != nil {
|
||||
fieldsSet++
|
||||
}
|
||||
if o.References != nil {
|
||||
|
@ -38,12 +38,14 @@ func TestDropConstraint(t *testing.T) {
|
||||
Name: "02_add_check_constraint",
|
||||
Operations: migrations.Operations{
|
||||
&migrations.OpAlterColumn{
|
||||
Table: "posts",
|
||||
Column: "title",
|
||||
ConstraintName: "check_title_length",
|
||||
Check: "length(title) > 3",
|
||||
Up: "(SELECT CASE WHEN length(title) <= 3 THEN LPAD(title, 4, '-') ELSE title END)",
|
||||
Down: "title",
|
||||
Table: "posts",
|
||||
Column: "title",
|
||||
Check: &migrations.CheckConstraint{
|
||||
Name: "check_title_length",
|
||||
Constraint: "length(title) > 3",
|
||||
},
|
||||
Up: "(SELECT CASE WHEN length(title) <= 3 THEN LPAD(title, 4, '-') ELSE title END)",
|
||||
Down: "title",
|
||||
},
|
||||
},
|
||||
},
|
||||
@ -165,12 +167,14 @@ func TestDropConstraint(t *testing.T) {
|
||||
Name: "02_add_check_constraint",
|
||||
Operations: migrations.Operations{
|
||||
&migrations.OpAlterColumn{
|
||||
Table: "posts",
|
||||
Column: "title",
|
||||
ConstraintName: "check_title_length",
|
||||
Check: "length(title) > 3",
|
||||
Up: "(SELECT CASE WHEN length(title) <= 3 THEN LPAD(title, 4, '-') ELSE title END)",
|
||||
Down: "title",
|
||||
Table: "posts",
|
||||
Column: "title",
|
||||
Check: &migrations.CheckConstraint{
|
||||
Name: "check_title_length",
|
||||
Constraint: "length(title) > 3",
|
||||
},
|
||||
Up: "(SELECT CASE WHEN length(title) <= 3 THEN LPAD(title, 4, '-') ELSE title END)",
|
||||
Down: "title",
|
||||
},
|
||||
},
|
||||
},
|
||||
@ -397,12 +401,14 @@ func TestDropConstraintValidation(t *testing.T) {
|
||||
Name: "02_add_check_constraint",
|
||||
Operations: migrations.Operations{
|
||||
&migrations.OpAlterColumn{
|
||||
Table: "posts",
|
||||
Column: "title",
|
||||
ConstraintName: "check_title_length",
|
||||
Check: "length(title) > 3",
|
||||
Up: "(SELECT CASE WHEN length(title) <= 3 THEN LPAD(title, 4, '-') ELSE title END)",
|
||||
Down: "title",
|
||||
Table: "posts",
|
||||
Column: "title",
|
||||
Check: &migrations.CheckConstraint{
|
||||
Name: "check_title_length",
|
||||
Constraint: "length(title) > 3",
|
||||
},
|
||||
Up: "(SELECT CASE WHEN length(title) <= 3 THEN LPAD(title, 4, '-') ELSE title END)",
|
||||
Down: "title",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
@ -11,12 +11,11 @@ import (
|
||||
)
|
||||
|
||||
type OpSetCheckConstraint struct {
|
||||
Table string `json:"table"`
|
||||
Column string `json:"column"`
|
||||
ConstraintName string `json:"constraint_name"`
|
||||
Check string `json:"check"`
|
||||
Up string `json:"up"`
|
||||
Down string `json:"down"`
|
||||
Table string `json:"table"`
|
||||
Column string `json:"column"`
|
||||
Check CheckConstraint `json:"check"`
|
||||
Up string `json:"up"`
|
||||
Down string `json:"down"`
|
||||
}
|
||||
|
||||
var _ Operation = (*OpSetCheckConstraint)(nil)
|
||||
@ -83,7 +82,7 @@ func (o *OpSetCheckConstraint) Complete(ctx context.Context, conn *sql.DB) error
|
||||
// Validate the check constraint
|
||||
_, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE IF EXISTS %s VALIDATE CONSTRAINT %s",
|
||||
pq.QuoteIdentifier(o.Table),
|
||||
pq.QuoteIdentifier(o.ConstraintName)))
|
||||
pq.QuoteIdentifier(o.Check.Name)))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -149,12 +148,12 @@ func (o *OpSetCheckConstraint) Rollback(ctx context.Context, conn *sql.DB) error
|
||||
}
|
||||
|
||||
func (o *OpSetCheckConstraint) Validate(ctx context.Context, s *schema.Schema) error {
|
||||
if o.Check == "" {
|
||||
return FieldRequiredError{Name: "check"}
|
||||
}
|
||||
|
||||
if o.ConstraintName == "" {
|
||||
return FieldRequiredError{Name: "constraint_name"}
|
||||
if err := o.Check.Validate(); err != nil {
|
||||
return CheckConstraintError{
|
||||
Table: o.Table,
|
||||
Column: o.Column,
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
|
||||
if o.Up == "" {
|
||||
@ -171,8 +170,8 @@ func (o *OpSetCheckConstraint) Validate(ctx context.Context, s *schema.Schema) e
|
||||
func (o *OpSetCheckConstraint) addCheckConstraint(ctx context.Context, conn *sql.DB) error {
|
||||
_, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE %s ADD CONSTRAINT %s CHECK (%s) NOT VALID",
|
||||
pq.QuoteIdentifier(o.Table),
|
||||
pq.QuoteIdentifier(o.ConstraintName),
|
||||
rewriteCheckExpression(o.Check, o.Column, TemporaryName(o.Column)),
|
||||
pq.QuoteIdentifier(o.Check.Name),
|
||||
rewriteCheckExpression(o.Check.Constraint, o.Column, TemporaryName(o.Column)),
|
||||
))
|
||||
|
||||
return err
|
||||
|
@ -37,12 +37,14 @@ func TestSetCheckConstraint(t *testing.T) {
|
||||
Name: "02_add_check_constraint",
|
||||
Operations: migrations.Operations{
|
||||
&migrations.OpAlterColumn{
|
||||
Table: "posts",
|
||||
Column: "title",
|
||||
ConstraintName: "check_title_length",
|
||||
Check: "length(title) > 3",
|
||||
Up: "(SELECT CASE WHEN length(title) <= 3 THEN LPAD(title, 4, '-') ELSE title END)",
|
||||
Down: "title",
|
||||
Table: "posts",
|
||||
Column: "title",
|
||||
Check: &migrations.CheckConstraint{
|
||||
Name: "check_title_length",
|
||||
Constraint: "length(title) > 3",
|
||||
},
|
||||
Up: "(SELECT CASE WHEN length(title) <= 3 THEN LPAD(title, 4, '-') ELSE title END)",
|
||||
Down: "title",
|
||||
},
|
||||
},
|
||||
},
|
||||
@ -168,14 +170,16 @@ func TestSetCheckConstraintValidation(t *testing.T) {
|
||||
&migrations.OpAlterColumn{
|
||||
Table: "posts",
|
||||
Column: "title",
|
||||
Check: "length(title) > 3",
|
||||
Up: "(SELECT CASE WHEN length(title) <= 3 THEN LPAD(title, 4, '-') ELSE title END)",
|
||||
Down: "title",
|
||||
Check: &migrations.CheckConstraint{
|
||||
Constraint: "length(title) > 3",
|
||||
},
|
||||
Up: "(SELECT CASE WHEN length(title) <= 3 THEN LPAD(title, 4, '-') ELSE title END)",
|
||||
Down: "title",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
wantStartErr: migrations.FieldRequiredError{Name: "constraint_name"},
|
||||
wantStartErr: migrations.FieldRequiredError{Name: "name"},
|
||||
},
|
||||
{
|
||||
name: "up SQL is mandatory",
|
||||
@ -185,11 +189,13 @@ func TestSetCheckConstraintValidation(t *testing.T) {
|
||||
Name: "02_add_check_constraint",
|
||||
Operations: migrations.Operations{
|
||||
&migrations.OpAlterColumn{
|
||||
Table: "posts",
|
||||
Column: "title",
|
||||
ConstraintName: "check_title_length",
|
||||
Check: "length(title) > 3",
|
||||
Down: "title",
|
||||
Table: "posts",
|
||||
Column: "title",
|
||||
Check: &migrations.CheckConstraint{
|
||||
Name: "check_title_length",
|
||||
Constraint: "length(title) > 3",
|
||||
},
|
||||
Down: "title",
|
||||
},
|
||||
},
|
||||
},
|
||||
@ -204,11 +210,13 @@ func TestSetCheckConstraintValidation(t *testing.T) {
|
||||
Name: "02_add_check_constraint",
|
||||
Operations: migrations.Operations{
|
||||
&migrations.OpAlterColumn{
|
||||
Table: "posts",
|
||||
Column: "title",
|
||||
ConstraintName: "check_title_length",
|
||||
Check: "length(title) > 3",
|
||||
Up: "(SELECT CASE WHEN length(title) <= 3 THEN LPAD(title, 4, '-') ELSE title END)",
|
||||
Table: "posts",
|
||||
Column: "title",
|
||||
Check: &migrations.CheckConstraint{
|
||||
Name: "check_title_length",
|
||||
Constraint: "length(title) > 3",
|
||||
},
|
||||
Up: "(SELECT CASE WHEN length(title) <= 3 THEN LPAD(title, 4, '-') ELSE title END)",
|
||||
},
|
||||
},
|
||||
},
|
||||
|
Loading…
Reference in New Issue
Block a user