mirror of
https://github.com/diesel-rs/diesel.git
synced 2024-10-04 09:39:24 +03:00
Give diesel.toml the same capabilities as every config flag
This gives every existing flag that can be given to `diesel print-schema` an equivalent entry in `diesel.toml`, with the exception of `database-url` which should never be in this file. There are more manual serde impls than I would have liked here. Particularly, the impl for `Filtering` should have just been derived, but the derived impl just blows up with "expected string for key `fields`" -- I have no clue why, and it's really opaque what's going on.
This commit is contained in:
parent
ae22270600
commit
e772508c2f
@ -20,7 +20,7 @@ clap = "2.27"
|
||||
clippy = { optional = true, version = "=0.0.185" }
|
||||
diesel = { version = "~1.2.0", default-features = false }
|
||||
dotenv = ">=0.8, <0.11"
|
||||
infer_schema_internals = "~1.2.0"
|
||||
infer_schema_internals = { version = "~1.2.0", features = ["serde"] }
|
||||
migrations_internals = "~1.2.0"
|
||||
serde = { version = "1.0.0", features = ["derive"] }
|
||||
toml = "0.4.6"
|
||||
|
@ -7,8 +7,10 @@ use std::path::PathBuf;
|
||||
use toml;
|
||||
|
||||
use super::{find_project_root, handle_error};
|
||||
use print_schema;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
#[serde(deny_unknown_fields)]
|
||||
pub struct Config {
|
||||
#[serde(default)]
|
||||
pub print_schema: PrintSchema,
|
||||
@ -36,6 +38,20 @@ impl Config {
|
||||
}
|
||||
|
||||
#[derive(Default, Deserialize)]
|
||||
#[serde(deny_unknown_fields)]
|
||||
pub struct PrintSchema {
|
||||
#[serde(default)]
|
||||
pub file: Option<PathBuf>,
|
||||
#[serde(default)]
|
||||
pub with_docs: bool,
|
||||
#[serde(default)]
|
||||
pub filter: print_schema::Filtering,
|
||||
#[serde(default)]
|
||||
pub schema: Option<String>,
|
||||
}
|
||||
|
||||
impl PrintSchema {
|
||||
pub fn schema_name(&self) -> Option<&str> {
|
||||
self.schema.as_ref().map(|s| &**s)
|
||||
}
|
||||
}
|
||||
|
@ -56,7 +56,7 @@ fn main() {
|
||||
("setup", Some(matches)) => run_setup_command(matches),
|
||||
("database", Some(matches)) => run_database_command(matches),
|
||||
("bash-completion", Some(matches)) => generate_bash_completion_command(matches),
|
||||
("print-schema", Some(matches)) => run_infer_schema(matches),
|
||||
("print-schema", Some(matches)) => run_infer_schema(matches).unwrap_or_else(handle_error),
|
||||
_ => unreachable!("The cli parser should prevent reaching here"),
|
||||
}
|
||||
}
|
||||
@ -313,18 +313,22 @@ fn convert_absolute_path_to_relative(target_path: &Path, mut current_path: &Path
|
||||
result.join(target_path.strip_prefix(current_path).unwrap())
|
||||
}
|
||||
|
||||
fn run_infer_schema(matches: &ArgMatches) {
|
||||
fn run_infer_schema(matches: &ArgMatches) -> Result<(), Box<Error>> {
|
||||
use infer_schema_internals::TableName;
|
||||
use print_schema::*;
|
||||
|
||||
let database_url = database::database_url(matches);
|
||||
let schema_name = matches.value_of("schema");
|
||||
let mut config = Config::read(matches)?.print_schema;
|
||||
|
||||
if let Some(schema_name) = matches.value_of("schema") {
|
||||
config.schema = Some(String::from(schema_name))
|
||||
}
|
||||
|
||||
let filter = matches
|
||||
.values_of("table-name")
|
||||
.unwrap_or_default()
|
||||
.map(|table_name| {
|
||||
if let Some(schema) = schema_name {
|
||||
if let Some(schema) = config.schema_name() {
|
||||
TableName::new(table_name, schema)
|
||||
} else {
|
||||
table_name.parse().unwrap()
|
||||
@ -332,34 +336,30 @@ fn run_infer_schema(matches: &ArgMatches) {
|
||||
})
|
||||
.collect();
|
||||
|
||||
let filter = if matches.is_present("whitelist") {
|
||||
Filtering::Whitelist(filter)
|
||||
if matches.is_present("whitelist") {
|
||||
config.filter = Filtering::Whitelist(filter)
|
||||
} else if matches.is_present("blacklist") {
|
||||
Filtering::Blacklist(filter)
|
||||
} else {
|
||||
Filtering::None
|
||||
};
|
||||
config.filter = Filtering::Blacklist(filter)
|
||||
}
|
||||
|
||||
let _ = run_print_schema(
|
||||
&database_url,
|
||||
schema_name,
|
||||
&filter,
|
||||
matches.is_present("with-docs"),
|
||||
).map_err(handle_error::<_, ()>);
|
||||
if matches.is_present("with-docs") {
|
||||
config.with_docs = true;
|
||||
}
|
||||
|
||||
run_print_schema(&database_url, &config)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn regenerate_schema_if_file_specified(matches: &ArgMatches) -> Result<(), Box<Error>> {
|
||||
use print_schema::*;
|
||||
|
||||
let config = Config::read(matches)?;
|
||||
if let Some(path) = config.print_schema.file {
|
||||
if let Some(ref path) = config.print_schema.file {
|
||||
if let Some(parent) = path.parent() {
|
||||
fs::create_dir_all(parent)?;
|
||||
}
|
||||
|
||||
let database_url = database::database_url(matches);
|
||||
let mut file = fs::File::create(path)?;
|
||||
print_schema::output_schema(&database_url, None, &Filtering::None, false, &mut file)?;
|
||||
print_schema::output_schema(&database_url, &config.print_schema, &mut file)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
@ -1,7 +1,11 @@
|
||||
use config;
|
||||
|
||||
use infer_schema_internals::*;
|
||||
use std::error::Error;
|
||||
use std::fmt::{self, Display, Formatter, Write};
|
||||
use std::io::{self, stdout};
|
||||
use serde::de::{self, MapAccess, Visitor};
|
||||
use serde::{Deserialize, Deserializer};
|
||||
|
||||
pub enum Filtering {
|
||||
Whitelist(Vec<TableName>),
|
||||
@ -9,6 +13,12 @@ pub enum Filtering {
|
||||
None,
|
||||
}
|
||||
|
||||
impl Default for Filtering {
|
||||
fn default() -> Self {
|
||||
Filtering::None
|
||||
}
|
||||
}
|
||||
|
||||
impl Filtering {
|
||||
pub fn should_ignore_table(&self, name: &TableName) -> bool {
|
||||
use self::Filtering::*;
|
||||
@ -23,31 +33,21 @@ impl Filtering {
|
||||
|
||||
pub fn run_print_schema(
|
||||
database_url: &str,
|
||||
schema_name: Option<&str>,
|
||||
filtering: &Filtering,
|
||||
include_docs: bool,
|
||||
config: &config::PrintSchema,
|
||||
) -> Result<(), Box<Error>> {
|
||||
output_schema(
|
||||
database_url,
|
||||
schema_name,
|
||||
filtering,
|
||||
include_docs,
|
||||
&mut stdout(),
|
||||
)
|
||||
output_schema(database_url, config, &mut stdout())
|
||||
}
|
||||
|
||||
pub fn output_schema<W: io::Write>(
|
||||
database_url: &str,
|
||||
schema_name: Option<&str>,
|
||||
filtering: &Filtering,
|
||||
include_docs: bool,
|
||||
config: &config::PrintSchema,
|
||||
out: &mut W,
|
||||
) -> Result<(), Box<Error>> {
|
||||
let table_names = load_table_names(database_url, schema_name)?
|
||||
let table_names = load_table_names(database_url, config.schema_name())?
|
||||
.into_iter()
|
||||
.filter(|t| !filtering.should_ignore_table(t))
|
||||
.filter(|t| !config.filter.should_ignore_table(t))
|
||||
.collect::<Vec<_>>();
|
||||
let foreign_keys = load_foreign_key_constraints(database_url, schema_name)?;
|
||||
let foreign_keys = load_foreign_key_constraints(database_url, config.schema_name())?;
|
||||
let foreign_keys =
|
||||
remove_unsafe_foreign_keys_for_codegen(database_url, &foreign_keys, &table_names);
|
||||
let table_data = table_names
|
||||
@ -57,10 +57,10 @@ pub fn output_schema<W: io::Write>(
|
||||
let definitions = TableDefinitions {
|
||||
tables: table_data,
|
||||
fk_constraints: foreign_keys,
|
||||
include_docs,
|
||||
include_docs: config.with_docs,
|
||||
};
|
||||
|
||||
if let Some(schema_name) = schema_name {
|
||||
if let Some(schema_name) = config.schema_name() {
|
||||
write!(out, "{}", ModuleDefinition(schema_name, definitions))?;
|
||||
} else {
|
||||
write!(out, "{}", definitions)?;
|
||||
@ -255,3 +255,53 @@ impl<'a, 'b: 'a> Write for PadAdapter<'a, 'b> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> Deserialize<'de> for Filtering {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
struct FilteringVisitor;
|
||||
|
||||
impl<'de> Visitor<'de> for FilteringVisitor {
|
||||
type Value = Filtering;
|
||||
|
||||
fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
f.write_str("either a whitelist or a blacklist")
|
||||
}
|
||||
|
||||
fn visit_map<V>(self, mut map: V) -> Result<Self::Value, V::Error>
|
||||
where
|
||||
V: MapAccess<'de>,
|
||||
{
|
||||
let mut whitelist = None;
|
||||
let mut blacklist = None;
|
||||
while let Some((key, value)) = map.next_entry()? {
|
||||
match key {
|
||||
"whitelist" => {
|
||||
if whitelist.is_some() {
|
||||
return Err(de::Error::duplicate_field("whitelist"));
|
||||
}
|
||||
whitelist = Some(value);
|
||||
}
|
||||
"blacklist" => {
|
||||
if blacklist.is_some() {
|
||||
return Err(de::Error::duplicate_field("blacklist"));
|
||||
}
|
||||
blacklist = Some(value);
|
||||
}
|
||||
_ => return Err(de::Error::unknown_field(key, &["whitelist", "blacklist"])),
|
||||
}
|
||||
}
|
||||
match (whitelist, blacklist) {
|
||||
(Some(_), Some(_)) => Err(de::Error::duplicate_field("blacklist")),
|
||||
(Some(w), None) => Ok(Filtering::Whitelist(w)),
|
||||
(None, Some(b)) => Ok(Filtering::Blacklist(b)),
|
||||
(None, None) => Ok(Filtering::None),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
deserializer.deserialize_map(FilteringVisitor)
|
||||
}
|
||||
}
|
||||
|
@ -80,26 +80,48 @@ fn test_print_schema(test_name: &str, args: Vec<&str>) {
|
||||
let test_path = Path::new(env!("CARGO_MANIFEST_DIR"))
|
||||
.join("tests")
|
||||
.join("print_schema")
|
||||
.join(test_name)
|
||||
.join(BACKEND);
|
||||
.join(test_name);
|
||||
let backend_path = test_path.join(BACKEND);
|
||||
let p = project(test_name).build();
|
||||
let db = database(&p.database_url());
|
||||
|
||||
p.command("setup").run();
|
||||
|
||||
let schema = read_file(&test_path.join("schema.sql"));
|
||||
let schema = read_file(&backend_path.join("schema.sql"));
|
||||
db.execute(&schema);
|
||||
|
||||
let result = p.command("print-schema").args(args).run();
|
||||
|
||||
assert!(result.is_success(), "Result was unsuccessful {:?}", result);
|
||||
let expected = read_file(&test_path.join("expected.rs"));
|
||||
let expected = read_file(&backend_path.join("expected.rs"));
|
||||
|
||||
assert_diff!(&expected, result.stdout(), "\n", 0);
|
||||
|
||||
test_print_schema_config(test_name, &test_path, schema, expected);
|
||||
}
|
||||
|
||||
fn test_print_schema_config(test_name: &str, test_path: &Path, schema: String, expected: String) {
|
||||
let config = read_file(&test_path.join("diesel.toml"));
|
||||
let p = project(&format!("{}_config", test_name))
|
||||
.file("diesel.toml", &config)
|
||||
.build();
|
||||
|
||||
p.command("setup").run();
|
||||
p.create_migration("12345_create_schema", &schema, "");
|
||||
|
||||
let result = p.command("migration").arg("run").run();
|
||||
assert!(result.is_success(), "Result was unsuccessful {:?}", result);
|
||||
|
||||
let schema = p.file_contents("src/schema.rs");
|
||||
assert_diff!(&expected, &schema, "\n", 0);
|
||||
|
||||
let result = p.command("print-schema").run();
|
||||
assert!(result.is_success(), "Result was unsuccessful {:?}", result);
|
||||
|
||||
assert_diff!(&expected, result.stdout(), "\n", 0);
|
||||
}
|
||||
|
||||
fn read_file(path: &Path) -> String {
|
||||
println!("{}", path.display());
|
||||
let mut file = File::open(path).expect(&format!("Could not open {}", path.display()));
|
||||
let mut string = String::new();
|
||||
file.read_to_string(&mut string).unwrap();
|
||||
|
@ -0,0 +1,4 @@
|
||||
[print_schema]
|
||||
file = "src/schema.rs"
|
||||
with_docs = true
|
||||
filter = { blacklist = ["users1"] }
|
@ -0,0 +1,3 @@
|
||||
[print_schema]
|
||||
file = "src/schema.rs"
|
||||
with_docs = true
|
@ -0,0 +1,3 @@
|
||||
[print_schema]
|
||||
file = "src/schema.rs"
|
||||
with_docs = true
|
@ -0,0 +1,3 @@
|
||||
[print_schema]
|
||||
file = "src/schema.rs"
|
||||
with_docs = true
|
@ -0,0 +1,3 @@
|
||||
[print_schema]
|
||||
file = "src/schema.rs"
|
||||
with_docs = true
|
@ -0,0 +1,2 @@
|
||||
[print_schema]
|
||||
file = "src/schema.rs"
|
@ -0,0 +1,4 @@
|
||||
[print_schema]
|
||||
file = "src/schema.rs"
|
||||
with_docs = true
|
||||
schema = "custom_schema"
|
@ -0,0 +1,4 @@
|
||||
[print_schema]
|
||||
file = "src/schema.rs"
|
||||
with_docs = true
|
||||
schema = "custom_schema"
|
@ -0,0 +1,4 @@
|
||||
[print_schema]
|
||||
file = "src/schema.rs"
|
||||
with_docs = true
|
||||
filter = { whitelist = ["users1"] }
|
@ -0,0 +1,3 @@
|
||||
[print_schema]
|
||||
file = "src/schema.rs"
|
||||
with_docs = true
|
@ -12,6 +12,7 @@ keywords = ["orm", "database", "postgres", "postgresql", "sql"]
|
||||
[dependencies]
|
||||
diesel = { version = "~1.2.0", default-features = false }
|
||||
clippy = { optional = true, version = "=0.0.185" }
|
||||
serde = { version = "1.0.0", optional = true }
|
||||
|
||||
[dev-dependencies]
|
||||
dotenv = ">=0.8, <0.11"
|
||||
|
@ -82,3 +82,39 @@ pub struct TableData {
|
||||
pub column_data: Vec<ColumnDefinition>,
|
||||
pub docs: String,
|
||||
}
|
||||
|
||||
#[cfg(feature = "serde")]
|
||||
mod serde_impls {
|
||||
extern crate serde;
|
||||
|
||||
use self::serde::de::Visitor;
|
||||
use self::serde::{de, Deserialize, Deserializer};
|
||||
use std::fmt;
|
||||
use super::TableName;
|
||||
|
||||
impl<'de> Deserialize<'de> for TableName {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
struct TableNameVisitor;
|
||||
|
||||
impl<'de> Visitor<'de> for TableNameVisitor {
|
||||
type Value = TableName;
|
||||
|
||||
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
|
||||
formatter.write_str("A valid table name")
|
||||
}
|
||||
|
||||
fn visit_str<E>(self, value: &str) -> Result<TableName, E>
|
||||
where
|
||||
E: de::Error,
|
||||
{
|
||||
value.parse().map_err(|_| unreachable!())
|
||||
}
|
||||
}
|
||||
|
||||
deserializer.deserialize_string(TableNameVisitor)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user