Add up option to create_table

This commit is contained in:
fabianlindfors 2023-11-22 21:58:06 +01:00
parent 42257be9cf
commit 6034dfebbd
3 changed files with 334 additions and 8 deletions

View File

@ -1,6 +1,9 @@
use std::collections::HashMap;
use super::{common::ForeignKey, Action, Column, MigrationContext};
use crate::{
db::{Conn, Transaction},
migrations::common,
schema::Schema,
};
use anyhow::Context;
@ -14,6 +17,21 @@ pub struct CreateTable {
#[serde(default)]
pub foreign_keys: Vec<ForeignKey>,
pub up: Option<Transformation>,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct Transformation {
table: String,
values: HashMap<String, String>,
upsert_constraint: Option<String>,
}
impl CreateTable {
fn trigger_name(&self, ctx: &MigrationContext) -> String {
format!("{}_create_table_{}", ctx.prefix(), self.name)
}
}
#[typetag::serde(name = "create_table")]
@ -24,7 +42,7 @@ impl Action for CreateTable {
fn run(
&self,
_ctx: &MigrationContext,
ctx: &MigrationContext,
db: &mut dyn Conn,
schema: &Schema,
) -> anyhow::Result<()> {
@ -85,7 +103,7 @@ impl Action for CreateTable {
));
}
db.run(&format!(
let query = &format!(
r#"
CREATE TABLE "{name}" (
{definition}
@ -93,23 +111,118 @@ impl Action for CreateTable {
"#,
name = self.name,
definition = definition_rows.join(",\n"),
))
.context("failed to create table")?;
);
db.run(query).context("failed to create table")?;
if let Some(Transformation {
table: from_table,
values,
upsert_constraint,
}) = &self.up
{
let from_table = schema.get_table(db, &from_table)?;
let declarations: Vec<String> = from_table
.columns
.iter()
.map(|column| {
format!(
"{alias} public.{table}.{real_name}%TYPE := NEW.{real_name};",
table = from_table.real_name,
alias = column.name,
real_name = column.real_name,
)
})
.collect();
let (insert_columns, insert_values): (Vec<&str>, Vec<&str>) = values
.iter()
.map(|(k, v)| -> (&str, &str) { (k, v) }) // Force &String to &str
.unzip();
let update_set: Vec<String> = values
.iter()
.map(|(field, value)| format!("\"{field}\" = {value}"))
.collect();
// Constraint to check for conflicts. Defaults to the primary key constraint.
let conflict_constraint_name = match upsert_constraint {
Some(custom_constraint) => custom_constraint.clone(),
_ => format!("{table}_pkey", table = self.name),
};
// Add triggers to fill in values as they are inserted/updated
let query = format!(
r#"
CREATE OR REPLACE FUNCTION {trigger_name}()
RETURNS TRIGGER AS $$
#variable_conflict use_variable
BEGIN
IF NOT reshape.is_new_schema() THEN
DECLARE
{declarations}
BEGIN
INSERT INTO public."{changed_table_real}" ({columns})
VALUES ({values})
ON CONFLICT ON CONSTRAINT "{conflict_constraint_name}"
DO UPDATE SET
{updates};
END;
END IF;
RETURN NEW;
END
$$ language 'plpgsql';
DROP TRIGGER IF EXISTS "{trigger_name}" ON "{from_table_real}";
CREATE TRIGGER "{trigger_name}" BEFORE UPDATE OR INSERT ON "{from_table_real}" FOR EACH ROW EXECUTE PROCEDURE {trigger_name}();
"#,
changed_table_real = self.name,
from_table_real = from_table.real_name,
trigger_name = self.trigger_name(ctx),
declarations = declarations.join("\n"),
columns = insert_columns.join(", "),
values = insert_values.join(", "),
updates = update_set.join(",\n"),
);
db.run(&query).context("failed to create up trigger")?;
// Backfill values in batches by touching the from table
common::batch_touch_rows(db, &from_table.real_name, None)
.context("failed to batch update existing rows")?;
}
Ok(())
}
fn complete<'a>(
&self,
_ctx: &MigrationContext,
_db: &'a mut dyn Conn,
ctx: &MigrationContext,
db: &'a mut dyn Conn,
) -> anyhow::Result<Option<Transaction<'a>>> {
// Do nothing
// Remove triggers and procedures
let query = format!(
r#"
DROP FUNCTION IF EXISTS "{trigger_name}" CASCADE;
"#,
trigger_name = self.trigger_name(ctx),
);
db.run(&query).context("failed to drop up trigger")?;
Ok(None)
}
fn update_schema(&self, _ctx: &MigrationContext, _schema: &mut Schema) {}
fn abort(&self, _ctx: &MigrationContext, db: &mut dyn Conn) -> anyhow::Result<()> {
fn abort(&self, ctx: &MigrationContext, db: &mut dyn Conn) -> anyhow::Result<()> {
// Remove triggers and procedures
let query = format!(
r#"
DROP FUNCTION IF EXISTS "{trigger_name}" CASCADE;
"#,
trigger_name = self.trigger_name(ctx),
);
db.run(&query).context("failed to drop up trigger")?;
db.run(&format!(
r#"
DROP TABLE IF EXISTS "{name}"

View File

@ -122,6 +122,9 @@ impl Action for RemoveColumn {
})
.collect();
// TODO: If column is NOT NULL, remove the constraint and perform a NULL check in some other way.
// Otherwise, it's not possible to update the NOT NULL column from another table even in the same transaction.
// Either make the NULL check in here or maybe use a constraint trigger: https://www.postgresql.org/docs/9.0/sql-createconstraint.html.
let query = format!(
r#"
CREATE OR REPLACE FUNCTION {trigger_name}()

210
tests/complex.rs Normal file
View File

@ -0,0 +1,210 @@
mod common;
use common::Test;
#[test]
fn extract_relation_into_new_table() {
let mut test = Test::new("Extract relation into new table");
test.first_migration(
r#"
name = "create_tables"
[[actions]]
type = "create_table"
name = "accounts"
primary_key = ["id"]
[[actions.columns]]
name = "id"
type = "INTEGER"
[[actions]]
type = "create_table"
name = "users"
primary_key = ["id"]
[[actions.columns]]
name = "id"
type = "INTEGER"
[[actions.columns]]
name = "account_id"
type = "INTEGER"
[[actions.columns]]
name = "account_role"
type = "TEXT"
"#,
);
test.second_migration(
r#"
name = "add_account_user_connection"
[[actions]]
type = "create_table"
name = "user_account_connections"
primary_key = ["account_id", "user_id"]
[[actions.columns]]
name = "account_id"
type = "INTEGER"
[[actions.columns]]
name = "user_id"
type = "INTEGER"
[[actions.columns]]
name = "role"
type = "TEXT"
[actions.up]
table = "users"
values = { user_id = "id", account_id = "account_id", role = "UPPER(account_role)" }
where = "user_id = id"
[[actions]]
type = "remove_column"
table = "users"
column = "account_id"
[actions.down]
table = "user_account_connections"
value = "account_id"
where = "id = user_id"
[[actions]]
type = "remove_column"
table = "users"
column = "account_role"
[actions.down]
table = "user_account_connections"
value = "LOWER(role)"
where = "id = user_id"
"#,
);
test.after_first(|db| {
db.simple_query("INSERT INTO accounts (id) VALUES (1)")
.unwrap();
db.simple_query("INSERT INTO users (id, account_id, account_role) VALUES (1, 1, 'admin')")
.unwrap();
});
test.intermediate(|old_db, new_db| {
// Ensure connections was backfilled
let rows: Vec<(i32, i32, String)> = new_db
.query(
"
SELECT account_id, user_id, role
FROM user_account_connections
",
&[],
)
.unwrap()
.iter()
.map(|row| (row.get("account_id"), row.get("user_id"), row.get("role")))
.collect();
assert_eq!(1, rows.len());
let row = rows.first().unwrap();
assert_eq!(1, row.0);
assert_eq!(1, row.1);
assert_eq!("ADMIN", row.2);
// Ensure inserted user in old schema creates a new connection
old_db
.simple_query(
"INSERT INTO users (id, account_id, account_role) VALUES (2, 1, 'developer')",
)
.unwrap();
assert!(
new_db
.query(
"
SELECT account_id, user_id, role
FROM user_account_connections
WHERE account_id = 1 AND user_id = 2 AND role = 'DEVELOPER'
",
&[],
)
.unwrap()
.len()
== 1
);
// Ensure updated user role in old schema updates connection in new schema
old_db
.simple_query("UPDATE users SET account_role = 'admin' WHERE id = 2")
.unwrap();
assert!(
new_db
.query(
"
SELECT account_id, user_id, role
FROM user_account_connections
WHERE account_id = 1 AND user_id = 2 AND role = 'ADMIN'
",
&[],
)
.unwrap()
.len()
== 1
);
// Ensure updated connection in new schema updates old schema user
new_db
.simple_query(
"UPDATE user_account_connections SET role = 'DEVELOPER' WHERE account_id = 1 AND user_id = 2",
)
.unwrap();
assert!(
old_db
.query(
"
SELECT id
FROM users
WHERE id = 2 AND account_id = 1 AND account_role = 'developer'
",
&[],
)
.unwrap()
.len()
== 1
);
// Ensure insert of user with connection through new schema updates user in old schema
new_db
.simple_query(
r#"
BEGIN;
INSERT INTO users (id) VALUES (3);
INSERT INTO user_account_connections (user_id, account_id, role) VALUES (3, 1, 'DEVELOPER');
COMMIT;
"#,
)
.unwrap();
new_db
.simple_query(
"",
)
.unwrap();
assert!(
old_db
.query(
"
SELECT id
FROM users
WHERE id = 3 AND account_id = 1 AND account_role = 'developer'
",
&[],
)
.unwrap()
.len()
== 1
);
});
test.run();
}