There were a ton of style and whitespace issues. I've resolved as many
as I could. The implementation was also a bit all over the place and the
intent of the code was extremely unclear to me. I've tried to better
break things out into blocks based on the logical segments of the impl,
and what I think will need to be reused for moving future macros over to
pure procedural macro style.

I'd actually strongly prefer to have `row_pat` and `build_expr` be
methods on `Model` that return `syn::Pat` and `syn::Expr` respectively,
to show a bit more type safety in the struct vs tuple variant handling.
As it stands though, the difference in the amount of code required is
extremely large, and not worth it.
This commit is contained in:
Sean Griffin 2017-02-10 09:58:49 -05:00
parent e0302ac7bf
commit 1f5f8f3a8c
15 changed files with 137 additions and 89 deletions

View File

@ -10,7 +10,7 @@ repository = "https://github.com/diesel-rs/diesel/tree/master/diesel_codegen"
keywords = ["orm", "database", "postgres", "sql", "codegen"]
[dependencies]
syn = "0.10.3"
syn = { version = "0.11.4", features = ["aster"] }
quote = "0.3.12"
dotenv = { version = "0.8.0", optional = true }
diesel = { version = "0.10.0", default-features = false }
@ -28,3 +28,6 @@ default = ["dotenv"]
postgres = ["diesel_infer_schema/postgres", "diesel/postgres"]
sqlite = ["diesel_infer_schema/sqlite", "diesel/sqlite"]
mysql = ["diesel_infer_schema/mysql", "diesel/mysql"]
[[test]]
name = "tests"

View File

