Allow #[derive(Identifiable)] to work with composite primary keys

This is a tad bit magic in the macro piece. We're expecting the primary
keys *not to* have a trailing comma, so we can rely on the fact that
`(T)` is equivalent to `T`, but if there's more than one element it
would be a tuple.

Beyond that everything was quite straightforward with the ground work
that we've laid.

Fixes #42.
This commit is contained in:
Sean Griffin 2016-12-08 11:08:47 -05:00
parent 9819ed44b5
commit edd145e2f8
11 changed files with 259 additions and 92 deletions

View File

@ -44,6 +44,11 @@ for Rust libraries in [RFC #1105](https://github.com/rust-lang/rfcs/blob/master/
[exists]: http://docs.diesel.rs/diesel/expression/dsl/fn.sql.html
* `#[derive(Identifiable)]` can be used with structs that have primary keys
other than `id`, as well as structs with composite primary keys. You can now
annotate the struct with `#[primary_key(nonstandard)]` or `#[primary_key(foo,
bar)]`.
### Changed
* All macros with the same name as traits we can derive (e.g. `Queryable!`) have

View File

@ -65,12 +65,12 @@ macro_rules! impl_Identifiable {
body = $($body:tt)*
),
found_option_with_name = primary_key,
value = ($primary_key_name:ident),
value = $primary_key_names:tt,
) => {
impl_Identifiable! {
(
table_name = $table_name,
primary_key_name = $primary_key_name,
primary_key_names = $primary_key_names,
)
$($body)*
}
@ -83,13 +83,14 @@ macro_rules! impl_Identifiable {
struct_ty = $struct_ty:ty,
lifetimes = ($($lifetimes:tt),*),
),
primary_key_field = {
found_fields_with_field_names,
fields = [$({
field_name: $field_name:ident,
column_name: $column_name:ident,
field_ty: $field_ty:ty,
field_kind: $field_kind:ident,
$($rest:tt)*
},
})*],
) => {
impl<$($lifetimes),*> $crate::associations::HasTable for $struct_ty {
type Table = $table_name::table;
@ -100,48 +101,28 @@ macro_rules! impl_Identifiable {
}
impl<'ident $(,$lifetimes)*> $crate::associations::Identifiable for &'ident $struct_ty {
type Id = &'ident $field_ty;
type Id = ($(&'ident $field_ty),*);
fn id(self) -> Self::Id {
&self.$field_name
($(&self.$field_name),*)
}
}
};
// Search for the primary key field and continue
// Search for the primary key fields and continue
(
(
table_name = $table_name:ident,
primary_key_name = $primary_key_name:ident,
primary_key_names = $primary_key_names:tt,
$($args:tt)*
),
fields = [{
field_name: $field_name:ident,
$($rest:tt)*
} $($fields:tt)*],
fields = $fields:tt,
) => {
static_cond! {
if $primary_key_name == $field_name {
impl_Identifiable! {
(
table_name = $table_name,
$($args)*
),
primary_key_field = {
field_name: $field_name,
$($rest)*
},
}
} else {
impl_Identifiable! {
(
table_name = $table_name,
primary_key_name = $primary_key_name,
$($args)*
),
fields = [$($fields)*],
}
}
__diesel_fields_with_field_names! {
(table_name = $table_name, $($args)*),
callback = impl_Identifiable,
targets = $primary_key_names,
fields = $fields,
}
};
@ -345,3 +326,32 @@ fn derive_identifiable_with_non_standard_pk_given_before_table_name() {
assert_eq!(&"hi", foo1.id());
assert_eq!(&"there", foo2.id());
}
#[test]
fn derive_identifiable_with_composite_pk() {
use associations::Identifiable;
#[allow(missing_debug_implementations, missing_copy_implementations, dead_code)]
struct Foo {
id: i32,
foo_id: i32,
bar_id: i32,
foo: i32,
}
impl_Identifiable! {
#[primary_key(foo_id, bar_id)]
#[table_name(bars)]
struct Foo {
id: i32,
foo_id: i32,
bar_id: i32,
foo: i32,
}
}
let foo1 = Foo { id: 1, foo_id: 2, bar_id: 3, foo: 4 };
let foo2 = Foo { id: 5, foo_id: 6, bar_id: 7, foo: 8 };
assert_eq!((&2, &3), foo1.id());
assert_eq!((&6, &7), foo2.id());
}

View File

@ -333,6 +333,105 @@ macro_rules! __diesel_field_with_column_name {
};
}
#[doc(hidden)]
#[macro_export]
macro_rules! __diesel_field_with_field_name {
(
$headers:tt,
callback = $callback:ident,
target = $target_field_name:ident,
fields = [$({
field_name: $field_name:ident,
$($field_info:tt)*
})*],
) => {
$(
static_cond! {
if $target_field_name == $field_name {
$callback! {
$headers,
found_field_with_field_name = $field_name,
field = {
field_name: $field_name,
$($field_info)*
},
}
}
}
)*
};
}
#[doc(hidden)]
#[macro_export]
macro_rules! __diesel_fields_with_field_names {
// Entrypoint, start search
(
$headers:tt,
callback = $callback:ident,
targets = ($target_field_name:ident $(,$rest:ident)*),
fields = $fields:tt,
) => {
__diesel_field_with_field_name! {
(
targets = ($($rest),*),
fields = $fields,
headers = $headers,
callback = $callback,
found_fields = [],
),
callback = __diesel_fields_with_field_names,
target = $target_field_name,
fields = $fields,
}
};
// Found field, more to search for
(
(
targets = ($target_field_name:ident $(,$rest:ident)*),
fields = $fields:tt,
headers = $headers:tt,
callback = $callback:ident,
found_fields = [$($found_fields:tt)*],
),
found_field_with_field_name = $ignore:tt,
field = $field:tt,
) => {
__diesel_field_with_field_name! {
(
targets = ($($rest),*),
fields = $fields,
headers = $headers,
callback = $callback,
found_fields = [$($found_fields)* $field],
),
callback = __diesel_fields_with_field_names,
target = $target_field_name,
fields = $fields,
}
};
// Found field, no more to search for
(
(
targets = (),
fields = $fields:tt,
headers = $headers:tt,
callback = $callback:ident,
found_fields = [$($found_fields:tt)*],
),
found_field_with_field_name = $ignore:tt,
field = $field:tt,
) => {
$callback! {
$headers,
found_fields_with_field_names,
fields = [$($found_fields)* $field],
}
}
}
#[doc(hidden)]
#[macro_export]
macro_rules! __diesel_find_option_with_name {

View File

@ -8,16 +8,18 @@ pub fn derive_identifiable(item: syn::MacroInput) -> Tokens {
let table_name = model.table_name();
let struct_ty = &model.ty;
let lifetimes = model.generics.lifetimes;
let primary_key_name = model.primary_key_name;
let primary_key_names = model.primary_key_names;
let fields = model.attrs;
if !fields.iter().any(|f| f.field_name.as_ref() == Some(&primary_key_name)) {
panic!("Could not find a field named `{}` on `{}`", primary_key_name, &model.name);
for pk in &primary_key_names {
if !fields.iter().any(|f| f.field_name.as_ref() == Some(pk)) {
panic!("Could not find a field named `{}` on `{}`", pk, &model.name);
}
}
quote!(impl_Identifiable! {
(
table_name = #table_name,
primary_key_name = #primary_key_name,
primary_key_names = (#(#primary_key_names),*),
struct_ty = #struct_ty,
lifetimes = (#(#lifetimes),*),
),

View File

@ -8,7 +8,7 @@ pub struct Model {
pub attrs: Vec<Attr>,
pub name: syn::Ident,
pub generics: syn::Generics,
pub primary_key_name: syn::Ident,
pub primary_key_names: Vec<syn::Ident>,
table_name_from_annotation: Option<syn::Ident>,
}
@ -23,9 +23,9 @@ impl Model {
let ty = struct_ty(item.ident.clone(), &item.generics);
let name = item.ident.clone();
let generics = item.generics.clone();
let primary_key_name = ident_value_of_attr_with_name(&item.attrs, "primary_key")
.map(Clone::clone)
.unwrap_or(syn::Ident::new("id"));
let primary_key_names = list_value_of_attr_with_name(&item.attrs, "primary_key")
.map(|v| v.into_iter().map(Clone::clone).collect())
.unwrap_or_else(|| vec![syn::Ident::new("id")]);
let table_name_from_annotation = str_value_of_attr_with_name(
&item.attrs, "table_name").map(syn::Ident::new);
@ -34,7 +34,7 @@ impl Model {
attrs: attrs,
name: name,
generics: generics,
primary_key_name: primary_key_name,
primary_key_names: primary_key_names,
table_name_from_annotation: table_name_from_annotation,
})
}

View File

@ -35,7 +35,19 @@ pub fn ident_value_of_attr_with_name<'a>(
attrs: &'a [Attribute],
name: &str,
) -> Option<&'a Ident> {
attr_with_name(attrs, name).map(|attr| single_arg_value_of_attr(attr, name))
list_value_of_attr_with_name(attrs, name).map(|idents| {
if idents.len() != 1 {
panic!(r#"`{}` must be in the form `#[{}(something)]`"#, name, name);
}
idents[0]
})
}
pub fn list_value_of_attr_with_name<'a>(
attrs: &'a [Attribute],
name: &str,
) -> Option<Vec<&'a Ident>> {
attr_with_name(attrs, name).map(|attr| list_value_of_attr(attr, name))
}
pub fn attr_with_name<'a>(
@ -56,19 +68,15 @@ pub fn str_value_of_meta_item<'a>(item: &'a MetaItem, name: &str) -> &'a str {
}
}
fn single_arg_value_of_attr<'a>(attr: &'a Attribute, name: &str) -> &'a Ident {
let usage_err = || panic!(r#"`{}` must be in the form `#[{}(something)]`"#, name, name);
fn list_value_of_attr<'a>(attr: &'a Attribute, name: &str) -> Vec<&'a Ident> {
match attr.value {
MetaItem::List(_, ref items) => {
if items.len() != 1 {
return usage_err();
}
match items[0] {
items.iter().map(|item| match *item {
NestedMetaItem::MetaItem(MetaItem::Word(ref name)) => name,
_ => usage_err(),
}
_ => panic!(r#"`{}` must be in the form `#[{}(something)]`"#, name, name),
}).collect()
}
_ => usage_err(),
_ => panic!(r#"`{}` must be in the form `#[{}(something)]`"#, name, name),
}
}

View File

@ -1,9 +1,10 @@
use syntax::ast;
use syntax::codemap::Span;
use syntax::ext::base::{Annotatable, ExtCtxt};
use syntax::parse::token;
use model::Model;
use util::lifetime_list_tokens;
use util::{lifetime_list_tokens, comma_delimited_tokens};
pub fn expand_derive_identifiable(
cx: &mut ExtCtxt,
@ -16,20 +17,25 @@ pub fn expand_derive_identifiable(
let table_name = model.table_name();
let struct_ty = &model.ty;
let lifetimes = lifetime_list_tokens(&model.generics.lifetimes, span);
let primary_key_name = model.primary_key_name;
let primary_key_names = model.primary_key_names();
let fields = model.field_tokens_for_stable_macro(cx);
if model.attr_named(primary_key_name).is_some() {
push(Annotatable::Item(quote_item!(cx, impl_Identifiable! {
(
table_name = $table_name,
primary_key_name = $primary_key_name,
struct_ty = $struct_ty,
lifetimes = ($lifetimes),
),
fields = [$fields],
}).unwrap()));
} else {
cx.span_err(span, &format!("Could not find a field named `{}` on `{}`", primary_key_name, model.name));
for name in primary_key_names {
if model.attr_named(*name).is_none() {
cx.span_err(span, &format!("Could not find a field named `{}` on `{}`", name, model.name));
return;
}
}
let primary_key_names = comma_delimited_tokens(
primary_key_names.into_iter().map(|n| token::Ident(*n)), span);
push(Annotatable::Item(quote_item!(cx, impl_Identifiable! {
(
table_name = $table_name,
primary_key_names = ($primary_key_names),
struct_ty = $struct_ty,
lifetimes = ($lifetimes),
),
fields = [$fields],
}).unwrap()));
}
}

View File

@ -13,7 +13,7 @@ pub struct Model {
pub attrs: Vec<Attr>,
pub name: ast::Ident,
pub generics: ast::Generics,
pub primary_key_name: ast::Ident,
pub primary_key_names: Vec<ast::Ident>,
table_name_from_annotation: Option<ast::Ident>,
}
@ -26,9 +26,9 @@ impl Model {
if let Annotatable::Item(ref item) = *annotatable {
let table_name_from_annotation =
str_value_of_attr_with_name(cx, &item.attrs, "table_name");
let primary_key_name =
ident_value_of_attr_with_name(cx, &item.attrs, "primary_key")
.unwrap_or(str_to_ident("id"));
let primary_key_names =
list_value_of_attr_with_name(cx, &item.attrs, "primary_key")
.unwrap_or_else(|| vec![str_to_ident("id")]);
Attr::from_item(cx, item).map(|(generics, attrs)| {
let ty = struct_ty(cx, span, item.ident, &generics);
Model {
@ -36,7 +36,7 @@ impl Model {
attrs: attrs,
name: item.ident,
generics: generics,
primary_key_name: primary_key_name,
primary_key_names: primary_key_names,
table_name_from_annotation: table_name_from_annotation,
}
})
@ -45,8 +45,8 @@ impl Model {
}
}
pub fn primary_key_name(&self) -> ast::Ident {
self.primary_key_name
pub fn primary_key_names(&self) -> &[ast::Ident] {
&self.primary_key_names
}
pub fn table_name(&self) -> ast::Ident {

View File

@ -99,9 +99,9 @@ fn changeset_impl(
let struct_ty = &model.ty;
let lifetimes = lifetime_list_tokens(&model.generics.lifetimes, span);
let pk = model.primary_key_name();
let pks = model.primary_key_names();
let fields = model.attrs.iter()
.filter(|a| a.column_name.name != pk.name)
.filter(|a| pks.iter().all(|pk| a.column_name.name != pk.name))
.map(|a| a.to_stable_macro_tokens(cx))
.collect::<Vec<_>>();

View File

@ -21,28 +21,27 @@ fn str_value_of_attr(
})
}
fn single_arg_value_of_attr(
fn list_value_of_attr(
cx: &mut ExtCtxt,
attr: &ast::Attribute,
name: &str,
) -> Option<ast::Ident> {
let usage_err = || {
cx.span_err(attr.span(),
&format!(r#"`{}` must be in the form `#[{}(something)]`"#, name, name));
None
};
// FIXME: This can be cleaned up with slice patterns
) -> Vec<ast::Ident> {
match attr.node.value.node {
ast::MetaItemKind::List(_, ref items) => {
if items.len() != 1 {
return usage_err();
}
match items[0].word() {
items.iter().filter_map(|item| match item.word() {
Some(word) => Some(str_to_ident(&word.name())),
_ => usage_err(),
}
_ => {
cx.span_err(attr.span(),
&format!(r#"`{}` must be in the form `#[{}(something)]`"#, name, name));
None
}
}).collect()
}
_ => usage_err(),
_ => {
cx.span_err(attr.span(),
&format!(r#"`{}` must be in the form `#[{}(something)]`"#, name, name));
Vec::new()
},
}
}
@ -63,7 +62,24 @@ pub fn ident_value_of_attr_with_name(
) -> Option<ast::Ident> {
attrs.iter()
.find(|a| a.check_name(name))
.and_then(|a| single_arg_value_of_attr(cx, &a, name))
.map(|a| {
let list = list_value_of_attr(cx, &a, name);
if list.len() != 1 {
cx.span_err(a.span(),
&format!(r#"`{}` must be in the form `#[{}(something)]`"#, name, name));
}
list[0]
})
}
pub fn list_value_of_attr_with_name(
cx: &mut ExtCtxt,
attrs: &[ast::Attribute],
name: &str,
) -> Option<Vec<ast::Ident>> {
attrs.iter()
.find(|a| a.check_name(name))
.map(|a| list_value_of_attr(cx, &a, name))
}
const KNOWN_ATTRIBUTES: &'static [&'static str] = &[

View File

@ -194,3 +194,24 @@ fn derive_identifiable_with_non_standard_pk() {
// Fails to compile if wrong table is generated.
let _: posts::table = Foo::<'static>::table();
}
#[test]
fn derive_identifiable_with_composite_pk() {
use diesel::associations::Identifiable;
#[derive(Identifiable)]
#[primary_key(foo_id, bar_id)]
#[table_name="posts"]
#[allow(dead_code)]
struct Foo {
id: i32,
foo_id: i32,
bar_id: i32,
foo: i32,
}
let foo1 = Foo { id: 1, foo_id: 2, bar_id: 3, foo: 4 };
let foo2 = Foo { id: 5, foo_id: 6, bar_id: 7, foo: 8 };
assert_eq!((&2, &3), foo1.id());
assert_eq!((&6, &7), foo2.id());
}