Updated borrowing for XLNet, integration tests

This commit is contained in:
guillaume-be 2021-08-20 11:08:37 +02:00
parent 2baf659e9b
commit cb6bc34eb4
10 changed files with 136 additions and 159 deletions

View File

@ -67,13 +67,13 @@ fn main() -> anyhow::Result<()> {
// Forward pass // Forward pass
let model_output = no_grad(|| { let model_output = no_grad(|| {
bert_model.forward_t( bert_model.forward_t(
Some(input_tensor), Some(&input_tensor),
None,
None,
None, None,
None, None,
None, None,
None, None,
&None,
&None,
false, false,
) )
}); });

View File

@ -15,6 +15,7 @@
use crate::bert::BertConfig; use crate::bert::BertConfig;
use crate::common::activations::Activation; use crate::common::activations::Activation;
use crate::common::dropout::Dropout; use crate::common::dropout::Dropout;
use crate::common::embeddings::get_shape_and_device_from_ids_embeddings_pair;
use crate::electra::embeddings::ElectraEmbeddings; use crate::electra::embeddings::ElectraEmbeddings;
use crate::{bert::encoder::BertEncoder, common::activations::TensorFunction}; use crate::{bert::encoder::BertEncoder, common::activations::TensorFunction};
use crate::{Config, RustBertError}; use crate::{Config, RustBertError};
@ -216,10 +217,10 @@ impl ElectraModel {
/// let model_output = no_grad(|| { /// let model_output = no_grad(|| {
/// electra_model /// electra_model
/// .forward_t( /// .forward_t(
/// Some(input_tensor), /// Some(&input_tensor),
/// Some(mask), /// Some(&mask),
/// Some(token_type_ids), /// Some(&token_type_ids),
/// Some(position_ids), /// Some(&position_ids),
/// None, /// None,
/// false, /// false,
/// ) /// )
@ -228,33 +229,22 @@ impl ElectraModel {
/// ``` /// ```
pub fn forward_t( pub fn forward_t(
&self, &self,
input_ids: Option<Tensor>, input_ids: Option<&Tensor>,
mask: Option<Tensor>, mask: Option<&Tensor>,
token_type_ids: Option<Tensor>, token_type_ids: Option<&Tensor>,
position_ids: Option<Tensor>, position_ids: Option<&Tensor>,
input_embeds: Option<Tensor>, input_embeds: Option<&Tensor>,
train: bool, train: bool,
) -> Result<ElectraModelOutput, RustBertError> { ) -> Result<ElectraModelOutput, RustBertError> {
let (input_shape, device) = match &input_ids { let (input_shape, device) =
Some(input_value) => match &input_embeds { get_shape_and_device_from_ids_embeddings_pair(input_ids, input_embeds)?;
Some(_) => {
return Err(RustBertError::ValueError(
"Only one of input ids or input embeddings may be set".into(),
));
}
None => (input_value.size(), input_value.device()),
},
None => match &input_embeds {
Some(embeds) => (vec![embeds.size()[0], embeds.size()[1]], embeds.device()),
None => {
return Err(RustBertError::ValueError(
"At least one of input ids or input embeddings must be set".into(),
));
}
},
};
let mask = mask.unwrap_or_else(|| Tensor::ones(&input_shape, (Kind::Int64, device))); let calc_mask = if mask.is_none() {
Some(Tensor::ones(&input_shape, (Kind::Int64, device)))
} else {
None
};
let mask = mask.unwrap_or_else(|| calc_mask.as_ref().unwrap());
let extended_attention_mask = match mask.dim() { let extended_attention_mask = match mask.dim() {
3 => mask.unsqueeze(1), 3 => mask.unsqueeze(1),
@ -590,10 +580,10 @@ impl ElectraForMaskedLM {
/// ///
/// let model_output = no_grad(|| { /// let model_output = no_grad(|| {
/// electra_model.forward_t( /// electra_model.forward_t(
/// Some(input_tensor), /// Some(&input_tensor),
/// Some(mask), /// Some(&mask),
/// Some(token_type_ids), /// Some(&token_type_ids),
/// Some(position_ids), /// Some(&position_ids),
/// None, /// None,
/// false, /// false,
/// ) /// )
@ -601,11 +591,11 @@ impl ElectraForMaskedLM {
/// ``` /// ```
pub fn forward_t( pub fn forward_t(
&self, &self,
input_ids: Option<Tensor>, input_ids: Option<&Tensor>,
mask: Option<Tensor>, mask: Option<&Tensor>,
token_type_ids: Option<Tensor>, token_type_ids: Option<&Tensor>,
position_ids: Option<Tensor>, position_ids: Option<&Tensor>,
input_embeds: Option<Tensor>, input_embeds: Option<&Tensor>,
train: bool, train: bool,
) -> ElectraMaskedLMOutput { ) -> ElectraMaskedLMOutput {
let base_model_output = self let base_model_output = self
@ -717,21 +707,21 @@ impl ElectraDiscriminator {
/// ///
/// let model_output = no_grad(|| { /// let model_output = no_grad(|| {
/// electra_model /// electra_model
/// .forward_t(Some(input_tensor), /// .forward_t(Some(&input_tensor),
/// Some(mask), /// Some(&mask),
/// Some(token_type_ids), /// Some(&token_type_ids),
/// Some(position_ids), /// Some(&position_ids),
/// None, /// None,
/// false) /// false)
/// }); /// });
/// ``` /// ```
pub fn forward_t( pub fn forward_t(
&self, &self,
input_ids: Option<Tensor>, input_ids: Option<&Tensor>,
mask: Option<Tensor>, mask: Option<&Tensor>,
token_type_ids: Option<Tensor>, token_type_ids: Option<&Tensor>,
position_ids: Option<Tensor>, position_ids: Option<&Tensor>,
input_embeds: Option<Tensor>, input_embeds: Option<&Tensor>,
train: bool, train: bool,
) -> ElectraDiscriminatorOutput { ) -> ElectraDiscriminatorOutput {
let base_model_output = self let base_model_output = self
@ -858,21 +848,21 @@ impl ElectraForTokenClassification {
/// ///
/// let model_output = no_grad(|| { /// let model_output = no_grad(|| {
/// electra_model /// electra_model
/// .forward_t(Some(input_tensor), /// .forward_t(Some(&input_tensor),
/// Some(mask), /// Some(&mask),
/// Some(token_type_ids), /// Some(&token_type_ids),
/// Some(position_ids), /// Some(&position_ids),
/// None, /// None,
/// false) /// false)
/// }); /// });
/// ``` /// ```
pub fn forward_t( pub fn forward_t(
&self, &self,
input_ids: Option<Tensor>, input_ids: Option<&Tensor>,
mask: Option<Tensor>, mask: Option<&Tensor>,
token_type_ids: Option<Tensor>, token_type_ids: Option<&Tensor>,
position_ids: Option<Tensor>, position_ids: Option<&Tensor>,
input_embeds: Option<Tensor>, input_embeds: Option<&Tensor>,
train: bool, train: bool,
) -> ElectraTokenClassificationOutput { ) -> ElectraTokenClassificationOutput {
let base_model_output = self let base_model_output = self

View File

@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
use crate::common::dropout::Dropout; use crate::common::dropout::Dropout;
use crate::common::embeddings::process_ids_embeddings_pair;
use crate::electra::electra_model::ElectraConfig; use crate::electra::electra_model::ElectraConfig;
use crate::RustBertError; use crate::RustBertError;
use std::borrow::Borrow; use std::borrow::Borrow;
@ -84,50 +85,40 @@ impl ElectraEmbeddings {
pub fn forward_t( pub fn forward_t(
&self, &self,
input_ids: Option<Tensor>, input_ids: Option<&Tensor>,
token_type_ids: Option<Tensor>, token_type_ids: Option<&Tensor>,
position_ids: Option<Tensor>, position_ids: Option<&Tensor>,
input_embeds: Option<Tensor>, input_embeds: Option<&Tensor>,
train: bool, train: bool,
) -> Result<Tensor, RustBertError> { ) -> Result<Tensor, RustBertError> {
let (input_embeddings, input_shape) = match input_ids { let (calc_input_embeddings, input_shape, _) =
Some(input_value) => match input_embeds { process_ids_embeddings_pair(input_ids, input_embeds, &self.word_embeddings)?;
Some(_) => {
return Err(RustBertError::ValueError(
"Only one of input ids or input embeddings may be set".into(),
));
}
None => (
input_value.apply_t(&self.word_embeddings, train),
input_value.size(),
),
},
None => match input_embeds {
Some(embeds) => {
let size = vec![embeds.size()[0], embeds.size()[1]];
(embeds, size)
}
None => {
return Err(RustBertError::ValueError(
"At least one of input ids or input embeddings must be set".into(),
));
}
},
};
let seq_length = input_embeddings.as_ref().size()[1].to_owned(); let input_embeddings =
input_embeds.unwrap_or_else(|| calc_input_embeddings.as_ref().unwrap());
let seq_length = input_embeddings.size()[1].to_owned();
let position_ids = match position_ids { let calc_position_ids = if position_ids.is_none() {
Some(value) => value, Some(
None => Tensor::arange(seq_length, (Kind::Int64, input_embeddings.device())) Tensor::arange(seq_length, (Kind::Int64, input_embeddings.device()))
.unsqueeze(0) .unsqueeze(0)
.expand(&input_shape, true), .expand(&input_shape, true),
)
} else {
None
}; };
let position_ids = position_ids.unwrap_or_else(|| calc_position_ids.as_ref().unwrap());
let token_type_ids = match token_type_ids { let calc_token_type_ids = if token_type_ids.is_none() {
Some(value) => value, Some(Tensor::zeros(
None => Tensor::zeros(&input_shape, (Kind::Int64, input_embeddings.device())), &input_shape,
(Kind::Int64, input_embeddings.device()),
))
} else {
None
}; };
let token_type_ids =
token_type_ids.unwrap_or_else(|| calc_token_type_ids.as_ref().unwrap());
let position_embeddings = position_ids.apply(&self.position_embeddings); let position_embeddings = position_ids.apply(&self.position_embeddings);
let token_type_embeddings = token_type_ids.apply(&self.token_type_embeddings); let token_type_embeddings = token_type_ids.apply(&self.token_type_embeddings);

View File

@ -545,12 +545,12 @@ impl TokenClassificationOption {
Self::Longformer(ref model) => { Self::Longformer(ref model) => {
model model
.forward_t( .forward_t(
input_ids.as_ref(), input_ids,
mask.as_ref(), mask,
None, None,
token_type_ids.as_ref(), token_type_ids,
position_ids.as_ref(), position_ids,
input_embeds.as_ref(), input_embeds,
train, train,
) )
.expect("Error in longformer forward_t") .expect("Error in longformer forward_t")

View File

@ -390,38 +390,34 @@ impl XLNetModel {
perm_mask: Option<&Tensor>, perm_mask: Option<&Tensor>,
target_mapping: Option<&Tensor>, target_mapping: Option<&Tensor>,
token_type_ids: Option<&Tensor>, token_type_ids: Option<&Tensor>,
input_embeds: Option<Tensor>, input_embeds: Option<&Tensor>,
train: bool, train: bool,
) -> Result<XLNetModelOutput, RustBertError> { ) -> Result<XLNetModelOutput, RustBertError> {
let (word_emb_k, input_shape) = match input_ids { let (word_emb_k, input_shape) = match (input_ids, input_embeds) {
Some(input_value) => match input_embeds { (Some(_), Some(_)) => {
Some(_) => { return Err(RustBertError::ValueError(
return Err(RustBertError::ValueError( "Only one of input ids or input embeddings may be set".into(),
"Only one of input ids or input embeddings may be set".into(), ));
)); }
} (Some(input_value), None) => {
None => { let size = input_value.size();
let size = input_value.size(); (
( input_value
input_value .transpose(0, 1)
.transpose(0, 1) .contiguous()
.contiguous() .apply_t(&self.word_embeddings, train),
.apply_t(&self.word_embeddings, train), vec![size[1], size[0]],
vec![size[1], size[0]], )
) }
} (None, Some(embeds)) => {
}, let size = vec![embeds.size()[1], embeds.size()[0]];
None => match input_embeds { (embeds.transpose(0, 1).contiguous(), size)
Some(embeds) => { }
let size = vec![embeds.size()[1], embeds.size()[0]]; (None, None) => {
(embeds.transpose(0, 1).contiguous(), size) return Err(RustBertError::ValueError(
} "At least one of input ids or input embeddings must be set".into(),
None => { ));
return Err(RustBertError::ValueError( }
"At least one of input ids or input embeddings must be set".into(),
));
}
},
}; };
let token_type_ids = let token_type_ids =
@ -715,7 +711,7 @@ impl XLNetLMHeadModel {
perm_mask: Option<&Tensor>, perm_mask: Option<&Tensor>,
target_mapping: Option<&Tensor>, target_mapping: Option<&Tensor>,
token_type_ids: Option<&Tensor>, token_type_ids: Option<&Tensor>,
input_embeds: Option<Tensor>, input_embeds: Option<&Tensor>,
train: bool, train: bool,
) -> Result<LMModelOutput, RustBertError> { ) -> Result<LMModelOutput, RustBertError> {
let base_model_output = self.base_model.forward_t( let base_model_output = self.base_model.forward_t(
@ -966,7 +962,7 @@ impl XLNetForSequenceClassification {
perm_mask: Option<&Tensor>, perm_mask: Option<&Tensor>,
target_mapping: Option<&Tensor>, target_mapping: Option<&Tensor>,
token_type_ids: Option<&Tensor>, token_type_ids: Option<&Tensor>,
input_embeds: Option<Tensor>, input_embeds: Option<&Tensor>,
train: bool, train: bool,
) -> XLNetSequenceClassificationOutput { ) -> XLNetSequenceClassificationOutput {
let base_model_output = self let base_model_output = self
@ -1124,7 +1120,7 @@ impl XLNetForTokenClassification {
perm_mask: Option<&Tensor>, perm_mask: Option<&Tensor>,
target_mapping: Option<&Tensor>, target_mapping: Option<&Tensor>,
token_type_ids: Option<&Tensor>, token_type_ids: Option<&Tensor>,
input_embeds: Option<Tensor>, input_embeds: Option<&Tensor>,
train: bool, train: bool,
) -> XLNetTokenClassificationOutput { ) -> XLNetTokenClassificationOutput {
let base_model_output = self let base_model_output = self
@ -1273,7 +1269,7 @@ impl XLNetForMultipleChoice {
perm_mask: Option<&Tensor>, perm_mask: Option<&Tensor>,
target_mapping: Option<&Tensor>, target_mapping: Option<&Tensor>,
token_type_ids: Option<&Tensor>, token_type_ids: Option<&Tensor>,
input_embeds: Option<Tensor>, input_embeds: Option<&Tensor>,
train: bool, train: bool,
) -> XLNetSequenceClassificationOutput { ) -> XLNetSequenceClassificationOutput {
let (input_ids, num_choices) = match input_ids { let (input_ids, num_choices) = match input_ids {
@ -1305,7 +1301,7 @@ impl XLNetForMultipleChoice {
perm_mask, perm_mask,
target_mapping, target_mapping,
token_type_ids.as_ref(), token_type_ids.as_ref(),
input_embeds, input_embeds.as_ref(),
train, train,
) )
.unwrap(); .unwrap();
@ -1444,7 +1440,7 @@ impl XLNetForQuestionAnswering {
perm_mask: Option<&Tensor>, perm_mask: Option<&Tensor>,
target_mapping: Option<&Tensor>, target_mapping: Option<&Tensor>,
token_type_ids: Option<&Tensor>, token_type_ids: Option<&Tensor>,
input_embeds: Option<Tensor>, input_embeds: Option<&Tensor>,
train: bool, train: bool,
) -> XLNetQuestionAnsweringOutput { ) -> XLNetQuestionAnsweringOutput {
let base_model_output = self let base_model_output = self

View File

@ -62,7 +62,7 @@ fn albert_masked_lm() -> anyhow::Result<()> {
// Forward pass // Forward pass
let model_output = let model_output =
no_grad(|| albert_model.forward_t(Some(input_tensor), None, None, None, None, false)); no_grad(|| albert_model.forward_t(Some(&input_tensor), None, None, None, None, false));
// Print masked tokens // Print masked tokens
let index_1 = model_output let index_1 = model_output
@ -135,7 +135,7 @@ fn albert_for_sequence_classification() -> anyhow::Result<()> {
// Forward pass // Forward pass
let model_output = let model_output =
no_grad(|| albert_model.forward_t(Some(input_tensor), None, None, None, None, false)); no_grad(|| albert_model.forward_t(Some(&input_tensor), None, None, None, None, false));
assert_eq!(model_output.logits.size(), &[2, 3]); assert_eq!(model_output.logits.size(), &[2, 3]);
assert_eq!( assert_eq!(
@ -199,7 +199,7 @@ fn albert_for_multiple_choice() -> anyhow::Result<()> {
// Forward pass // Forward pass
let model_output = no_grad(|| { let model_output = no_grad(|| {
albert_model albert_model
.forward_t(Some(input_tensor), None, None, None, None, false) .forward_t(Some(&input_tensor), None, None, None, None, false)
.unwrap() .unwrap()
}); });
@ -268,7 +268,7 @@ fn albert_for_token_classification() -> anyhow::Result<()> {
// Forward pass // Forward pass
let model_output = let model_output =
no_grad(|| bert_model.forward_t(Some(input_tensor), None, None, None, None, false)); no_grad(|| bert_model.forward_t(Some(&input_tensor), None, None, None, None, false));
assert_eq!(model_output.logits.size(), &[2, 12, 4]); assert_eq!(model_output.logits.size(), &[2, 12, 4]);
assert_eq!( assert_eq!(
@ -329,7 +329,7 @@ fn albert_for_question_answering() -> anyhow::Result<()> {
// Forward pass // Forward pass
let model_output = let model_output =
no_grad(|| albert_model.forward_t(Some(input_tensor), None, None, None, None, false)); no_grad(|| albert_model.forward_t(Some(&input_tensor), None, None, None, None, false));
assert_eq!(model_output.start_logits.size(), &[2, 12]); assert_eq!(model_output.start_logits.size(), &[2, 12]);
assert_eq!(model_output.end_logits.size(), &[2, 12]); assert_eq!(model_output.end_logits.size(), &[2, 12]);

View File

@ -72,13 +72,13 @@ fn bert_masked_lm() -> anyhow::Result<()> {
// Forward pass // Forward pass
let model_output = no_grad(|| { let model_output = no_grad(|| {
bert_model.forward_t( bert_model.forward_t(
Some(input_tensor), Some(&input_tensor),
None,
None,
None, None,
None, None,
None, None,
None, None,
&None,
&None,
false, false,
) )
}); });
@ -152,7 +152,7 @@ fn bert_for_sequence_classification() -> anyhow::Result<()> {
// Forward pass // Forward pass
let model_output = let model_output =
no_grad(|| bert_model.forward_t(Some(input_tensor), None, None, None, None, false)); no_grad(|| bert_model.forward_t(Some(&input_tensor), None, None, None, None, false));
assert_eq!(model_output.logits.size(), &[2, 3]); assert_eq!(model_output.logits.size(), &[2, 3]);
assert_eq!( assert_eq!(
@ -212,7 +212,7 @@ fn bert_for_multiple_choice() -> anyhow::Result<()> {
.unsqueeze(0); .unsqueeze(0);
// Forward pass // Forward pass
let model_output = no_grad(|| bert_model.forward_t(input_tensor, None, None, None, false)); let model_output = no_grad(|| bert_model.forward_t(&input_tensor, None, None, None, false));
assert_eq!(model_output.logits.size(), &[1, 2]); assert_eq!(model_output.logits.size(), &[1, 2]);
assert_eq!( assert_eq!(
@ -277,7 +277,7 @@ fn bert_for_token_classification() -> anyhow::Result<()> {
// Forward pass // Forward pass
let model_output = let model_output =
no_grad(|| bert_model.forward_t(Some(input_tensor), None, None, None, None, false)); no_grad(|| bert_model.forward_t(Some(&input_tensor), None, None, None, None, false));
assert_eq!(model_output.logits.size(), &[2, 11, 4]); assert_eq!(model_output.logits.size(), &[2, 11, 4]);
assert_eq!( assert_eq!(
@ -336,7 +336,7 @@ fn bert_for_question_answering() -> anyhow::Result<()> {
// Forward pass // Forward pass
let model_output = let model_output =
no_grad(|| bert_model.forward_t(Some(input_tensor), None, None, None, None, false)); no_grad(|| bert_model.forward_t(Some(&input_tensor), None, None, None, None, false));
assert_eq!(model_output.start_logits.size(), &[2, 11]); assert_eq!(model_output.start_logits.size(), &[2, 11]);
assert_eq!(model_output.end_logits.size(), &[2, 11]); assert_eq!(model_output.end_logits.size(), &[2, 11]);

View File

@ -96,7 +96,7 @@ fn distilbert_masked_lm() -> anyhow::Result<()> {
// Forward pass // Forward pass
let model_output = no_grad(|| { let model_output = no_grad(|| {
distil_bert_model distil_bert_model
.forward_t(Some(input_tensor), None, None, false) .forward_t(Some(&input_tensor), None, None, false)
.unwrap() .unwrap()
}); });
@ -167,7 +167,7 @@ fn distilbert_for_question_answering() -> anyhow::Result<()> {
// Forward pass // Forward pass
let model_output = no_grad(|| { let model_output = no_grad(|| {
distil_bert_model distil_bert_model
.forward_t(Some(input_tensor), None, None, false) .forward_t(Some(&input_tensor), None, None, false)
.unwrap() .unwrap()
}); });
@ -238,7 +238,7 @@ fn distilbert_for_token_classification() -> anyhow::Result<()> {
// Forward pass // Forward pass
let model_output = no_grad(|| { let model_output = no_grad(|| {
distil_bert_model distil_bert_model
.forward_t(Some(input_tensor), None, None, false) .forward_t(Some(&input_tensor), None, None, false)
.unwrap() .unwrap()
}); });

View File

@ -59,7 +59,7 @@ fn electra_masked_lm() -> anyhow::Result<()> {
// Forward pass // Forward pass
let model_output = let model_output =
no_grad(|| electra_model.forward_t(Some(input_tensor), None, None, None, None, false)); no_grad(|| electra_model.forward_t(Some(&input_tensor), None, None, None, None, false));
// Decode output // Decode output
let index_1 = model_output let index_1 = model_output
@ -138,7 +138,7 @@ fn electra_discriminator() -> anyhow::Result<()> {
// Forward pass // Forward pass
let model_output = let model_output =
no_grad(|| electra_model.forward_t(Some(input_tensor), None, None, None, None, false)); no_grad(|| electra_model.forward_t(Some(&input_tensor), None, None, None, None, false));
// Validate model predictions // Validate model predictions
let expected_probabilities = vec![ let expected_probabilities = vec![

View File

@ -82,13 +82,13 @@ fn roberta_masked_lm() -> anyhow::Result<()> {
// Forward pass // Forward pass
let model_output = no_grad(|| { let model_output = no_grad(|| {
roberta_model.forward_t( roberta_model.forward_t(
Some(input_tensor), Some(&input_tensor),
None,
None,
None, None,
None, None,
None, None,
None, None,
&None,
&None,
false, false,
) )
}); });
@ -172,7 +172,7 @@ fn roberta_for_sequence_classification() -> anyhow::Result<()> {
// Forward pass // Forward pass
let model_output = let model_output =
no_grad(|| roberta_model.forward_t(Some(input_tensor), None, None, None, None, false)); no_grad(|| roberta_model.forward_t(Some(&input_tensor), None, None, None, None, false));
assert_eq!(model_output.logits.size(), &[2, 3]); assert_eq!(model_output.logits.size(), &[2, 3]);
assert_eq!( assert_eq!(
@ -242,7 +242,7 @@ fn roberta_for_multiple_choice() -> anyhow::Result<()> {
.unsqueeze(0); .unsqueeze(0);
// Forward pass // Forward pass
let model_output = no_grad(|| roberta_model.forward_t(input_tensor, None, None, None, false)); let model_output = no_grad(|| roberta_model.forward_t(&input_tensor, None, None, None, false));
assert_eq!(model_output.logits.size(), &[1, 2]); assert_eq!(model_output.logits.size(), &[1, 2]);
assert_eq!( assert_eq!(
@ -317,7 +317,7 @@ fn roberta_for_token_classification() -> anyhow::Result<()> {
// Forward pass // Forward pass
let model_output = let model_output =
no_grad(|| roberta_model.forward_t(Some(input_tensor), None, None, None, None, false)); no_grad(|| roberta_model.forward_t(Some(&input_tensor), None, None, None, None, false));
assert_eq!(model_output.logits.size(), &[2, 9, 4]); assert_eq!(model_output.logits.size(), &[2, 9, 4]);
assert_eq!( assert_eq!(