Merge pull request #1632 from diesel-rs/sg-cli-config-flag-replacements

Give diesel.toml the same capabilities as every CLI flag
This commit is contained in:
Sean Griffin 2018-04-12 11:23:49 -06:00 committed by GitHub
commit 518d7c1fac
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 202 additions and 44 deletions

View File

@ -20,7 +20,7 @@ clap = "2.27"
clippy = { optional = true, version = "=0.0.185" }
diesel = { version = "~1.2.0", default-features = false }
dotenv = ">=0.8, <0.11"
infer_schema_internals = "~1.2.0"
infer_schema_internals = { version = "~1.2.0", features = ["serde"] }
migrations_internals = "~1.2.0"
serde = { version = "1.0.0", features = ["derive"] }
toml = "0.4.6"

View File

@ -7,8 +7,10 @@ use std::path::PathBuf;
use toml;
use super::{find_project_root, handle_error};
use print_schema;
#[derive(Deserialize)]
#[serde(deny_unknown_fields)]
pub struct Config {
#[serde(default)]
pub print_schema: PrintSchema,
@ -36,6 +38,20 @@ impl Config {
}
#[derive(Default, Deserialize)]
#[serde(deny_unknown_fields)]
pub struct PrintSchema {
#[serde(default)]
pub file: Option<PathBuf>,
#[serde(default)]
pub with_docs: bool,
#[serde(default)]
pub filter: print_schema::Filtering,
#[serde(default)]
pub schema: Option<String>,
}
impl PrintSchema {
pub fn schema_name(&self) -> Option<&str> {
self.schema.as_ref().map(|s| &**s)
}
}

View File

@ -56,7 +56,7 @@ fn main() {
("setup", Some(matches)) => run_setup_command(matches),
("database", Some(matches)) => run_database_command(matches),
("bash-completion", Some(matches)) => generate_bash_completion_command(matches),
("print-schema", Some(matches)) => run_infer_schema(matches),
("print-schema", Some(matches)) => run_infer_schema(matches).unwrap_or_else(handle_error),
_ => unreachable!("The cli parser should prevent reaching here"),
}
}
@ -313,18 +313,22 @@ fn convert_absolute_path_to_relative(target_path: &Path, mut current_path: &Path
result.join(target_path.strip_prefix(current_path).unwrap())
}
fn run_infer_schema(matches: &ArgMatches) {
fn run_infer_schema(matches: &ArgMatches) -> Result<(), Box<Error>> {
use infer_schema_internals::TableName;
use print_schema::*;
let database_url = database::database_url(matches);
let schema_name = matches.value_of("schema");
let mut config = Config::read(matches)?.print_schema;
if let Some(schema_name) = matches.value_of("schema") {
config.schema = Some(String::from(schema_name))
}
let filter = matches
.values_of("table-name")
.unwrap_or_default()
.map(|table_name| {
if let Some(schema) = schema_name {
if let Some(schema) = config.schema_name() {
TableName::new(table_name, schema)
} else {
table_name.parse().unwrap()
@ -332,34 +336,30 @@ fn run_infer_schema(matches: &ArgMatches) {
})
.collect();
let filter = if matches.is_present("whitelist") {
Filtering::Whitelist(filter)
if matches.is_present("whitelist") {
config.filter = Filtering::Whitelist(filter)
} else if matches.is_present("blacklist") {
Filtering::Blacklist(filter)
} else {
Filtering::None
};
config.filter = Filtering::Blacklist(filter)
}
let _ = run_print_schema(
&database_url,
schema_name,
&filter,
matches.is_present("with-docs"),
).map_err(handle_error::<_, ()>);
if matches.is_present("with-docs") {
config.with_docs = true;
}
run_print_schema(&database_url, &config)?;
Ok(())
}
fn regenerate_schema_if_file_specified(matches: &ArgMatches) -> Result<(), Box<Error>> {
use print_schema::*;
let config = Config::read(matches)?;
if let Some(path) = config.print_schema.file {
if let Some(ref path) = config.print_schema.file {
if let Some(parent) = path.parent() {
fs::create_dir_all(parent)?;
}
let database_url = database::database_url(matches);
let mut file = fs::File::create(path)?;
print_schema::output_schema(&database_url, None, &Filtering::None, false, &mut file)?;
print_schema::output_schema(&database_url, &config.print_schema, &mut file)?;
}
Ok(())
}

View File

@ -1,7 +1,11 @@
use config;
use infer_schema_internals::*;
use std::error::Error;
use std::fmt::{self, Display, Formatter, Write};
use std::io::{self, stdout};
use serde::de::{self, MapAccess, Visitor};
use serde::{Deserialize, Deserializer};
pub enum Filtering {
Whitelist(Vec<TableName>),
@ -9,6 +13,12 @@ pub enum Filtering {
None,
}
impl Default for Filtering {
fn default() -> Self {
Filtering::None
}
}
impl Filtering {
pub fn should_ignore_table(&self, name: &TableName) -> bool {
use self::Filtering::*;
@ -23,31 +33,21 @@ impl Filtering {
pub fn run_print_schema(
database_url: &str,
schema_name: Option<&str>,
filtering: &Filtering,
include_docs: bool,
config: &config::PrintSchema,
) -> Result<(), Box<Error>> {
output_schema(
database_url,
schema_name,
filtering,
include_docs,
&mut stdout(),
)
output_schema(database_url, config, &mut stdout())
}
pub fn output_schema<W: io::Write>(
database_url: &str,
schema_name: Option<&str>,
filtering: &Filtering,
include_docs: bool,
config: &config::PrintSchema,
out: &mut W,
) -> Result<(), Box<Error>> {
let table_names = load_table_names(database_url, schema_name)?
let table_names = load_table_names(database_url, config.schema_name())?
.into_iter()
.filter(|t| !filtering.should_ignore_table(t))
.filter(|t| !config.filter.should_ignore_table(t))
.collect::<Vec<_>>();
let foreign_keys = load_foreign_key_constraints(database_url, schema_name)?;
let foreign_keys = load_foreign_key_constraints(database_url, config.schema_name())?;
let foreign_keys =
remove_unsafe_foreign_keys_for_codegen(database_url, &foreign_keys, &table_names);
let table_data = table_names
@ -57,10 +57,10 @@ pub fn output_schema<W: io::Write>(
let definitions = TableDefinitions {
tables: table_data,
fk_constraints: foreign_keys,
include_docs,
include_docs: config.with_docs,
};
if let Some(schema_name) = schema_name {
if let Some(schema_name) = config.schema_name() {
write!(out, "{}", ModuleDefinition(schema_name, definitions))?;
} else {
write!(out, "{}", definitions)?;
@ -255,3 +255,53 @@ impl<'a, 'b: 'a> Write for PadAdapter<'a, 'b> {
Ok(())
}
}
impl<'de> Deserialize<'de> for Filtering {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
struct FilteringVisitor;
impl<'de> Visitor<'de> for FilteringVisitor {
type Value = Filtering;
fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.write_str("either a whitelist or a blacklist")
}
fn visit_map<V>(self, mut map: V) -> Result<Self::Value, V::Error>
where
V: MapAccess<'de>,
{
let mut whitelist = None;
let mut blacklist = None;
while let Some((key, value)) = map.next_entry()? {
match key {
"whitelist" => {
if whitelist.is_some() {
return Err(de::Error::duplicate_field("whitelist"));
}
whitelist = Some(value);
}
"blacklist" => {
if blacklist.is_some() {
return Err(de::Error::duplicate_field("blacklist"));
}
blacklist = Some(value);
}
_ => return Err(de::Error::unknown_field(key, &["whitelist", "blacklist"])),
}
}
match (whitelist, blacklist) {
(Some(_), Some(_)) => Err(de::Error::duplicate_field("blacklist")),
(Some(w), None) => Ok(Filtering::Whitelist(w)),
(None, Some(b)) => Ok(Filtering::Blacklist(b)),
(None, None) => Ok(Filtering::None),
}
}
}
deserializer.deserialize_map(FilteringVisitor)
}
}

View File

@ -80,26 +80,48 @@ fn test_print_schema(test_name: &str, args: Vec<&str>) {
let test_path = Path::new(env!("CARGO_MANIFEST_DIR"))
.join("tests")
.join("print_schema")
.join(test_name)
.join(BACKEND);
.join(test_name);
let backend_path = test_path.join(BACKEND);
let p = project(test_name).build();
let db = database(&p.database_url());
p.command("setup").run();
let schema = read_file(&test_path.join("schema.sql"));
let schema = read_file(&backend_path.join("schema.sql"));
db.execute(&schema);
let result = p.command("print-schema").args(args).run();
assert!(result.is_success(), "Result was unsuccessful {:?}", result);
let expected = read_file(&test_path.join("expected.rs"));
let expected = read_file(&backend_path.join("expected.rs"));
assert_diff!(&expected, result.stdout(), "\n", 0);
test_print_schema_config(test_name, &test_path, schema, expected);
}
fn test_print_schema_config(test_name: &str, test_path: &Path, schema: String, expected: String) {
let config = read_file(&test_path.join("diesel.toml"));
let p = project(&format!("{}_config", test_name))
.file("diesel.toml", &config)
.build();
p.command("setup").run();
p.create_migration("12345_create_schema", &schema, "");
let result = p.command("migration").arg("run").run();
assert!(result.is_success(), "Result was unsuccessful {:?}", result);
let schema = p.file_contents("src/schema.rs");
assert_diff!(&expected, &schema, "\n", 0);
let result = p.command("print-schema").run();
assert!(result.is_success(), "Result was unsuccessful {:?}", result);
assert_diff!(&expected, result.stdout(), "\n", 0);
}
fn read_file(path: &Path) -> String {
println!("{}", path.display());
let mut file = File::open(path).expect(&format!("Could not open {}", path.display()));
let mut string = String::new();
file.read_to_string(&mut string).unwrap();

View File

@ -0,0 +1,4 @@
[print_schema]
file = "src/schema.rs"
with_docs = true
filter = { blacklist = ["users1"] }

View File

@ -0,0 +1,3 @@
[print_schema]
file = "src/schema.rs"
with_docs = true

View File

@ -0,0 +1,3 @@
[print_schema]
file = "src/schema.rs"
with_docs = true

View File

@ -0,0 +1,3 @@
[print_schema]
file = "src/schema.rs"
with_docs = true

View File

@ -0,0 +1,3 @@
[print_schema]
file = "src/schema.rs"
with_docs = true

View File

@ -0,0 +1,2 @@
[print_schema]
file = "src/schema.rs"

View File

@ -0,0 +1,4 @@
[print_schema]
file = "src/schema.rs"
with_docs = true
schema = "custom_schema"

View File

@ -0,0 +1,4 @@
[print_schema]
file = "src/schema.rs"
with_docs = true
schema = "custom_schema"

View File

@ -0,0 +1,4 @@
[print_schema]
file = "src/schema.rs"
with_docs = true
filter = { whitelist = ["users1"] }

View File

@ -0,0 +1,3 @@
[print_schema]
file = "src/schema.rs"
with_docs = true

View File

@ -12,6 +12,7 @@ keywords = ["orm", "database", "postgres", "postgresql", "sql"]
[dependencies]
diesel = { version = "~1.2.0", default-features = false }
clippy = { optional = true, version = "=0.0.185" }
serde = { version = "1.0.0", optional = true }
[dev-dependencies]
dotenv = ">=0.8, <0.11"

View File

@ -82,3 +82,39 @@ pub struct TableData {
pub column_data: Vec<ColumnDefinition>,
pub docs: String,
}
#[cfg(feature = "serde")]
mod serde_impls {
extern crate serde;
use self::serde::de::Visitor;
use self::serde::{de, Deserialize, Deserializer};
use std::fmt;
use super::TableName;
impl<'de> Deserialize<'de> for TableName {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
struct TableNameVisitor;
impl<'de> Visitor<'de> for TableNameVisitor {
type Value = TableName;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("A valid table name")
}
fn visit_str<E>(self, value: &str) -> Result<TableName, E>
where
E: de::Error,
{
value.parse().map_err(|_| unreachable!())
}
}
deserializer.deserialize_string(TableNameVisitor)
}
}
}