@ -12,7 +12,7 @@ pub fn derive_as_changeset(item: syn::MacroInput) -> quote::Tokens {
let table_name = model.table_name();
let struct_ty = &model.ty;
let mut lifetimes = item.generics.lifetimes;
let attrs = model.attrs.into_iter()
let attrs = model.attrs.as_slice().iter()
.filter(|a| a.column_name != Some(syn::Ident::new("id")))
.collect::<Vec<_>>();

View File

@ -8,7 +8,7 @@ pub fn derive_associations(input: syn::MacroInput) -> quote::Tokens {
let mut derived_associations = Vec::new();
let model = t!(Model::from_item(&input, "Associations"));
for attr in &input.attrs {
for attr in input.attrs.as_slice() {
if attr.name() == "has_many" {
let options = t!(build_association_options(attr, "has_many"));
derived_associations.push(expand_has_many(&model, options))
@ -30,7 +30,7 @@ fn expand_belongs_to(model: &Model, options: AssociationOptions) -> quote::Token
let foreign_key_name = options.foreign_key_name.unwrap_or_else(||
to_foreign_key(&parent_struct.as_ref()));
let child_table_name = model.table_name();
let fields = &model.attrs;
let fields = model.attrs.as_slice();
quote!(BelongsTo! {
(
@ -48,7 +48,7 @@ fn expand_has_many(model: &Model, options: AssociationOptions) -> quote::Tokens
let child_table_name = options.name;
let foreign_key_name = options.foreign_key_name.unwrap_or_else(||
to_foreign_key(&model.name.as_ref()));
let fields = &model.attrs;
let fields = model.attrs.as_slice();
quote!(HasMany! {
(

View File

@ -1,6 +1,8 @@
use quote;
use syn;
use std::borrow::Cow;
use util::*;
#[derive(Debug)]
@ -8,10 +10,11 @@ pub struct Attr {
pub column_name: Option<syn::Ident>,
pub field_name: Option<syn::Ident>,
pub ty: syn::Ty,
field_position: usize,
}
impl Attr {
pub fn from_struct_field(field: &syn::Field) -> Self {
pub fn from_struct_field((index, field): (usize, &syn::Field)) -> Self {
let field_name = field.ident.clone();
let column_name = ident_value_of_attr_with_name(&field.attrs, "column_name")
.map(Clone::clone)
@ -22,6 +25,14 @@ impl Attr {
column_name: column_name,
field_name: field_name,
ty: ty,
field_position: index,
}
}
pub fn name_for_pattern(&self) -> Cow<syn::Ident> {
match self.field_name {
Some(ref name) => Cow::Borrowed(name),
None => Cow::Owned(format!("field_{}", self.field_position).into()),
}
}

View File

@ -9,7 +9,7 @@ pub fn derive_identifiable(item: syn::MacroInput) -> Tokens {
let struct_ty = &model.ty;
let lifetimes = model.generics.lifetimes;
let primary_key_names = model.primary_key_names;
let fields = model.attrs;
let fields = model.attrs.as_slice();
for pk in &primary_key_names {
if !fields.iter().any(|f| f.field_name.as_ref() == Some(pk)) {
panic!("Could not find a field named `{}` on `{}`", pk, &model.name);

View File

@ -19,7 +19,7 @@ pub fn derive_insertable(item: syn::MacroInput) -> quote::Tokens {
let struct_ty = &model.ty;
let table_name = &model.table_name();
let lifetimes = model.generics.lifetimes;
let fields = model.attrs;
let fields = model.attrs.as_slice();
quote!(impl_Insertable! {
(

View File

@ -1,5 +1,4 @@
#![deny(warnings)]
#![recursion_limit="1024"]
macro_rules! t {
($expr:expr) => {

View File

@ -3,10 +3,9 @@ use syn;
use attr::Attr;
use util::*;
#[derive(Debug)]
pub struct Model {
pub ty: syn::Ty,
pub attrs: Vec<Attr>,
pub attrs: ModelAttrs,
pub name: syn::Ident,
pub generics: syn::Generics,
pub primary_key_names: Vec<syn::Ident>,
@ -15,12 +14,11 @@ pub struct Model {
impl Model {
pub fn from_item(item: &syn::MacroInput, derived_from: &str) -> Result<Self, String> {
let fields = match item.body {
let attrs = match item.body {
syn::Body::Enum(..) => return Err(format!(
"#[derive({})] cannot be used with enums", derived_from)),
syn::Body::Struct(ref fields) => fields.fields(),
syn::Body::Struct(ref fields) => ModelAttrs::from_struct_body(fields),
};
let attrs = fields.into_iter().map(Attr::from_struct_field).collect();
let ty = struct_ty(item.ident.clone(), &item.generics);
let name = item.ident.clone();
let generics = item.generics.clone();
@ -49,6 +47,10 @@ impl Model {
pub fn has_table_name_annotation(&self) -> bool {
self.table_name_from_annotation.is_some()
}
pub fn is_tuple_struct(&self) -> bool {
self.attrs.is_tuple()
}
}
pub fn infer_association_name(name: &str) -> String {
@ -73,6 +75,40 @@ fn infer_table_name(name: &str) -> String {
result
}
pub enum ModelAttrs {
Struct(Vec<Attr>),
Tuple(Vec<Attr>),
}
impl ModelAttrs {
fn from_struct_body(body: &syn::VariantData) -> Self {
let attrs = body.fields()
.into_iter()
.enumerate()
.map(Attr::from_struct_field)
.collect();
match *body {
syn::VariantData::Struct(_) => ModelAttrs::Struct(attrs),
_ => ModelAttrs::Tuple(attrs),
}
}
pub fn as_slice(&self) -> &[Attr] {
match *self {
ModelAttrs::Struct(ref attrs) => &*attrs,
ModelAttrs::Tuple(ref attrs) => &*attrs,
}
}
pub fn is_tuple(&self) -> bool {
match *self {
ModelAttrs::Struct(_) => false,
ModelAttrs::Tuple(_) => true,
}
}
}
#[test]
fn infer_table_name_pluralizes_and_downcases() {
assert_eq!("foos", &infer_table_name("Foo"));

View File

@ -1,59 +1,60 @@
use quote::Tokens;
use syn;
use attr::Attr;
use model::Model;
use util::wrap_item_in_const;
pub fn derive_queryable(item: syn::MacroInput) -> Tokens {
let model = t!(Model::from_item(&item, "Queryable"));
let struct_name = &model.name;
let ty_params = &model.generics.ty_params;
let lifetimes = &model.generics.lifetimes;
let struct_ty = if !ty_params.is_empty() || !lifetimes.is_empty(){
quote!(#struct_name<#(#lifetimes,)* #(#ty_params,)*>)
} else {
quote!(#struct_name)
};
let row_ty = model.attrs.iter().map(|a| &a.ty);
let field_names = model.attrs.iter().enumerate().map(|(counter, a)| a.field_name.clone()
.unwrap_or_else(||{
syn::Ident::from(format!("t_{}", counter))
}));
let generics = syn::aster::from_generics(model.generics.clone())
.ty_param_id("__DB")
.ty_param_id("__ST")
.build();
let struct_ty = &model.ty;
let row_ty = model.attrs.as_slice().iter().map(|a| &a.ty);
let row_ty = quote!((#(#row_ty,)*));
let build_expr = build_expr_for_model(&model);
let field_names = model.attrs.as_slice().iter().map(Attr::name_for_pattern);
let row_pat = quote!((#(#field_names,)*));
let field_names = model.attrs.iter().enumerate().map(|(counter, a)|
a.field_name.clone().map(|name| {
quote!(#name:#name)
})
.unwrap_or_else(||{
let r = syn::Ident::from(format!("t_{}", counter));
quote!(#r)
}));
let build_expr = if model.attrs[0].field_name.is_some(){
quote!(#struct_name {#(#field_names,)*})
} else {
quote!(#struct_name (#(#field_names,)*))
};
let dummy_const = syn::Ident::new(format!("_IMPL_QUERYABLE_FOR_{}", struct_name));
let model_name_uppercase = model.name.as_ref().to_uppercase();
let dummy_const = format!("_IMPL_QUERYABLE_FOR_{}", model_name_uppercase).into();
quote!(
#[allow(non_upper_case_globals, unused_attributes, unused_qualifications)]
const #dummy_const: () = {
extern crate diesel as _diesel;
#[automatically_derived]
impl<#(#lifetimes,)* #(#ty_params,)* __DB, __ST> _diesel::Queryable<__ST, __DB> for #struct_ty where
__DB: _diesel::backend::Backend + _diesel::types::HasSqlType<__ST>,
#row_ty: _diesel::types::FromSqlRow<__ST, __DB>,
{
type Row = #row_ty;
fn build(row: Self::Row) -> Self {
let #row_pat = row;
#build_expr
}
}
};)
wrap_item_in_const(dummy_const, quote!(
impl#generics diesel::Queryable<__ST, __DB> for #struct_ty where
__DB: diesel::backend::Backend + diesel::types::HasSqlType<__ST>,
#row_ty: diesel::types::FromSqlRow<__ST, __DB>,
{
type Row = #row_ty;
fn build(#row_pat: Self::Row) -> Self {
#build_expr
}
}
))
}
fn build_expr_for_model(model: &Model) -> Tokens {
let struct_name = &model.name;
let field_names = model.attrs.as_slice().iter().map(Attr::name_for_pattern);
let field_assignments = field_names.map(|field_name| {
if model.is_tuple_struct() {
quote!(#field_name)
} else {
quote!(#field_name: #field_name)
}
});
if model.is_tuple_struct() {
quote!(#struct_name(#(#field_assignments),*))
} else {
quote!(#struct_name {
#(#field_assignments,)*
})
}
}

View File

@ -1,3 +1,4 @@
use quote::Tokens;
use syn::*;
use ast_builder::ty_ident;
@ -141,3 +142,12 @@ pub fn get_optional_option<'a>(
options.iter().find(|a| a.name() == option_name)
.map(|a| str_value_of_meta_item(a, option_name))
}
pub fn wrap_item_in_const(const_name: Ident, item: Tokens) -> Tokens {
quote! {
const #const_name: () = {
extern crate diesel;
#item
};
}
}

View File

@ -1,16 +1,9 @@
#[macro_use]
extern crate cfg_if;
#[macro_use]
extern crate diesel_codegen;
extern crate diesel;
mod test_helpers;
use diesel::expression::dsl::sql;
use diesel::prelude::*;
use test_helpers::connection;
use diesel::*;
use diesel::types::Integer;
use test_helpers::connection;
#[test]
fn named_struct_definition() {
#[derive(Debug, Clone, Copy, PartialEq, Eq, Queryable)]
@ -20,18 +13,17 @@ fn named_struct_definition() {
}
let conn = connection();
let data = ::diesel::select(sql::<(Integer, Integer)>("1, 2")).get_result(&conn);
let data = select(sql::<(Integer, Integer)>("1, 2")).get_result(&conn);
assert_eq!(Ok(MyStruct { foo: 1, bar: 2 }), data);
}
#[test]
fn tuple_struct() {
#[derive(Debug, Clone, Copy, PartialEq, Eq, Queryable)]
struct MyStruct(#[column_name(foo)] i32, #[column_name(bar)] i32);
let conn = connection();
let data = ::diesel::select(sql::<(Integer, Integer)>("1, 2")).get_result(&conn);
let data = select(sql::<(Integer, Integer)>("1, 2")).get_result(&conn);
assert_eq!(Ok(MyStruct(1, 2)), data);
}
@ -41,6 +33,6 @@ fn tuple_struct_without_column_name_annotations() {
struct MyStruct(i32, i32);
let conn = connection();
let data = ::diesel::select(sql::<(Integer, Integer)>("1, 2")).get_result(&conn);
let data = select(sql::<(Integer, Integer)>("1, 2")).get_result(&conn);
assert_eq!(Ok(MyStruct(1, 2)), data);
}

View File

@ -46,17 +46,12 @@ cfg_if! {
pub fn connection_no_transaction() -> TestConnection {
dotenv().ok();
let database_url = env::var("MYSQL_DATABASE_URL")
let database_url = env::var("MYSQL_UNIT_TEST_DATABASE_URL")
.or_else(|_| env::var("DATABASE_URL"))
.expect("DATABASE_URL must be set in order to run tests");
MysqlConnection::establish(&database_url).unwrap()
}
} else {
pub type TestConnection = ();
pub fn connection() -> TestConnection {
panic!("At least one backend must be enabled to run tests")
}
// FIXME: https://github.com/rust-lang/rfcs/pull/1695
// compile_error!("At least one backend must be enabled to run tests");
}

View File

@ -0,0 +1,6 @@
#[macro_use] extern crate cfg_if;
#[macro_use] extern crate diesel_codegen;
extern crate diesel;
mod queryable;
mod test_helpers;

View File

@ -11,7 +11,7 @@ keywords = ["orm", "database", "postgres", "postgresql", "sql"]
[dependencies]
diesel = { version = "0.10.0", default-features = false }
syn = "0.10.3"
syn = "0.11.4"
quote = "0.3.1"
[dev-dependencies]

View File

@ -2,7 +2,7 @@ use std::error::Error;
use diesel::*;
use diesel::expression::dsl::sql;
use diesel::sqlite::SqliteConnection;
use diesel::sqlite::{Sqlite, SqliteConnection};
use table_data::TableData;
use super::data_structures::*;
@ -60,14 +60,10 @@ struct FullTableInfo {
primary_key: bool,
}
impl<ST> Queryable<ST, ::diesel::sqlite::Sqlite> for FullTableInfo where
::diesel::sqlite::Sqlite: ::diesel::types::HasSqlType<ST>,
(i32, String, String, bool, Option<String>, bool) : ::diesel::types::FromSqlRow<ST, ::diesel::sqlite::Sqlite> {
impl Queryable<pragma_table_info::SqlType, Sqlite> for FullTableInfo {
type Row = (i32, String, String, bool, Option<String>, bool);
fn build(row: Self::Row) -> Self{
fn build(row: Self::Row) -> Self {
FullTableInfo{
_cid: row.0,
name: row.1,
@ -79,7 +75,6 @@ impl<ST> Queryable<ST, ::diesel::sqlite::Sqlite> for FullTableInfo where
}
}
pub fn get_primary_keys(conn: &SqliteConnection, table: &TableData) -> QueryResult<Vec<String>> {
let query = format!("PRAGMA TABLE_INFO('{}')", &table.name);
let results = try!(sql::<pragma_table_info::SqlType>(&query)