Give diesel.toml the same capabilities as every config flag

This gives every existing flag that can be given to `diesel
print-schema` an equivalent entry in `diesel.toml`, with the exception
of `database-url` which should never be in this file.

There are more manual serde impls than I would have liked here.
Particularly, the impl for `Filtering` should have just been derived,
but the derived impl just blows up with "expected string for key
`fields`" -- I have no clue why, and it's really opaque what's going on.
This commit is contained in:
Sean Griffin 2018-04-11 11:24:54 -06:00
parent ae22270600
commit e772508c2f
17 changed files with 202 additions and 44 deletions

View File

@ -20,7 +20,7 @@ clap = "2.27"
clippy = { optional = true, version = "=0.0.185" } clippy = { optional = true, version = "=0.0.185" }
diesel = { version = "~1.2.0", default-features = false } diesel = { version = "~1.2.0", default-features = false }
dotenv = ">=0.8, <0.11" dotenv = ">=0.8, <0.11"
infer_schema_internals = "~1.2.0" infer_schema_internals = { version = "~1.2.0", features = ["serde"] }
migrations_internals = "~1.2.0" migrations_internals = "~1.2.0"
serde = { version = "1.0.0", features = ["derive"] } serde = { version = "1.0.0", features = ["derive"] }
toml = "0.4.6" toml = "0.4.6"

View File

