diff --git a/tests/add_column.rs b/tests/add_column.rs index cee0b74..70a3ab0 100644 --- a/tests/add_column.rs +++ b/tests/add_column.rs @@ -4,76 +4,59 @@ mod common; #[test] fn add_column() { - let (mut reshape, mut old_db, mut new_db) = common::setup(); + let mut test = common::Test::new("Add column"); - let create_users_table = Migration::new("create_user_table", None).with_action( - CreateTableBuilder::default() - .name("users") - .primary_key(vec!["id".to_string()]) - .columns(vec![ - ColumnBuilder::default() - .name("id") - .data_type("INTEGER") - .build() - .unwrap(), - ColumnBuilder::default() - .name("name") - .data_type("TEXT") - .build() - .unwrap(), - ]) - .build() - .unwrap(), + test.first_migration( + Migration::new("create_user_table", None).with_action( + CreateTableBuilder::default() + .name("users") + .primary_key(vec!["id".to_string()]) + .columns(vec![ + ColumnBuilder::default() + .name("id") + .data_type("INTEGER") + .build() + .unwrap(), + ColumnBuilder::default() + .name("name") + .data_type("TEXT") + .build() + .unwrap(), + ]) + .build() + .unwrap(), + ), ); - let add_first_last_name_columns = Migration::new("add_first_and_last_name_columns", None) - .with_action(AddColumn { - table: "users".to_string(), - column: Column { - name: "first".to_string(), - data_type: "TEXT".to_string(), - nullable: false, - default: None, - generated: None, - }, - up: Some("(STRING_TO_ARRAY(name, ' '))[1]".to_string()), - }) - .with_action(AddColumn { - table: "users".to_string(), - column: Column { - name: "last".to_string(), - data_type: "TEXT".to_string(), - nullable: false, - default: None, - generated: None, - }, - up: Some("(STRING_TO_ARRAY(name, ' '))[2]".to_string()), - }); + test.second_migration( + Migration::new("add_first_and_last_name_columns", None) + .with_action(AddColumn { + table: "users".to_string(), + column: Column { + name: "first".to_string(), + data_type: "TEXT".to_string(), + nullable: false, + default: None, + generated: None, + }, + up: Some("(STRING_TO_ARRAY(name, ' '))[1]".to_string()), + }) + .with_action(AddColumn { + table: "users".to_string(), + column: Column { + name: "last".to_string(), + data_type: "TEXT".to_string(), + nullable: false, + default: None, + generated: None, + }, + up: Some("(STRING_TO_ARRAY(name, ' '))[2]".to_string()), + }), + ); - let first_migrations = vec![create_users_table.clone()]; - let second_migrations = vec![ - create_users_table.clone(), - add_first_last_name_columns.clone(), - ]; - - // Run first migration, should automatically finish - reshape.migrate(first_migrations.clone()).unwrap(); - - // Update search paths - old_db - .simple_query(&reshape::schema_query_for_migration( - &first_migrations.last().unwrap().name, - )) - .unwrap(); - new_db - .simple_query(&reshape::schema_query_for_migration( - &first_migrations.last().unwrap().name, - )) - .unwrap(); - - // Insert some test users - new_db - .simple_query( + test.after_first(|db| { + // Insert some test users + db.simple_query( " INSERT INTO users (id, name) VALUES (1, 'John Doe'), @@ -81,39 +64,53 @@ fn add_column() { ", ) .unwrap(); + }); - // Run second migration - reshape.migrate(second_migrations.clone()).unwrap(); - new_db - .simple_query(&reshape::schema_query_for_migration( - &second_migrations.last().unwrap().name, - )) - .unwrap(); + test.intermediate(|old_db, new_db| { + // Check that the existing users have the new columns populated + let expected = vec![("John", "Doe"), ("Jane", "Doe")]; + assert!(new_db + .query("SELECT first, last FROM users ORDER BY id", &[],) + .unwrap() + .iter() + .map(|row| (row.get("first"), row.get("last"))) + .eq(expected)); - // Check that the existing users have the new columns populated - let expected = vec![("John", "Doe"), ("Jane", "Doe")]; - assert!(new_db - .query("SELECT first, last FROM users ORDER BY id", &[],) - .unwrap() - .iter() - .map(|row| (row.get("first"), row.get("last"))) - .eq(expected)); + // Insert data using old schema and make sure the new columns are populated + old_db + .simple_query("INSERT INTO users (id, name) VALUES (3, 'Test Testsson')") + .unwrap(); + let (first_name, last_name): (String, String) = new_db + .query_one("SELECT first, last from users WHERE id = 3", &[]) + .map(|row| (row.get("first"), row.get("last"))) + .unwrap(); + assert_eq!( + ("Test", "Testsson"), + (first_name.as_ref(), last_name.as_ref()) + ); + }); - // Insert data using old schema and make sure the new columns are populated - old_db - .simple_query("INSERT INTO users (id, name) VALUES (3, 'Test Testsson')") - .unwrap(); - let (first_name, last_name): (String, String) = new_db - .query_one("SELECT first, last from users WHERE id = 3", &[]) - .map(|row| (row.get("first"), row.get("last"))) - .unwrap(); - assert_eq!( - ("Test", "Testsson"), - (first_name.as_ref(), last_name.as_ref()) - ); + test.after_completion(|db| { + let expected = vec![("John", "Doe"), ("Jane", "Doe"), ("Test", "Testsson")]; + assert!(db + .query("SELECT first, last FROM users ORDER BY id", &[],) + .unwrap() + .iter() + .map(|row| (row.get("first"), row.get("last"))) + .eq(expected)); + }); - reshape.complete().unwrap(); - common::assert_cleaned_up(&mut new_db); + test.after_abort(|db| { + let expected = vec![("John Doe"), ("Jane Doe"), ("Test Testsson")]; + assert!(db + .query("SELECT name FROM users ORDER BY id", &[],) + .unwrap() + .iter() + .map(|row| row.get::<'_, _, String>("name")) + .eq(expected)); + }); + + test.run() } #[test] diff --git a/tests/common.rs b/tests/common.rs index 5555dc4..38951ca 100644 --- a/tests/common.rs +++ b/tests/common.rs @@ -1,5 +1,202 @@ +use colored::Colorize; use postgres::{Client, NoTls}; -use reshape::Reshape; +use reshape::{migrations::Migration, Reshape}; + +pub struct Test<'a> { + name: &'a str, + reshape: Reshape, + old_db: Client, + new_db: Client, + + first_migration: Option, + second_migration: Option, + + after_first_fn: Option ()>, + intermediate_fn: Option ()>, + after_completion_fn: Option ()>, + after_abort_fn: Option ()>, +} + +impl Test<'_> { + pub fn new<'a>(name: &'a str) -> Test<'a> { + let connection_string = std::env::var("POSTGRES_CONNECTION_STRING") + .unwrap_or("postgres://postgres:postgres@localhost/reshape_test".to_string()); + + let old_db = Client::connect(&connection_string, NoTls).unwrap(); + let new_db = Client::connect(&connection_string, NoTls).unwrap(); + + let reshape = Reshape::new(&connection_string).unwrap(); + + Test { + name, + reshape, + old_db, + new_db, + first_migration: None, + second_migration: None, + after_first_fn: None, + intermediate_fn: None, + after_completion_fn: None, + after_abort_fn: None, + } + } + + pub fn first_migration(&mut self, migration: Migration) -> &mut Self { + self.first_migration = Some(migration); + self + } + + pub fn second_migration(&mut self, migration: Migration) -> &mut Self { + self.second_migration = Some(migration); + self + } + + pub fn after_first(&mut self, f: fn(&mut Client) -> ()) -> &mut Self { + self.after_first_fn = Some(f); + self + } + + pub fn intermediate(&mut self, f: fn(&mut Client, &mut Client) -> ()) -> &mut Self { + self.intermediate_fn = Some(f); + self + } + + pub fn after_completion(&mut self, f: fn(&mut Client) -> ()) -> &mut Self { + self.after_completion_fn = Some(f); + self + } + + pub fn after_abort(&mut self, f: fn(&mut Client) -> ()) -> &mut Self { + self.after_abort_fn = Some(f); + self + } +} + +enum RunType { + Simple, + Completion, + Abort, +} + +impl Test<'_> { + pub fn run(&mut self) { + if self.second_migration.is_some() { + // Run to completion + print_heading(&format!("Test completion: {}", self.name)); + self.run_internal(RunType::Completion); + + // Run and abort + print_heading(&format!("Test abort: {}", self.name)); + self.run_internal(RunType::Abort); + } else { + print_heading(&format!("Test: {}", self.name)); + self.run_internal(RunType::Simple); + } + } + + fn run_internal(&mut self, run_type: RunType) { + print_subheading("Clearing database"); + self.reshape.remove().unwrap(); + + // Apply first migration, will automatically complete + print_subheading("Applying first migration"); + let first_migration = self + .first_migration + .as_ref() + .expect("no starting migration set"); + self.reshape.migrate(vec![first_migration.clone()]).unwrap(); + + // Update search path + self.old_db + .simple_query(&reshape::schema_query_for_migration(&first_migration.name)) + .unwrap(); + + // Run setup function + if let Some(after_first_fn) = self.after_first_fn { + print_subheading("Running setup and first checks"); + after_first_fn(&mut self.old_db); + print_success(); + } + + // Apply second migration + if let Some(second_migration) = &self.second_migration { + print_subheading("Applying second migration"); + self.reshape + .migrate(vec![first_migration.clone(), second_migration.clone()]) + .unwrap(); + + // Update search path + self.new_db + .simple_query(&reshape::schema_query_for_migration(&second_migration.name)) + .unwrap(); + + if let Some(intermediate_fn) = self.intermediate_fn { + print_subheading("Running intermediate checks"); + intermediate_fn(&mut self.old_db, &mut self.new_db); + print_success(); + } + + match run_type { + RunType::Completion => { + print_subheading("Completing"); + self.reshape.complete().unwrap(); + + if let Some(after_completion_fn) = self.after_completion_fn { + print_subheading("Running post-completion checks"); + after_completion_fn(&mut self.new_db); + print_success(); + } + } + RunType::Abort => { + print_subheading("Aborting"); + self.reshape.abort().unwrap(); + + if let Some(after_abort_fn) = self.after_abort_fn { + print_subheading("Running post-abort checks"); + after_abort_fn(&mut self.old_db); + print_success(); + } + } + _ => {} + } + } + + print_subheading("Checking cleanup"); + assert_cleaned_up(&mut self.new_db); + print_success(); + } +} + +fn print_heading(text: &str) { + let delimiter = std::iter::repeat("=").take(50).collect::(); + + println!(); + println!(); + println!("{}", delimiter.blue().bold()); + println!("{}", add_spacer(text, "=").blue().bold()); + println!("{}", delimiter.blue().bold()); +} + +fn print_subheading(text: &str) { + println!(); + println!("{}", add_spacer(text, "=").blue()); +} + +fn print_success() { + println!("{}", add_spacer("Success", "=").green()); +} + +fn add_spacer(text: &str, char: &str) -> String { + const TARGET_WIDTH: usize = 50; + let num_of_chars = (TARGET_WIDTH - text.len() - 2) / 2; + let spacer = std::iter::repeat(char) + .take(num_of_chars) + .collect::(); + + let extra = if text.len() % 2 == 0 { "" } else { char }; + + format!("{spacer} {text} {spacer}{extra}", spacer = spacer) +} pub fn setup() -> (Reshape, Client, Client) { let connection_string = std::env::var("POSTGRES_CONNECTION_STRING") diff --git a/tests/create_table.rs b/tests/create_table.rs index 628a571..bae1356 100644 --- a/tests/create_table.rs +++ b/tests/create_table.rs @@ -4,113 +4,115 @@ mod common; #[test] fn create_table() { - let (mut reshape, mut db, _) = common::setup(); + let mut test = common::Test::new("Create table"); - let create_table_migration = Migration::new("create_users_table", None).with_action( - CreateTableBuilder::default() - .name("users") - .primary_key(vec!["id".to_string()]) - .columns(vec![ - ColumnBuilder::default() - .name("id") - .data_type("INTEGER") - .generated("ALWAYS AS IDENTITY") - .build() - .unwrap(), - ColumnBuilder::default() - .name("name") - .data_type("TEXT") - .build() - .unwrap(), - ColumnBuilder::default() - .name("created_at") - .data_type("TIMESTAMP") - .nullable(false) - .default_value("NOW()") - .build() - .unwrap(), - ]) - .build() - .unwrap(), + test.first_migration( + Migration::new("create_users_table", None).with_action( + CreateTableBuilder::default() + .name("users") + .primary_key(vec!["id".to_string()]) + .columns(vec![ + ColumnBuilder::default() + .name("id") + .data_type("INTEGER") + .generated("ALWAYS AS IDENTITY") + .build() + .unwrap(), + ColumnBuilder::default() + .name("name") + .data_type("TEXT") + .build() + .unwrap(), + ColumnBuilder::default() + .name("created_at") + .data_type("TIMESTAMP") + .nullable(false) + .default_value("NOW()") + .build() + .unwrap(), + ]) + .build() + .unwrap(), + ), ); - reshape - .migrate(vec![create_table_migration.clone()]) - .unwrap(); + test.after_first(|db| { + // Ensure table was created + let result = db + .query_opt( + " + SELECT table_name + FROM information_schema.tables + WHERE table_name = 'users' AND table_schema = 'public'", + &[], + ) + .unwrap(); + assert!(result.is_some()); - // Ensure table was created - let result = db - .query_opt( - " - SELECT table_name - FROM information_schema.tables - WHERE table_name = 'users' AND table_schema = 'public'", - &[], - ) - .unwrap(); - assert!(result.is_some()); + // Ensure table has the right columns + let result = db + .query( + " + SELECT column_name, column_default, is_nullable, data_type + FROM information_schema.columns + WHERE table_name = 'users' AND table_schema = 'public' + ORDER BY ordinal_position", + &[], + ) + .unwrap(); - // Ensure table has the right columns - let result = db - .query( - " - SELECT column_name, column_default, is_nullable, data_type - FROM information_schema.columns - WHERE table_name = 'users' AND table_schema = 'public' - ORDER BY ordinal_position", - &[], - ) - .unwrap(); + // id column + let id_row = &result[0]; + assert_eq!("id", id_row.get::<_, String>("column_name")); + assert!(id_row.get::<_, Option>("column_default").is_none()); + assert_eq!("NO", id_row.get::<_, String>("is_nullable")); + assert_eq!("integer", id_row.get::<_, String>("data_type")); - // id column - let id_row = &result[0]; - assert_eq!("id", id_row.get::<_, String>("column_name")); - assert!(id_row.get::<_, Option>("column_default").is_none()); - assert_eq!("NO", id_row.get::<_, String>("is_nullable")); - assert_eq!("integer", id_row.get::<_, String>("data_type")); + // name column + let name_row = &result[1]; + assert_eq!("name", name_row.get::<_, String>("column_name")); + assert!(name_row + .get::<_, Option>("column_default") + .is_none()); + assert_eq!("YES", name_row.get::<_, String>("is_nullable")); + assert_eq!("text", name_row.get::<_, String>("data_type")); - // name column - let name_row = &result[1]; - assert_eq!("name", name_row.get::<_, String>("column_name")); - assert!(name_row - .get::<_, Option>("column_default") - .is_none()); - assert_eq!("YES", name_row.get::<_, String>("is_nullable")); - assert_eq!("text", name_row.get::<_, String>("data_type")); + // created_at column + let created_at_column = &result[2]; + assert_eq!( + "created_at", + created_at_column.get::<_, String>("column_name") + ); + assert!(created_at_column + .get::<_, Option>("column_default") + .is_some()); + assert_eq!("NO", created_at_column.get::<_, String>("is_nullable")); + assert_eq!( + "timestamp without time zone", + created_at_column.get::<_, String>("data_type") + ); - // created_at column - let created_at_column = &result[2]; - assert_eq!( - "created_at", - created_at_column.get::<_, String>("column_name") - ); - assert!(created_at_column - .get::<_, Option>("column_default") - .is_some()); - assert_eq!("NO", created_at_column.get::<_, String>("is_nullable")); - assert_eq!( - "timestamp without time zone", - created_at_column.get::<_, String>("data_type") - ); + // Ensure the primary key has the right columns + let primary_key_columns: Vec = db + .query( + " + SELECT a.attname AS column + FROM pg_index i + JOIN pg_attribute a ON a.attrelid = i.indrelid AND a.attnum = ANY(i.indkey) + JOIN pg_class t ON t.oid = i.indrelid + WHERE t.relname = 'users' AND i.indisprimary + ", + &[], + ) + .unwrap() + .iter() + .map(|row| row.get("column")) + .collect(); - // Ensure the primary key has the right columns - let primary_key_columns: Vec = db - .query( - " - SELECT a.attname AS column - FROM pg_index i - JOIN pg_attribute a ON a.attrelid = i.indrelid AND a.attnum = ANY(i.indkey) - WHERE i.indrelid = 'users'::regclass AND i.indisprimary - ", - &[], - ) - .unwrap() - .iter() - .map(|row| row.get("column")) - .collect(); + assert_eq!(vec!["id"], primary_key_columns); + }); - assert_eq!(vec!["id"], primary_key_columns); - common::assert_cleaned_up(&mut db); + test.run(); } #[test]