Use real table and column names for foreign keys in create_table

This commit is contained in:
fabianlindfors 2022-01-21 17:07:46 +01:00
parent 4e0e112eb2
commit b6f1a0c6a4
2 changed files with 28 additions and 14 deletions

View File

@ -1,4 +1,4 @@
use super::{Action, Column, MigrationContext};
use super::{common::ForeignKey, Action, Column, MigrationContext};
use crate::{
db::{Conn, Transaction},
schema::Schema,
@ -16,13 +16,6 @@ pub struct CreateTable {
pub foreign_keys: Vec<ForeignKey>,
}
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct ForeignKey {
pub columns: Vec<String>,
pub referenced_table: String,
pub referenced_columns: Vec<String>,
}
#[typetag::serde(name = "create_table")]
impl Action for CreateTable {
fn describe(&self) -> String {
@ -33,7 +26,7 @@ impl Action for CreateTable {
&self,
_ctx: &MigrationContext,
db: &mut dyn Conn,
_schema: &Schema,
schema: &Schema,
) -> anyhow::Result<()> {
let mut definition_rows: Vec<String> = self
.columns
@ -75,9 +68,10 @@ impl Action for CreateTable {
.iter()
.map(|col| format!("\"{}\"", col))
.collect();
let referenced_columns: Vec<String> = foreign_key
.referenced_columns
.iter()
let referenced_table = schema.get_table(db, &foreign_key.referenced_table)?;
let referenced_columns: Vec<String> = referenced_table
.real_column_names(&foreign_key.referenced_columns)
.map(|col| format!("\"{}\"", col))
.collect();
@ -86,7 +80,7 @@ impl Action for CreateTable {
FOREIGN KEY ({columns}) REFERENCES "{table}" ({referenced_columns})
"#,
columns = columns.join(", "),
table = foreign_key.referenced_table,
table = referenced_table.real_name,
referenced_columns = referenced_columns.join(", "),
));
}

View File

@ -1,5 +1,8 @@
use crate::db::Conn;
use std::collections::{HashMap, HashSet};
use std::{
collections::{HashMap, HashSet},
path::Iter,
};
// Schema tracks changes made to tables and columns during a migration.
// These changes are not applied until the migration is completed but
@ -286,3 +289,20 @@ impl Schema {
Ok(table)
}
}
impl Table {
pub fn real_column_names<'a>(
&'a self,
columns: &'a [String],
) -> impl Iterator<Item = &'a String> {
columns.iter().map(|name| {
self.get_column(name)
.map(|col| &col.real_name)
.unwrap_or(name)
})
}
fn get_column(&self, name: &str) -> Option<&Column> {
self.columns.iter().find(|column| column.name == name)
}
}