@ -7,8 +7,10 @@ use std::path::PathBuf;
use toml; use toml;
use super::{find_project_root, handle_error}; use super::{find_project_root, handle_error};
use print_schema;
#[derive(Deserialize)] #[derive(Deserialize)]
#[serde(deny_unknown_fields)]
pub struct Config { pub struct Config {
#[serde(default)] #[serde(default)]
pub print_schema: PrintSchema, pub print_schema: PrintSchema,
@ -36,6 +38,20 @@ impl Config {
} }
#[derive(Default, Deserialize)] #[derive(Default, Deserialize)]
#[serde(deny_unknown_fields)]
pub struct PrintSchema { pub struct PrintSchema {
#[serde(default)]
pub file: Option<PathBuf>, pub file: Option<PathBuf>,
#[serde(default)]
pub with_docs: bool,
#[serde(default)]
pub filter: print_schema::Filtering,
#[serde(default)]
pub schema: Option<String>,
}
impl PrintSchema {
pub fn schema_name(&self) -> Option<&str> {
self.schema.as_ref().map(|s| &**s)
}
} }

View File

@ -56,7 +56,7 @@ fn main() {
("setup", Some(matches)) => run_setup_command(matches), ("setup", Some(matches)) => run_setup_command(matches),
("database", Some(matches)) => run_database_command(matches), ("database", Some(matches)) => run_database_command(matches),
("bash-completion", Some(matches)) => generate_bash_completion_command(matches), ("bash-completion", Some(matches)) => generate_bash_completion_command(matches),
("print-schema", Some(matches)) => run_infer_schema(matches), ("print-schema", Some(matches)) => run_infer_schema(matches).unwrap_or_else(handle_error),
_ => unreachable!("The cli parser should prevent reaching here"), _ => unreachable!("The cli parser should prevent reaching here"),
} }
} }
@ -313,18 +313,22 @@ fn convert_absolute_path_to_relative(target_path: &Path, mut current_path: &Path
result.join(target_path.strip_prefix(current_path).unwrap()) result.join(target_path.strip_prefix(current_path).unwrap())
} }
fn run_infer_schema(matches: &ArgMatches) { fn run_infer_schema(matches: &ArgMatches) -> Result<(), Box<Error>> {
use infer_schema_internals::TableName; use infer_schema_internals::TableName;
use print_schema::*; use print_schema::*;
let database_url = database::database_url(matches); let database_url = database::database_url(matches);
let schema_name = matches.value_of("schema"); let mut config = Config::read(matches)?.print_schema;
if let Some(schema_name) = matches.value_of("schema") {
config.schema = Some(String::from(schema_name))
}
let filter = matches let filter = matches
.values_of("table-name") .values_of("table-name")
.unwrap_or_default() .unwrap_or_default()
.map(|table_name| { .map(|table_name| {
if let Some(schema) = schema_name { if let Some(schema) = config.schema_name() {
TableName::new(table_name, schema) TableName::new(table_name, schema)
} else { } else {
table_name.parse().unwrap() table_name.parse().unwrap()
@ -332,34 +336,30 @@ fn run_infer_schema(matches: &ArgMatches) {
}) })
.collect(); .collect();
let filter = if matches.is_present("whitelist") { if matches.is_present("whitelist") {
Filtering::Whitelist(filter) config.filter = Filtering::Whitelist(filter)
} else if matches.is_present("blacklist") { } else if matches.is_present("blacklist") {
Filtering::Blacklist(filter) config.filter = Filtering::Blacklist(filter)
} else { }
Filtering::None
};
let _ = run_print_schema( if matches.is_present("with-docs") {
&database_url, config.with_docs = true;
schema_name, }
&filter,
matches.is_present("with-docs"), run_print_schema(&database_url, &config)?;
).map_err(handle_error::<_, ()>); Ok(())
} }
fn regenerate_schema_if_file_specified(matches: &ArgMatches) -> Result<(), Box<Error>> { fn regenerate_schema_if_file_specified(matches: &ArgMatches) -> Result<(), Box<Error>> {
use print_schema::*;
let config = Config::read(matches)?; let config = Config::read(matches)?;
if let Some(path) = config.print_schema.file { if let Some(ref path) = config.print_schema.file {
if let Some(parent) = path.parent() { if let Some(parent) = path.parent() {
fs::create_dir_all(parent)?; fs::create_dir_all(parent)?;
} }
let database_url = database::database_url(matches); let database_url = database::database_url(matches);
let mut file = fs::File::create(path)?; let mut file = fs::File::create(path)?;
print_schema::output_schema(&database_url, None, &Filtering::None, false, &mut file)?; print_schema::output_schema(&database_url, &config.print_schema, &mut file)?;
} }
Ok(()) Ok(())
} }

View File

@ -1,7 +1,11 @@
use config;
use infer_schema_internals::*; use infer_schema_internals::*;
use std::error::Error; use std::error::Error;
use std::fmt::{self, Display, Formatter, Write}; use std::fmt::{self, Display, Formatter, Write};
use std::io::{self, stdout}; use std::io::{self, stdout};
use serde::de::{self, MapAccess, Visitor};
use serde::{Deserialize, Deserializer};
pub enum Filtering { pub enum Filtering {
Whitelist(Vec<TableName>), Whitelist(Vec<TableName>),
@ -9,6 +13,12 @@ pub enum Filtering {
None, None,
} }
impl Default for Filtering {
fn default() -> Self {
Filtering::None
}
}
impl Filtering { impl Filtering {
pub fn should_ignore_table(&self, name: &TableName) -> bool { pub fn should_ignore_table(&self, name: &TableName) -> bool {
use self::Filtering::*; use self::Filtering::*;
@ -23,31 +33,21 @@ impl Filtering {
pub fn run_print_schema( pub fn run_print_schema(
database_url: &str, database_url: &str,
schema_name: Option<&str>, config: &config::PrintSchema,
filtering: &Filtering,
include_docs: bool,
) -> Result<(), Box<Error>> { ) -> Result<(), Box<Error>> {
output_schema( output_schema(database_url, config, &mut stdout())
database_url,
schema_name,
filtering,
include_docs,
&mut stdout(),
)
} }
pub fn output_schema<W: io::Write>( pub fn output_schema<W: io::Write>(
database_url: &str, database_url: &str,
schema_name: Option<&str>, config: &config::PrintSchema,
filtering: &Filtering,
include_docs: bool,
out: &mut W, out: &mut W,
) -> Result<(), Box<Error>> { ) -> Result<(), Box<Error>> {
let table_names = load_table_names(database_url, schema_name)? let table_names = load_table_names(database_url, config.schema_name())?
.into_iter() .into_iter()
.filter(|t| !filtering.should_ignore_table(t)) .filter(|t| !config.filter.should_ignore_table(t))
.collect::<Vec<_>>(); .collect::<Vec<_>>();
let foreign_keys = load_foreign_key_constraints(database_url, schema_name)?; let foreign_keys = load_foreign_key_constraints(database_url, config.schema_name())?;
let foreign_keys = let foreign_keys =
remove_unsafe_foreign_keys_for_codegen(database_url, &foreign_keys, &table_names); remove_unsafe_foreign_keys_for_codegen(database_url, &foreign_keys, &table_names);
let table_data = table_names let table_data = table_names
@ -57,10 +57,10 @@ pub fn output_schema<W: io::Write>(
let definitions = TableDefinitions { let definitions = TableDefinitions {
tables: table_data, tables: table_data,
fk_constraints: foreign_keys, fk_constraints: foreign_keys,
include_docs, include_docs: config.with_docs,
}; };
if let Some(schema_name) = schema_name { if let Some(schema_name) = config.schema_name() {
write!(out, "{}", ModuleDefinition(schema_name, definitions))?; write!(out, "{}", ModuleDefinition(schema_name, definitions))?;
} else { } else {
write!(out, "{}", definitions)?; write!(out, "{}", definitions)?;
@ -255,3 +255,53 @@ impl<'a, 'b: 'a> Write for PadAdapter<'a, 'b> {
Ok(()) Ok(())
} }
} }
impl<'de> Deserialize<'de> for Filtering {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
struct FilteringVisitor;
impl<'de> Visitor<'de> for FilteringVisitor {
type Value = Filtering;
fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.write_str("either a whitelist or a blacklist")
}
fn visit_map<V>(self, mut map: V) -> Result<Self::Value, V::Error>
where
V: MapAccess<'de>,
{
let mut whitelist = None;
let mut blacklist = None;
while let Some((key, value)) = map.next_entry()? {
match key {
"whitelist" => {
if whitelist.is_some() {
return Err(de::Error::duplicate_field("whitelist"));
}
whitelist = Some(value);
}
"blacklist" => {
if blacklist.is_some() {
return Err(de::Error::duplicate_field("blacklist"));
}
blacklist = Some(value);
}
_ => return Err(de::Error::unknown_field(key, &["whitelist", "blacklist"])),
}
}
match (whitelist, blacklist) {
(Some(_), Some(_)) => Err(de::Error::duplicate_field("blacklist")),
(Some(w), None) => Ok(Filtering::Whitelist(w)),
(None, Some(b)) => Ok(Filtering::Blacklist(b)),
(None, None) => Ok(Filtering::None),
}
}
}
deserializer.deserialize_map(FilteringVisitor)
}
}

View File

@ -80,26 +80,48 @@ fn test_print_schema(test_name: &str, args: Vec<&str>) {
let test_path = Path::new(env!("CARGO_MANIFEST_DIR")) let test_path = Path::new(env!("CARGO_MANIFEST_DIR"))
.join("tests") .join("tests")
.join("print_schema") .join("print_schema")
.join(test_name) .join(test_name);
.join(BACKEND); let backend_path = test_path.join(BACKEND);
let p = project(test_name).build(); let p = project(test_name).build();
let db = database(&p.database_url()); let db = database(&p.database_url());
p.command("setup").run(); p.command("setup").run();
let schema = read_file(&test_path.join("schema.sql")); let schema = read_file(&backend_path.join("schema.sql"));
db.execute(&schema); db.execute(&schema);
let result = p.command("print-schema").args(args).run(); let result = p.command("print-schema").args(args).run();
assert!(result.is_success(), "Result was unsuccessful {:?}", result); assert!(result.is_success(), "Result was unsuccessful {:?}", result);
let expected = read_file(&test_path.join("expected.rs")); let expected = read_file(&backend_path.join("expected.rs"));
assert_diff!(&expected, result.stdout(), "\n", 0);
test_print_schema_config(test_name, &test_path, schema, expected);
}
fn test_print_schema_config(test_name: &str, test_path: &Path, schema: String, expected: String) {
let config = read_file(&test_path.join("diesel.toml"));
let p = project(&format!("{}_config", test_name))
.file("diesel.toml", &config)
.build();
p.command("setup").run();
p.create_migration("12345_create_schema", &schema, "");
let result = p.command("migration").arg("run").run();
assert!(result.is_success(), "Result was unsuccessful {:?}", result);
let schema = p.file_contents("src/schema.rs");
assert_diff!(&expected, &schema, "\n", 0);
let result = p.command("print-schema").run();
assert!(result.is_success(), "Result was unsuccessful {:?}", result);
assert_diff!(&expected, result.stdout(), "\n", 0); assert_diff!(&expected, result.stdout(), "\n", 0);
} }
fn read_file(path: &Path) -> String { fn read_file(path: &Path) -> String {
println!("{}", path.display());
let mut file = File::open(path).expect(&format!("Could not open {}", path.display())); let mut file = File::open(path).expect(&format!("Could not open {}", path.display()));
let mut string = String::new(); let mut string = String::new();
file.read_to_string(&mut string).unwrap(); file.read_to_string(&mut string).unwrap();

View File

@ -0,0 +1,4 @@
[print_schema]
file = "src/schema.rs"
with_docs = true
filter = { blacklist = ["users1"] }

View File

@ -0,0 +1,3 @@
[print_schema]
file = "src/schema.rs"
with_docs = true

View File

@ -0,0 +1,3 @@
[print_schema]
file = "src/schema.rs"
with_docs = true

View File

@ -0,0 +1,3 @@
[print_schema]
file = "src/schema.rs"
with_docs = true

View File

@ -0,0 +1,3 @@
[print_schema]
file = "src/schema.rs"
with_docs = true

View File

@ -0,0 +1,2 @@
[print_schema]
file = "src/schema.rs"

View File

@ -0,0 +1,4 @@
[print_schema]
file = "src/schema.rs"
with_docs = true
schema = "custom_schema"

View File

@ -0,0 +1,4 @@
[print_schema]
file = "src/schema.rs"
with_docs = true
schema = "custom_schema"

View File

@ -0,0 +1,4 @@
[print_schema]
file = "src/schema.rs"
with_docs = true
filter = { whitelist = ["users1"] }

View File

@ -0,0 +1,3 @@
[print_schema]
file = "src/schema.rs"
with_docs = true

View File

@ -12,6 +12,7 @@ keywords = ["orm", "database", "postgres", "postgresql", "sql"]
[dependencies] [dependencies]
diesel = { version = "~1.2.0", default-features = false } diesel = { version = "~1.2.0", default-features = false }
clippy = { optional = true, version = "=0.0.185" } clippy = { optional = true, version = "=0.0.185" }
serde = { version = "1.0.0", optional = true }
[dev-dependencies] [dev-dependencies]
dotenv = ">=0.8, <0.11" dotenv = ">=0.8, <0.11"

View File

@ -82,3 +82,39 @@ pub struct TableData {
pub column_data: Vec<ColumnDefinition>, pub column_data: Vec<ColumnDefinition>,
pub docs: String, pub docs: String,
} }
#[cfg(feature = "serde")]
mod serde_impls {
extern crate serde;
use self::serde::de::Visitor;
use self::serde::{de, Deserialize, Deserializer};
use std::fmt;
use super::TableName;
impl<'de> Deserialize<'de> for TableName {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
struct TableNameVisitor;
impl<'de> Visitor<'de> for TableNameVisitor {
type Value = TableName;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("A valid table name")
}
fn visit_str<E>(self, value: &str) -> Result<TableName, E>
where
E: de::Error,
{
value.parse().map_err(|_| unreachable!())
}
}
deserializer.deserialize_string(TableNameVisitor)
}
}
}