From 6034dfebbd7b841daffad46439c6d35e3fde4b68 Mon Sep 17 00:00:00 2001 From: fabianlindfors Date: Wed, 22 Nov 2023 21:58:06 +0100 Subject: [PATCH] Add up option to create_table --- src/migrations/create_table.rs | 129 ++++++++++++++++++-- src/migrations/remove_column.rs | 3 + tests/complex.rs | 210 ++++++++++++++++++++++++++++++++ 3 files changed, 334 insertions(+), 8 deletions(-) create mode 100644 tests/complex.rs diff --git a/src/migrations/create_table.rs b/src/migrations/create_table.rs index 29ebe95..e3bc11d 100644 --- a/src/migrations/create_table.rs +++ b/src/migrations/create_table.rs @@ -1,6 +1,9 @@ +use std::collections::HashMap; + use super::{common::ForeignKey, Action, Column, MigrationContext}; use crate::{ db::{Conn, Transaction}, + migrations::common, schema::Schema, }; use anyhow::Context; @@ -14,6 +17,21 @@ pub struct CreateTable { #[serde(default)] pub foreign_keys: Vec, + + pub up: Option, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct Transformation { + table: String, + values: HashMap, + upsert_constraint: Option, +} + +impl CreateTable { + fn trigger_name(&self, ctx: &MigrationContext) -> String { + format!("{}_create_table_{}", ctx.prefix(), self.name) + } } #[typetag::serde(name = "create_table")] @@ -24,7 +42,7 @@ impl Action for CreateTable { fn run( &self, - _ctx: &MigrationContext, + ctx: &MigrationContext, db: &mut dyn Conn, schema: &Schema, ) -> anyhow::Result<()> { @@ -85,7 +103,7 @@ impl Action for CreateTable { )); } - db.run(&format!( + let query = &format!( r#" CREATE TABLE "{name}" ( {definition} @@ -93,23 +111,118 @@ impl Action for CreateTable { "#, name = self.name, definition = definition_rows.join(",\n"), - )) - .context("failed to create table")?; + ); + db.run(query).context("failed to create table")?; + + if let Some(Transformation { + table: from_table, + values, + upsert_constraint, + }) = &self.up + { + let from_table = schema.get_table(db, &from_table)?; + + let declarations: Vec = from_table + .columns + .iter() + .map(|column| { + format!( + "{alias} public.{table}.{real_name}%TYPE := NEW.{real_name};", + table = from_table.real_name, + alias = column.name, + real_name = column.real_name, + ) + }) + .collect(); + + let (insert_columns, insert_values): (Vec<&str>, Vec<&str>) = values + .iter() + .map(|(k, v)| -> (&str, &str) { (k, v) }) // Force &String to &str + .unzip(); + + let update_set: Vec = values + .iter() + .map(|(field, value)| format!("\"{field}\" = {value}")) + .collect(); + + // Constraint to check for conflicts. Defaults to the primary key constraint. + let conflict_constraint_name = match upsert_constraint { + Some(custom_constraint) => custom_constraint.clone(), + _ => format!("{table}_pkey", table = self.name), + }; + + // Add triggers to fill in values as they are inserted/updated + let query = format!( + r#" + CREATE OR REPLACE FUNCTION {trigger_name}() + RETURNS TRIGGER AS $$ + #variable_conflict use_variable + BEGIN + IF NOT reshape.is_new_schema() THEN + DECLARE + {declarations} + BEGIN + INSERT INTO public."{changed_table_real}" ({columns}) + VALUES ({values}) + ON CONFLICT ON CONSTRAINT "{conflict_constraint_name}" + DO UPDATE SET + {updates}; + END; + END IF; + RETURN NEW; + END + $$ language 'plpgsql'; + + DROP TRIGGER IF EXISTS "{trigger_name}" ON "{from_table_real}"; + CREATE TRIGGER "{trigger_name}" BEFORE UPDATE OR INSERT ON "{from_table_real}" FOR EACH ROW EXECUTE PROCEDURE {trigger_name}(); + "#, + changed_table_real = self.name, + from_table_real = from_table.real_name, + trigger_name = self.trigger_name(ctx), + declarations = declarations.join("\n"), + columns = insert_columns.join(", "), + values = insert_values.join(", "), + updates = update_set.join(",\n"), + ); + db.run(&query).context("failed to create up trigger")?; + + // Backfill values in batches by touching the from table + common::batch_touch_rows(db, &from_table.real_name, None) + .context("failed to batch update existing rows")?; + } + Ok(()) } fn complete<'a>( &self, - _ctx: &MigrationContext, - _db: &'a mut dyn Conn, + ctx: &MigrationContext, + db: &'a mut dyn Conn, ) -> anyhow::Result>> { - // Do nothing + // Remove triggers and procedures + let query = format!( + r#" + DROP FUNCTION IF EXISTS "{trigger_name}" CASCADE; + "#, + trigger_name = self.trigger_name(ctx), + ); + db.run(&query).context("failed to drop up trigger")?; + Ok(None) } fn update_schema(&self, _ctx: &MigrationContext, _schema: &mut Schema) {} - fn abort(&self, _ctx: &MigrationContext, db: &mut dyn Conn) -> anyhow::Result<()> { + fn abort(&self, ctx: &MigrationContext, db: &mut dyn Conn) -> anyhow::Result<()> { + // Remove triggers and procedures + let query = format!( + r#" + DROP FUNCTION IF EXISTS "{trigger_name}" CASCADE; + "#, + trigger_name = self.trigger_name(ctx), + ); + db.run(&query).context("failed to drop up trigger")?; + db.run(&format!( r#" DROP TABLE IF EXISTS "{name}" diff --git a/src/migrations/remove_column.rs b/src/migrations/remove_column.rs index e18d19e..ff09e46 100644 --- a/src/migrations/remove_column.rs +++ b/src/migrations/remove_column.rs @@ -122,6 +122,9 @@ impl Action for RemoveColumn { }) .collect(); + // TODO: If column is NOT NULL, remove the constraint and perform a NULL check in some other way. + // Otherwise, it's not possible to update the NOT NULL column from another table even in the same transaction. + // Either make the NULL check in here or maybe use a constraint trigger: https://www.postgresql.org/docs/9.0/sql-createconstraint.html. let query = format!( r#" CREATE OR REPLACE FUNCTION {trigger_name}() diff --git a/tests/complex.rs b/tests/complex.rs new file mode 100644 index 0000000..09fc171 --- /dev/null +++ b/tests/complex.rs @@ -0,0 +1,210 @@ +mod common; +use common::Test; + +#[test] +fn extract_relation_into_new_table() { + let mut test = Test::new("Extract relation into new table"); + + test.first_migration( + r#" + name = "create_tables" + + [[actions]] + type = "create_table" + name = "accounts" + primary_key = ["id"] + + [[actions.columns]] + name = "id" + type = "INTEGER" + + [[actions]] + type = "create_table" + name = "users" + primary_key = ["id"] + + [[actions.columns]] + name = "id" + type = "INTEGER" + + [[actions.columns]] + name = "account_id" + type = "INTEGER" + + [[actions.columns]] + name = "account_role" + type = "TEXT" + "#, + ); + + test.second_migration( + r#" + name = "add_account_user_connection" + + [[actions]] + type = "create_table" + name = "user_account_connections" + primary_key = ["account_id", "user_id"] + + [[actions.columns]] + name = "account_id" + type = "INTEGER" + + [[actions.columns]] + name = "user_id" + type = "INTEGER" + + [[actions.columns]] + name = "role" + type = "TEXT" + + [actions.up] + table = "users" + values = { user_id = "id", account_id = "account_id", role = "UPPER(account_role)" } + where = "user_id = id" + + [[actions]] + type = "remove_column" + table = "users" + column = "account_id" + + [actions.down] + table = "user_account_connections" + value = "account_id" + where = "id = user_id" + + [[actions]] + type = "remove_column" + table = "users" + column = "account_role" + + [actions.down] + table = "user_account_connections" + value = "LOWER(role)" + where = "id = user_id" + "#, + ); + + test.after_first(|db| { + db.simple_query("INSERT INTO accounts (id) VALUES (1)") + .unwrap(); + db.simple_query("INSERT INTO users (id, account_id, account_role) VALUES (1, 1, 'admin')") + .unwrap(); + }); + + test.intermediate(|old_db, new_db| { + // Ensure connections was backfilled + let rows: Vec<(i32, i32, String)> = new_db + .query( + " + SELECT account_id, user_id, role + FROM user_account_connections + ", + &[], + ) + .unwrap() + .iter() + .map(|row| (row.get("account_id"), row.get("user_id"), row.get("role"))) + .collect(); + assert_eq!(1, rows.len()); + + let row = rows.first().unwrap(); + assert_eq!(1, row.0); + assert_eq!(1, row.1); + assert_eq!("ADMIN", row.2); + + // Ensure inserted user in old schema creates a new connection + old_db + .simple_query( + "INSERT INTO users (id, account_id, account_role) VALUES (2, 1, 'developer')", + ) + .unwrap(); + assert!( + new_db + .query( + " + SELECT account_id, user_id, role + FROM user_account_connections + WHERE account_id = 1 AND user_id = 2 AND role = 'DEVELOPER' + ", + &[], + ) + .unwrap() + .len() + == 1 + ); + + // Ensure updated user role in old schema updates connection in new schema + old_db + .simple_query("UPDATE users SET account_role = 'admin' WHERE id = 2") + .unwrap(); + assert!( + new_db + .query( + " + SELECT account_id, user_id, role + FROM user_account_connections + WHERE account_id = 1 AND user_id = 2 AND role = 'ADMIN' + ", + &[], + ) + .unwrap() + .len() + == 1 + ); + + // Ensure updated connection in new schema updates old schema user + new_db + .simple_query( + "UPDATE user_account_connections SET role = 'DEVELOPER' WHERE account_id = 1 AND user_id = 2", + ) + .unwrap(); + assert!( + old_db + .query( + " + SELECT id + FROM users + WHERE id = 2 AND account_id = 1 AND account_role = 'developer' + ", + &[], + ) + .unwrap() + .len() + == 1 + ); + + // Ensure insert of user with connection through new schema updates user in old schema + new_db + .simple_query( + r#" + BEGIN; + INSERT INTO users (id) VALUES (3); + INSERT INTO user_account_connections (user_id, account_id, role) VALUES (3, 1, 'DEVELOPER'); + COMMIT; + "#, + ) + .unwrap(); + new_db + .simple_query( + "", + ) + .unwrap(); + assert!( + old_db + .query( + " + SELECT id + FROM users + WHERE id = 3 AND account_id = 1 AND account_role = 'developer' + ", + &[], + ) + .unwrap() + .len() + == 1 + ); + }); + + test.run(); +}