From cb6bc34eb40ab7d1978c20ce7b0d067669430296 Mon Sep 17 00:00:00 2001 From: guillaume-be Date: Fri, 20 Aug 2021 11:08:37 +0200 Subject: [PATCH] Updated borrowing for XLNet, integration tests --- examples/masked_language_model_bert.rs | 6 +- src/electra/electra_model.rs | 100 +++++++++++-------------- src/electra/embeddings.rs | 65 +++++++--------- src/pipelines/token_classification.rs | 10 +-- src/xlnet/xlnet_model.rs | 68 ++++++++--------- tests/albert.rs | 10 +-- tests/bert.rs | 14 ++-- tests/distilbert.rs | 6 +- tests/electra.rs | 4 +- tests/roberta.rs | 12 +-- 10 files changed, 136 insertions(+), 159 deletions(-) diff --git a/examples/masked_language_model_bert.rs b/examples/masked_language_model_bert.rs index 0158d87..166ea43 100644 --- a/examples/masked_language_model_bert.rs +++ b/examples/masked_language_model_bert.rs @@ -67,13 +67,13 @@ fn main() -> anyhow::Result<()> { // Forward pass let model_output = no_grad(|| { bert_model.forward_t( - Some(input_tensor), + Some(&input_tensor), + None, + None, None, None, None, None, - &None, - &None, false, ) }); diff --git a/src/electra/electra_model.rs b/src/electra/electra_model.rs index d36acc0..6b7408a 100644 --- a/src/electra/electra_model.rs +++ b/src/electra/electra_model.rs @@ -15,6 +15,7 @@ use crate::bert::BertConfig; use crate::common::activations::Activation; use crate::common::dropout::Dropout; +use crate::common::embeddings::get_shape_and_device_from_ids_embeddings_pair; use crate::electra::embeddings::ElectraEmbeddings; use crate::{bert::encoder::BertEncoder, common::activations::TensorFunction}; use crate::{Config, RustBertError}; @@ -216,10 +217,10 @@ impl ElectraModel { /// let model_output = no_grad(|| { /// electra_model /// .forward_t( - /// Some(input_tensor), - /// Some(mask), - /// Some(token_type_ids), - /// Some(position_ids), + /// Some(&input_tensor), + /// Some(&mask), + /// Some(&token_type_ids), + /// Some(&position_ids), /// None, /// false, /// ) @@ -228,33 +229,22 @@ impl ElectraModel { /// ``` pub fn forward_t( &self, - input_ids: Option, - mask: Option, - token_type_ids: Option, - position_ids: Option, - input_embeds: Option, + input_ids: Option<&Tensor>, + mask: Option<&Tensor>, + token_type_ids: Option<&Tensor>, + position_ids: Option<&Tensor>, + input_embeds: Option<&Tensor>, train: bool, ) -> Result { - let (input_shape, device) = match &input_ids { - Some(input_value) => match &input_embeds { - Some(_) => { - return Err(RustBertError::ValueError( - "Only one of input ids or input embeddings may be set".into(), - )); - } - None => (input_value.size(), input_value.device()), - }, - None => match &input_embeds { - Some(embeds) => (vec![embeds.size()[0], embeds.size()[1]], embeds.device()), - None => { - return Err(RustBertError::ValueError( - "At least one of input ids or input embeddings must be set".into(), - )); - } - }, - }; + let (input_shape, device) = + get_shape_and_device_from_ids_embeddings_pair(input_ids, input_embeds)?; - let mask = mask.unwrap_or_else(|| Tensor::ones(&input_shape, (Kind::Int64, device))); + let calc_mask = if mask.is_none() { + Some(Tensor::ones(&input_shape, (Kind::Int64, device))) + } else { + None + }; + let mask = mask.unwrap_or_else(|| calc_mask.as_ref().unwrap()); let extended_attention_mask = match mask.dim() { 3 => mask.unsqueeze(1), @@ -590,10 +580,10 @@ impl ElectraForMaskedLM { /// /// let model_output = no_grad(|| { /// electra_model.forward_t( - /// Some(input_tensor), - /// Some(mask), - /// Some(token_type_ids), - /// Some(position_ids), + /// Some(&input_tensor), + /// Some(&mask), + /// Some(&token_type_ids), + /// Some(&position_ids), /// None, /// false, /// ) @@ -601,11 +591,11 @@ impl ElectraForMaskedLM { /// ``` pub fn forward_t( &self, - input_ids: Option, - mask: Option, - token_type_ids: Option, - position_ids: Option, - input_embeds: Option, + input_ids: Option<&Tensor>, + mask: Option<&Tensor>, + token_type_ids: Option<&Tensor>, + position_ids: Option<&Tensor>, + input_embeds: Option<&Tensor>, train: bool, ) -> ElectraMaskedLMOutput { let base_model_output = self @@ -717,21 +707,21 @@ impl ElectraDiscriminator { /// /// let model_output = no_grad(|| { /// electra_model - /// .forward_t(Some(input_tensor), - /// Some(mask), - /// Some(token_type_ids), - /// Some(position_ids), + /// .forward_t(Some(&input_tensor), + /// Some(&mask), + /// Some(&token_type_ids), + /// Some(&position_ids), /// None, /// false) /// }); /// ``` pub fn forward_t( &self, - input_ids: Option, - mask: Option, - token_type_ids: Option, - position_ids: Option, - input_embeds: Option, + input_ids: Option<&Tensor>, + mask: Option<&Tensor>, + token_type_ids: Option<&Tensor>, + position_ids: Option<&Tensor>, + input_embeds: Option<&Tensor>, train: bool, ) -> ElectraDiscriminatorOutput { let base_model_output = self @@ -858,21 +848,21 @@ impl ElectraForTokenClassification { /// /// let model_output = no_grad(|| { /// electra_model - /// .forward_t(Some(input_tensor), - /// Some(mask), - /// Some(token_type_ids), - /// Some(position_ids), + /// .forward_t(Some(&input_tensor), + /// Some(&mask), + /// Some(&token_type_ids), + /// Some(&position_ids), /// None, /// false) /// }); /// ``` pub fn forward_t( &self, - input_ids: Option, - mask: Option, - token_type_ids: Option, - position_ids: Option, - input_embeds: Option, + input_ids: Option<&Tensor>, + mask: Option<&Tensor>, + token_type_ids: Option<&Tensor>, + position_ids: Option<&Tensor>, + input_embeds: Option<&Tensor>, train: bool, ) -> ElectraTokenClassificationOutput { let base_model_output = self diff --git a/src/electra/embeddings.rs b/src/electra/embeddings.rs index e16a81e..4fd75d3 100644 --- a/src/electra/embeddings.rs +++ b/src/electra/embeddings.rs @@ -13,6 +13,7 @@ // limitations under the License. use crate::common::dropout::Dropout; +use crate::common::embeddings::process_ids_embeddings_pair; use crate::electra::electra_model::ElectraConfig; use crate::RustBertError; use std::borrow::Borrow; @@ -84,50 +85,40 @@ impl ElectraEmbeddings { pub fn forward_t( &self, - input_ids: Option, - token_type_ids: Option, - position_ids: Option, - input_embeds: Option, + input_ids: Option<&Tensor>, + token_type_ids: Option<&Tensor>, + position_ids: Option<&Tensor>, + input_embeds: Option<&Tensor>, train: bool, ) -> Result { - let (input_embeddings, input_shape) = match input_ids { - Some(input_value) => match input_embeds { - Some(_) => { - return Err(RustBertError::ValueError( - "Only one of input ids or input embeddings may be set".into(), - )); - } - None => ( - input_value.apply_t(&self.word_embeddings, train), - input_value.size(), - ), - }, - None => match input_embeds { - Some(embeds) => { - let size = vec![embeds.size()[0], embeds.size()[1]]; - (embeds, size) - } - None => { - return Err(RustBertError::ValueError( - "At least one of input ids or input embeddings must be set".into(), - )); - } - }, - }; + let (calc_input_embeddings, input_shape, _) = + process_ids_embeddings_pair(input_ids, input_embeds, &self.word_embeddings)?; - let seq_length = input_embeddings.as_ref().size()[1].to_owned(); + let input_embeddings = + input_embeds.unwrap_or_else(|| calc_input_embeddings.as_ref().unwrap()); + let seq_length = input_embeddings.size()[1].to_owned(); - let position_ids = match position_ids { - Some(value) => value, - None => Tensor::arange(seq_length, (Kind::Int64, input_embeddings.device())) - .unsqueeze(0) - .expand(&input_shape, true), + let calc_position_ids = if position_ids.is_none() { + Some( + Tensor::arange(seq_length, (Kind::Int64, input_embeddings.device())) + .unsqueeze(0) + .expand(&input_shape, true), + ) + } else { + None }; + let position_ids = position_ids.unwrap_or_else(|| calc_position_ids.as_ref().unwrap()); - let token_type_ids = match token_type_ids { - Some(value) => value, - None => Tensor::zeros(&input_shape, (Kind::Int64, input_embeddings.device())), + let calc_token_type_ids = if token_type_ids.is_none() { + Some(Tensor::zeros( + &input_shape, + (Kind::Int64, input_embeddings.device()), + )) + } else { + None }; + let token_type_ids = + token_type_ids.unwrap_or_else(|| calc_token_type_ids.as_ref().unwrap()); let position_embeddings = position_ids.apply(&self.position_embeddings); let token_type_embeddings = token_type_ids.apply(&self.token_type_embeddings); diff --git a/src/pipelines/token_classification.rs b/src/pipelines/token_classification.rs index e75fb1e..f712f9c 100644 --- a/src/pipelines/token_classification.rs +++ b/src/pipelines/token_classification.rs @@ -545,12 +545,12 @@ impl TokenClassificationOption { Self::Longformer(ref model) => { model .forward_t( - input_ids.as_ref(), - mask.as_ref(), + input_ids, + mask, None, - token_type_ids.as_ref(), - position_ids.as_ref(), - input_embeds.as_ref(), + token_type_ids, + position_ids, + input_embeds, train, ) .expect("Error in longformer forward_t") diff --git a/src/xlnet/xlnet_model.rs b/src/xlnet/xlnet_model.rs index 8fab78a..17f95a1 100644 --- a/src/xlnet/xlnet_model.rs +++ b/src/xlnet/xlnet_model.rs @@ -390,38 +390,34 @@ impl XLNetModel { perm_mask: Option<&Tensor>, target_mapping: Option<&Tensor>, token_type_ids: Option<&Tensor>, - input_embeds: Option, + input_embeds: Option<&Tensor>, train: bool, ) -> Result { - let (word_emb_k, input_shape) = match input_ids { - Some(input_value) => match input_embeds { - Some(_) => { - return Err(RustBertError::ValueError( - "Only one of input ids or input embeddings may be set".into(), - )); - } - None => { - let size = input_value.size(); - ( - input_value - .transpose(0, 1) - .contiguous() - .apply_t(&self.word_embeddings, train), - vec![size[1], size[0]], - ) - } - }, - None => match input_embeds { - Some(embeds) => { - let size = vec![embeds.size()[1], embeds.size()[0]]; - (embeds.transpose(0, 1).contiguous(), size) - } - None => { - return Err(RustBertError::ValueError( - "At least one of input ids or input embeddings must be set".into(), - )); - } - }, + let (word_emb_k, input_shape) = match (input_ids, input_embeds) { + (Some(_), Some(_)) => { + return Err(RustBertError::ValueError( + "Only one of input ids or input embeddings may be set".into(), + )); + } + (Some(input_value), None) => { + let size = input_value.size(); + ( + input_value + .transpose(0, 1) + .contiguous() + .apply_t(&self.word_embeddings, train), + vec![size[1], size[0]], + ) + } + (None, Some(embeds)) => { + let size = vec![embeds.size()[1], embeds.size()[0]]; + (embeds.transpose(0, 1).contiguous(), size) + } + (None, None) => { + return Err(RustBertError::ValueError( + "At least one of input ids or input embeddings must be set".into(), + )); + } }; let token_type_ids = @@ -715,7 +711,7 @@ impl XLNetLMHeadModel { perm_mask: Option<&Tensor>, target_mapping: Option<&Tensor>, token_type_ids: Option<&Tensor>, - input_embeds: Option, + input_embeds: Option<&Tensor>, train: bool, ) -> Result { let base_model_output = self.base_model.forward_t( @@ -966,7 +962,7 @@ impl XLNetForSequenceClassification { perm_mask: Option<&Tensor>, target_mapping: Option<&Tensor>, token_type_ids: Option<&Tensor>, - input_embeds: Option, + input_embeds: Option<&Tensor>, train: bool, ) -> XLNetSequenceClassificationOutput { let base_model_output = self @@ -1124,7 +1120,7 @@ impl XLNetForTokenClassification { perm_mask: Option<&Tensor>, target_mapping: Option<&Tensor>, token_type_ids: Option<&Tensor>, - input_embeds: Option, + input_embeds: Option<&Tensor>, train: bool, ) -> XLNetTokenClassificationOutput { let base_model_output = self @@ -1273,7 +1269,7 @@ impl XLNetForMultipleChoice { perm_mask: Option<&Tensor>, target_mapping: Option<&Tensor>, token_type_ids: Option<&Tensor>, - input_embeds: Option, + input_embeds: Option<&Tensor>, train: bool, ) -> XLNetSequenceClassificationOutput { let (input_ids, num_choices) = match input_ids { @@ -1305,7 +1301,7 @@ impl XLNetForMultipleChoice { perm_mask, target_mapping, token_type_ids.as_ref(), - input_embeds, + input_embeds.as_ref(), train, ) .unwrap(); @@ -1444,7 +1440,7 @@ impl XLNetForQuestionAnswering { perm_mask: Option<&Tensor>, target_mapping: Option<&Tensor>, token_type_ids: Option<&Tensor>, - input_embeds: Option, + input_embeds: Option<&Tensor>, train: bool, ) -> XLNetQuestionAnsweringOutput { let base_model_output = self diff --git a/tests/albert.rs b/tests/albert.rs index d8cbf7e..0b5b58f 100644 --- a/tests/albert.rs +++ b/tests/albert.rs @@ -62,7 +62,7 @@ fn albert_masked_lm() -> anyhow::Result<()> { // Forward pass let model_output = - no_grad(|| albert_model.forward_t(Some(input_tensor), None, None, None, None, false)); + no_grad(|| albert_model.forward_t(Some(&input_tensor), None, None, None, None, false)); // Print masked tokens let index_1 = model_output @@ -135,7 +135,7 @@ fn albert_for_sequence_classification() -> anyhow::Result<()> { // Forward pass let model_output = - no_grad(|| albert_model.forward_t(Some(input_tensor), None, None, None, None, false)); + no_grad(|| albert_model.forward_t(Some(&input_tensor), None, None, None, None, false)); assert_eq!(model_output.logits.size(), &[2, 3]); assert_eq!( @@ -199,7 +199,7 @@ fn albert_for_multiple_choice() -> anyhow::Result<()> { // Forward pass let model_output = no_grad(|| { albert_model - .forward_t(Some(input_tensor), None, None, None, None, false) + .forward_t(Some(&input_tensor), None, None, None, None, false) .unwrap() }); @@ -268,7 +268,7 @@ fn albert_for_token_classification() -> anyhow::Result<()> { // Forward pass let model_output = - no_grad(|| bert_model.forward_t(Some(input_tensor), None, None, None, None, false)); + no_grad(|| bert_model.forward_t(Some(&input_tensor), None, None, None, None, false)); assert_eq!(model_output.logits.size(), &[2, 12, 4]); assert_eq!( @@ -329,7 +329,7 @@ fn albert_for_question_answering() -> anyhow::Result<()> { // Forward pass let model_output = - no_grad(|| albert_model.forward_t(Some(input_tensor), None, None, None, None, false)); + no_grad(|| albert_model.forward_t(Some(&input_tensor), None, None, None, None, false)); assert_eq!(model_output.start_logits.size(), &[2, 12]); assert_eq!(model_output.end_logits.size(), &[2, 12]); diff --git a/tests/bert.rs b/tests/bert.rs index 32c63e5..d40197f 100644 --- a/tests/bert.rs +++ b/tests/bert.rs @@ -72,13 +72,13 @@ fn bert_masked_lm() -> anyhow::Result<()> { // Forward pass let model_output = no_grad(|| { bert_model.forward_t( - Some(input_tensor), + Some(&input_tensor), + None, + None, None, None, None, None, - &None, - &None, false, ) }); @@ -152,7 +152,7 @@ fn bert_for_sequence_classification() -> anyhow::Result<()> { // Forward pass let model_output = - no_grad(|| bert_model.forward_t(Some(input_tensor), None, None, None, None, false)); + no_grad(|| bert_model.forward_t(Some(&input_tensor), None, None, None, None, false)); assert_eq!(model_output.logits.size(), &[2, 3]); assert_eq!( @@ -212,7 +212,7 @@ fn bert_for_multiple_choice() -> anyhow::Result<()> { .unsqueeze(0); // Forward pass - let model_output = no_grad(|| bert_model.forward_t(input_tensor, None, None, None, false)); + let model_output = no_grad(|| bert_model.forward_t(&input_tensor, None, None, None, false)); assert_eq!(model_output.logits.size(), &[1, 2]); assert_eq!( @@ -277,7 +277,7 @@ fn bert_for_token_classification() -> anyhow::Result<()> { // Forward pass let model_output = - no_grad(|| bert_model.forward_t(Some(input_tensor), None, None, None, None, false)); + no_grad(|| bert_model.forward_t(Some(&input_tensor), None, None, None, None, false)); assert_eq!(model_output.logits.size(), &[2, 11, 4]); assert_eq!( @@ -336,7 +336,7 @@ fn bert_for_question_answering() -> anyhow::Result<()> { // Forward pass let model_output = - no_grad(|| bert_model.forward_t(Some(input_tensor), None, None, None, None, false)); + no_grad(|| bert_model.forward_t(Some(&input_tensor), None, None, None, None, false)); assert_eq!(model_output.start_logits.size(), &[2, 11]); assert_eq!(model_output.end_logits.size(), &[2, 11]); diff --git a/tests/distilbert.rs b/tests/distilbert.rs index e52dfcf..c986b51 100644 --- a/tests/distilbert.rs +++ b/tests/distilbert.rs @@ -96,7 +96,7 @@ fn distilbert_masked_lm() -> anyhow::Result<()> { // Forward pass let model_output = no_grad(|| { distil_bert_model - .forward_t(Some(input_tensor), None, None, false) + .forward_t(Some(&input_tensor), None, None, false) .unwrap() }); @@ -167,7 +167,7 @@ fn distilbert_for_question_answering() -> anyhow::Result<()> { // Forward pass let model_output = no_grad(|| { distil_bert_model - .forward_t(Some(input_tensor), None, None, false) + .forward_t(Some(&input_tensor), None, None, false) .unwrap() }); @@ -238,7 +238,7 @@ fn distilbert_for_token_classification() -> anyhow::Result<()> { // Forward pass let model_output = no_grad(|| { distil_bert_model - .forward_t(Some(input_tensor), None, None, false) + .forward_t(Some(&input_tensor), None, None, false) .unwrap() }); diff --git a/tests/electra.rs b/tests/electra.rs index 374506d..17b137f 100644 --- a/tests/electra.rs +++ b/tests/electra.rs @@ -59,7 +59,7 @@ fn electra_masked_lm() -> anyhow::Result<()> { // Forward pass let model_output = - no_grad(|| electra_model.forward_t(Some(input_tensor), None, None, None, None, false)); + no_grad(|| electra_model.forward_t(Some(&input_tensor), None, None, None, None, false)); // Decode output let index_1 = model_output @@ -138,7 +138,7 @@ fn electra_discriminator() -> anyhow::Result<()> { // Forward pass let model_output = - no_grad(|| electra_model.forward_t(Some(input_tensor), None, None, None, None, false)); + no_grad(|| electra_model.forward_t(Some(&input_tensor), None, None, None, None, false)); // Validate model predictions let expected_probabilities = vec![ diff --git a/tests/roberta.rs b/tests/roberta.rs index 41f812d..b8fb7db 100644 --- a/tests/roberta.rs +++ b/tests/roberta.rs @@ -82,13 +82,13 @@ fn roberta_masked_lm() -> anyhow::Result<()> { // Forward pass let model_output = no_grad(|| { roberta_model.forward_t( - Some(input_tensor), + Some(&input_tensor), + None, + None, None, None, None, None, - &None, - &None, false, ) }); @@ -172,7 +172,7 @@ fn roberta_for_sequence_classification() -> anyhow::Result<()> { // Forward pass let model_output = - no_grad(|| roberta_model.forward_t(Some(input_tensor), None, None, None, None, false)); + no_grad(|| roberta_model.forward_t(Some(&input_tensor), None, None, None, None, false)); assert_eq!(model_output.logits.size(), &[2, 3]); assert_eq!( @@ -242,7 +242,7 @@ fn roberta_for_multiple_choice() -> anyhow::Result<()> { .unsqueeze(0); // Forward pass - let model_output = no_grad(|| roberta_model.forward_t(input_tensor, None, None, None, false)); + let model_output = no_grad(|| roberta_model.forward_t(&input_tensor, None, None, None, false)); assert_eq!(model_output.logits.size(), &[1, 2]); assert_eq!( @@ -317,7 +317,7 @@ fn roberta_for_token_classification() -> anyhow::Result<()> { // Forward pass let model_output = - no_grad(|| roberta_model.forward_t(Some(input_tensor), None, None, None, None, false)); + no_grad(|| roberta_model.forward_t(Some(&input_tensor), None, None, None, None, false)); assert_eq!(model_output.logits.size(), &[2, 9, 4]); assert_eq!(