From 4e5cc5c4ff96dd481991583bf9278b335f025d1f Mon Sep 17 00:00:00 2001 From: fabianlindfors Date: Wed, 22 Nov 2023 00:04:45 +0100 Subject: [PATCH] Add support for complex up function in --- README.md | 20 ++++ src/migrations/add_column.rs | 164 +++++++++++++++++++++++---------- src/migrations/alter_column.rs | 2 +- src/migrations/common.rs | 23 +++-- tests/add_column.rs | 99 ++++++++++++++++++++ 5 files changed, 248 insertions(+), 60 deletions(-) diff --git a/README.md b/README.md index fc4f2f6..eb46467 100644 --- a/README.md +++ b/README.md @@ -351,6 +351,26 @@ up = "data['path']['to']['value'] #>> '{}'" type = "TEXT" ``` +_Example: duplicate `email` column from `users` to `profiles` table_ + +```toml +# `profiles` has `user_id` column which maps to `users.id` +[[actions]] +type = "add_column" +table = "profiles" + + [actions.column] + name = "email" + type = "TEXT" + nullable = false + + # When `users` is updated in the old schema, we write the email value to `profiles` + [actions.up] + table = "users" + value = "email" + where = "user_id = id" +``` + #### Alter column The `alter_column` action enables many different changes to an existing column, for example renaming, changing type and changing existing values. diff --git a/src/migrations/add_column.rs b/src/migrations/add_column.rs index 537d6bc..1e1b845 100644 --- a/src/migrations/add_column.rs +++ b/src/migrations/add_column.rs @@ -9,8 +9,19 @@ use serde::{Deserialize, Serialize}; #[derive(Serialize, Deserialize, Debug)] pub struct AddColumn { pub table: String, - pub up: Option, pub column: Column, + pub up: Option, +} + +#[derive(Serialize, Deserialize, Debug)] +#[serde(untagged)] +pub enum Transformation { + Simple(String), + Update { + table: String, + value: String, + r#where: String, + }, } impl AddColumn { @@ -87,52 +98,109 @@ impl Action for AddColumn { db.run(&query).context("failed to add column")?; if let Some(up) = &self.up { - let declarations: Vec = table - .columns - .iter() - .map(|column| { - format!( - "{alias} public.{table}.{real_name}%TYPE := NEW.{real_name};", - table = table.real_name, - alias = column.name, - real_name = column.real_name, - ) - }) - .collect(); + if let Transformation::Simple(up) = up { + let declarations: Vec = table + .columns + .iter() + .map(|column| { + format!( + "{alias} public.{table}.{real_name}%TYPE := NEW.{real_name};", + table = table.real_name, + alias = column.name, + real_name = column.real_name, + ) + }) + .collect(); - // Add triggers to fill in values as they are inserted/updated - let query = format!( - r#" - CREATE OR REPLACE FUNCTION {trigger_name}() - RETURNS TRIGGER AS $$ - BEGIN - IF NOT reshape.is_new_schema() THEN - DECLARE - {declarations} - BEGIN - NEW.{temp_column_name} = {up}; - END; - END IF; - RETURN NEW; - END - $$ language 'plpgsql'; + // Add triggers to fill in values as they are inserted/updated + let query = format!( + r#" + CREATE OR REPLACE FUNCTION {trigger_name}() + RETURNS TRIGGER AS $$ + BEGIN + IF NOT reshape.is_new_schema() THEN + DECLARE + {declarations} + BEGIN + NEW.{temp_column_name} = {up}; + END; + END IF; + RETURN NEW; + END + $$ language 'plpgsql'; - DROP TRIGGER IF EXISTS "{trigger_name}" ON "{table}"; - CREATE TRIGGER "{trigger_name}" BEFORE UPDATE OR INSERT ON "{table}" FOR EACH ROW EXECUTE PROCEDURE {trigger_name}(); - "#, - temp_column_name = temp_column_name, - trigger_name = self.trigger_name(ctx), - up = up, - table = self.table, - declarations = declarations.join("\n"), - ); - db.run(&query).context("failed to create up trigger")?; - } + DROP TRIGGER IF EXISTS "{trigger_name}" ON "{table}"; + CREATE TRIGGER "{trigger_name}" BEFORE UPDATE OR INSERT ON "{table}" FOR EACH ROW EXECUTE PROCEDURE {trigger_name}(); + "#, + temp_column_name = temp_column_name, + trigger_name = self.trigger_name(ctx), + up = up, + table = self.table, + declarations = declarations.join("\n"), + ); + db.run(&query).context("failed to create up trigger")?; - // Backfill values in batches - if self.up.is_some() { - common::batch_touch_rows(db, &table.real_name, &temp_column_name) - .context("failed to batch update existing rows")?; + // Backfill values in batches + common::batch_touch_rows(db, &table.real_name, Some(&temp_column_name)) + .context("failed to batch update existing rows")?; + } + + if let Transformation::Update { + table: from_table, + value, + r#where, + } = up + { + let from_table = schema.get_table(db, &from_table)?; + + let declarations: Vec = from_table + .columns + .iter() + .map(|column| { + format!( + "{alias} public.{table}.{real_name}%TYPE := NEW.{real_name};", + table = from_table.real_name, + alias = column.name, + real_name = column.real_name, + ) + }) + .collect(); + + // Add triggers to fill in values as they are inserted/updated + let query = format!( + r#" + CREATE OR REPLACE FUNCTION {trigger_name}() + RETURNS TRIGGER AS $$ + #variable_conflict use_variable + BEGIN + IF NOT reshape.is_new_schema() THEN + DECLARE + {declarations} + BEGIN + UPDATE public."{changed_table_real}" + SET "{temp_column_name}" = {value} + WHERE {where}; + END; + END IF; + RETURN NEW; + END + $$ language 'plpgsql'; + + DROP TRIGGER IF EXISTS "{trigger_name}" ON "{from_table_real}"; + CREATE TRIGGER "{trigger_name}" BEFORE UPDATE OR INSERT ON "{from_table_real}" FOR EACH ROW EXECUTE PROCEDURE {trigger_name}(); + "#, + changed_table_real = table.real_name, + from_table_real = from_table.real_name, + trigger_name = self.trigger_name(ctx), + declarations = declarations.join("\n"), + temp_column_name = temp_column_name, + ); + db.run(&query).context("failed to create up trigger")?; + + // Backfill values in batches by touching the from table + common::batch_touch_rows(db, &from_table.real_name, None) + .context("failed to batch update existing rows")?; + } } // Add a temporary NOT NULL constraint if the column shouldn't be nullable. @@ -167,10 +235,8 @@ impl Action for AddColumn { // Remove triggers and procedures let query = format!( r#" - DROP TRIGGER IF EXISTS "{trigger_name}" ON "{table}"; - DROP FUNCTION IF EXISTS "{trigger_name}"; + DROP FUNCTION IF EXISTS "{trigger_name}" CASCADE; "#, - table = self.table, trigger_name = self.trigger_name(ctx), ); transaction @@ -262,10 +328,8 @@ impl Action for AddColumn { // Remove triggers and procedures let query = format!( r#" - DROP TRIGGER IF EXISTS "{trigger_name}" ON "{table}"; - DROP FUNCTION IF EXISTS "{trigger_name}"; + DROP FUNCTION IF EXISTS "{trigger_name}" CASCADE; "#, - table = self.table, trigger_name = self.trigger_name(ctx), ); db.run(&query).context("failed to drop up trigger")?; diff --git a/src/migrations/alter_column.rs b/src/migrations/alter_column.rs index e3d7bd1..fb30594 100644 --- a/src/migrations/alter_column.rs +++ b/src/migrations/alter_column.rs @@ -150,7 +150,7 @@ impl Action for AlterColumn { .context("failed to create up and down triggers")?; // Backfill values in batches by touching the previous column - common::batch_touch_rows(db, &table.real_name, &column.real_name) + common::batch_touch_rows(db, &table.real_name, Some(&column.real_name)) .context("failed to batch update existing rows")?; // Duplicate any indices to the temporary column diff --git a/src/migrations/common.rs b/src/migrations/common.rs index 8170c8c..70ff54f 100644 --- a/src/migrations/common.rs +++ b/src/migrations/common.rs @@ -69,7 +69,11 @@ impl ToSql for PostgresRawValue { postgres::types::to_sql_checked!(); } -pub fn batch_touch_rows(db: &mut dyn Conn, table: &str, column: &str) -> anyhow::Result<()> { +pub fn batch_touch_rows( + db: &mut dyn Conn, + table: &str, + column: Option<&str>, +) -> anyhow::Result<()> { const BATCH_SIZE: u16 = 1000; let mut cursor: Option = None; @@ -78,6 +82,13 @@ pub fn batch_touch_rows(db: &mut dyn Conn, table: &str, column: &str) -> anyhow: let mut params: Vec<&(dyn ToSql + Sync)> = Vec::new(); let primary_key = get_primary_key_columns_for_table(db, table)?; + + // If no column to touch is passed, we default to the first primary key column (just to make some "update") + let touched_column = match column { + Some(column) => column, + None => primary_key.first().unwrap(), + }; + let primary_key_columns = primary_key.join(", "); let primary_key_where = primary_key @@ -120,8 +131,8 @@ pub fn batch_touch_rows(db: &mut dyn Conn, table: &str, column: &str) -> anyhow: ORDER BY {primary_key_columns} LIMIT {batch_size} ), update AS ( - UPDATE public."{table}" - SET "{column}" = "{column}" + UPDATE public."{table}" "{table}" + SET "{touched_column}" = "{table}"."{touched_column}" FROM rows WHERE {primary_key_where} RETURNING {returning_columns} @@ -130,13 +141,7 @@ pub fn batch_touch_rows(db: &mut dyn Conn, table: &str, column: &str) -> anyhow: FROM update LIMIT 1 "#, - table = table, - primary_key_columns = primary_key_columns, - cursor_where = cursor_where, batch_size = BATCH_SIZE, - column = column, - primary_key_where = primary_key_where, - returning_columns = returning_columns, ); let last_value = db .query_with_params(&query, ¶ms)? diff --git a/tests/add_column.rs b/tests/add_column.rs index 1286b48..795b355 100644 --- a/tests/add_column.rs +++ b/tests/add_column.rs @@ -260,3 +260,102 @@ fn add_column_with_default() { test.run(); } + +#[test] +fn add_column_with_complex_up() { + let mut test = Test::new("Add column complex"); + + test.first_migration( + r#" + name = "create_tables" + + [[actions]] + type = "create_table" + name = "users" + primary_key = ["id"] + + [[actions.columns]] + name = "id" + type = "INTEGER" + + [[actions.columns]] + name = "email" + type = "TEXT" + + [[actions]] + type = "create_table" + name = "profiles" + primary_key = ["user_id"] + + [[actions.columns]] + name = "user_id" + type = "INTEGER" + "#, + ); + + test.second_migration( + r#" + name = "add_profiles_email_column" + + [[actions]] + type = "add_column" + table = "profiles" + + [actions.column] + name = "email" + type = "TEXT" + nullable = false + + [actions.up] + table = "users" + value = "email" + where = "user_id = id" + "#, + ); + + test.after_first(|db| { + db.simple_query("INSERT INTO users (id, email) VALUES (1, 'test@example.com')") + .unwrap(); + db.simple_query("INSERT INTO profiles (user_id) VALUES (1)") + .unwrap(); + }); + + test.intermediate(|old_db, new_db| { + // Ensure email was backfilled on profiles + let email: String = new_db + .query( + " + SELECT email + FROM profiles + WHERE user_id = 1 + ", + &[], + ) + .unwrap() + .first() + .map(|row| row.get("email")) + .unwrap(); + assert_eq!("test@example.com", email); + + // Ensure email change in old schema is propagated to profiles table in new schema + old_db + .simple_query("UPDATE users SET email = 'test2@example.com' WHERE id = 1") + .unwrap(); + let email: String = new_db + .query( + " + SELECT email + FROM profiles + WHERE user_id = 1 + ", + &[], + ) + .unwrap() + .first() + .map(|row| row.get("email")) + .unwrap(); + assert_eq!("test2@example.com", email); + }); + + test.run(); +}