Updated borrowing for XLNet, integration tests

This commit is contained in:
guillaume-be 2021-08-20 11:08:37 +02:00
parent 2baf659e9b
commit cb6bc34eb4
10 changed files with 136 additions and 159 deletions

View File

@ -67,13 +67,13 @@ fn main() -> anyhow::Result<()> {
// Forward pass
let model_output = no_grad(|| {
bert_model.forward_t(
Some(input_tensor),
Some(&input_tensor),
None,
None,
None,
None,
None,
None,
&None,
&None,
false,
)
});

View File

@ -15,6 +15,7 @@
use crate::bert::BertConfig;
use crate::common::activations::Activation;
use crate::common::dropout::Dropout;
use crate::common::embeddings::get_shape_and_device_from_ids_embeddings_pair;
use crate::electra::embeddings::ElectraEmbeddings;
use crate::{bert::encoder::BertEncoder, common::activations::TensorFunction};
use crate::{Config, RustBertError};
@ -216,10 +217,10 @@ impl ElectraModel {
/// let model_output = no_grad(|| {
/// electra_model
/// .forward_t(
/// Some(input_tensor),
/// Some(mask),
/// Some(token_type_ids),
/// Some(position_ids),
/// Some(&input_tensor),
/// Some(&mask),
/// Some(&token_type_ids),
/// Some(&position_ids),
/// None,
/// false,
/// )
@ -228,33 +229,22 @@ impl ElectraModel {
/// ```
pub fn forward_t(
&self,
input_ids: Option<Tensor>,
mask: Option<Tensor>,
token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
input_ids: Option<&Tensor>,
mask: Option<&Tensor>,
token_type_ids: Option<&Tensor>,
position_ids: Option<&Tensor>,
input_embeds: Option<&Tensor>,
train: bool,
) -> Result<ElectraModelOutput, RustBertError> {
let (input_shape, device) = match &input_ids {
Some(input_value) => match &input_embeds {
Some(_) => {
return Err(RustBertError::ValueError(
"Only one of input ids or input embeddings may be set".into(),
));
}
None => (input_value.size(), input_value.device()),
},
None => match &input_embeds {
Some(embeds) => (vec![embeds.size()[0], embeds.size()[1]], embeds.device()),
None => {
return Err(RustBertError::ValueError(
"At least one of input ids or input embeddings must be set".into(),
));
}
},
};
let (input_shape, device) =
get_shape_and_device_from_ids_embeddings_pair(input_ids, input_embeds)?;
let mask = mask.unwrap_or_else(|| Tensor::ones(&input_shape, (Kind::Int64, device)));
let calc_mask = if mask.is_none() {
Some(Tensor::ones(&input_shape, (Kind::Int64, device)))
} else {
None
};
let mask = mask.unwrap_or_else(|| calc_mask.as_ref().unwrap());
let extended_attention_mask = match mask.dim() {
3 => mask.unsqueeze(1),
@ -590,10 +580,10 @@ impl ElectraForMaskedLM {
///
/// let model_output = no_grad(|| {
/// electra_model.forward_t(
/// Some(input_tensor),
/// Some(mask),
/// Some(token_type_ids),
/// Some(position_ids),
/// Some(&input_tensor),
/// Some(&mask),
/// Some(&token_type_ids),
/// Some(&position_ids),
/// None,
/// false,
/// )
@ -601,11 +591,11 @@ impl ElectraForMaskedLM {
/// ```
pub fn forward_t(
&self,
input_ids: Option<Tensor>,
mask: Option<Tensor>,
token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
input_ids: Option<&Tensor>,
mask: Option<&Tensor>,
token_type_ids: Option<&Tensor>,
position_ids: Option<&Tensor>,
input_embeds: Option<&Tensor>,
train: bool,
) -> ElectraMaskedLMOutput {
let base_model_output = self
@ -717,21 +707,21 @@ impl ElectraDiscriminator {
///
/// let model_output = no_grad(|| {
/// electra_model
/// .forward_t(Some(input_tensor),
/// Some(mask),
/// Some(token_type_ids),
/// Some(position_ids),
/// .forward_t(Some(&input_tensor),
/// Some(&mask),
/// Some(&token_type_ids),
/// Some(&position_ids),
/// None,
/// false)
/// });
/// ```
pub fn forward_t(
&self,
input_ids: Option<Tensor>,
mask: Option<Tensor>,
token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
input_ids: Option<&Tensor>,
mask: Option<&Tensor>,
token_type_ids: Option<&Tensor>,
position_ids: Option<&Tensor>,
input_embeds: Option<&Tensor>,
train: bool,
) -> ElectraDiscriminatorOutput {
let base_model_output = self
@ -858,21 +848,21 @@ impl ElectraForTokenClassification {
///
/// let model_output = no_grad(|| {
/// electra_model
/// .forward_t(Some(input_tensor),
/// Some(mask),
/// Some(token_type_ids),
/// Some(position_ids),
/// .forward_t(Some(&input_tensor),
/// Some(&mask),
/// Some(&token_type_ids),
/// Some(&position_ids),
/// None,
/// false)
/// });
/// ```
pub fn forward_t(
&self,
input_ids: Option<Tensor>,
mask: Option<Tensor>,
token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
input_ids: Option<&Tensor>,
mask: Option<&Tensor>,
token_type_ids: Option<&Tensor>,
position_ids: Option<&Tensor>,
input_embeds: Option<&Tensor>,
train: bool,
) -> ElectraTokenClassificationOutput {
let base_model_output = self

View File

@ -13,6 +13,7 @@
// limitations under the License.
use crate::common::dropout::Dropout;
use crate::common::embeddings::process_ids_embeddings_pair;
use crate::electra::electra_model::ElectraConfig;
use crate::RustBertError;
use std::borrow::Borrow;
@ -84,50 +85,40 @@ impl ElectraEmbeddings {
pub fn forward_t(
&self,
input_ids: Option<Tensor>,
token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
input_ids: Option<&Tensor>,
token_type_ids: Option<&Tensor>,
position_ids: Option<&Tensor>,
input_embeds: Option<&Tensor>,
train: bool,
) -> Result<Tensor, RustBertError> {
let (input_embeddings, input_shape) = match input_ids {
Some(input_value) => match input_embeds {
Some(_) => {
return Err(RustBertError::ValueError(
"Only one of input ids or input embeddings may be set".into(),
));
}
None => (
input_value.apply_t(&self.word_embeddings, train),
input_value.size(),
),
},
None => match input_embeds {
Some(embeds) => {
let size = vec![embeds.size()[0], embeds.size()[1]];
(embeds, size)
}
None => {
return Err(RustBertError::ValueError(
"At least one of input ids or input embeddings must be set".into(),
));
}
},
};
let (calc_input_embeddings, input_shape, _) =
process_ids_embeddings_pair(input_ids, input_embeds, &self.word_embeddings)?;
let seq_length = input_embeddings.as_ref().size()[1].to_owned();
let input_embeddings =
input_embeds.unwrap_or_else(|| calc_input_embeddings.as_ref().unwrap());
let seq_length = input_embeddings.size()[1].to_owned();
let position_ids = match position_ids {
Some(value) => value,
None => Tensor::arange(seq_length, (Kind::Int64, input_embeddings.device()))
.unsqueeze(0)
.expand(&input_shape, true),
let calc_position_ids = if position_ids.is_none() {
Some(
Tensor::arange(seq_length, (Kind::Int64, input_embeddings.device()))
.unsqueeze(0)
.expand(&input_shape, true),
)
} else {
None
};
let position_ids = position_ids.unwrap_or_else(|| calc_position_ids.as_ref().unwrap());
let token_type_ids = match token_type_ids {
Some(value) => value,
None => Tensor::zeros(&input_shape, (Kind::Int64, input_embeddings.device())),
let calc_token_type_ids = if token_type_ids.is_none() {
Some(Tensor::zeros(
&input_shape,
(Kind::Int64, input_embeddings.device()),
))
} else {
None
};
let token_type_ids =
token_type_ids.unwrap_or_else(|| calc_token_type_ids.as_ref().unwrap());
let position_embeddings = position_ids.apply(&self.position_embeddings);
let token_type_embeddings = token_type_ids.apply(&self.token_type_embeddings);

View File

@ -545,12 +545,12 @@ impl TokenClassificationOption {
Self::Longformer(ref model) => {
model
.forward_t(
input_ids.as_ref(),
mask.as_ref(),
input_ids,
mask,
None,
token_type_ids.as_ref(),
position_ids.as_ref(),
input_embeds.as_ref(),
token_type_ids,
position_ids,
input_embeds,
train,
)
.expect("Error in longformer forward_t")

View File

@ -390,38 +390,34 @@ impl XLNetModel {
perm_mask: Option<&Tensor>,
target_mapping: Option<&Tensor>,
token_type_ids: Option<&Tensor>,
input_embeds: Option<Tensor>,
input_embeds: Option<&Tensor>,
train: bool,
) -> Result<XLNetModelOutput, RustBertError> {
let (word_emb_k, input_shape) = match input_ids {
Some(input_value) => match input_embeds {
Some(_) => {
return Err(RustBertError::ValueError(
"Only one of input ids or input embeddings may be set".into(),
));
}
None => {
let size = input_value.size();
(
input_value
.transpose(0, 1)
.contiguous()
.apply_t(&self.word_embeddings, train),
vec![size[1], size[0]],
)
}
},
None => match input_embeds {
Some(embeds) => {
let size = vec![embeds.size()[1], embeds.size()[0]];
(embeds.transpose(0, 1).contiguous(), size)
}
None => {
return Err(RustBertError::ValueError(
"At least one of input ids or input embeddings must be set".into(),
));
}
},
let (word_emb_k, input_shape) = match (input_ids, input_embeds) {
(Some(_), Some(_)) => {
return Err(RustBertError::ValueError(
"Only one of input ids or input embeddings may be set".into(),
));
}
(Some(input_value), None) => {
let size = input_value.size();
(
input_value
.transpose(0, 1)
.contiguous()
.apply_t(&self.word_embeddings, train),
vec![size[1], size[0]],
)
}
(None, Some(embeds)) => {
let size = vec![embeds.size()[1], embeds.size()[0]];
(embeds.transpose(0, 1).contiguous(), size)
}
(None, None) => {
return Err(RustBertError::ValueError(
"At least one of input ids or input embeddings must be set".into(),
));
}
};
let token_type_ids =
@ -715,7 +711,7 @@ impl XLNetLMHeadModel {
perm_mask: Option<&Tensor>,
target_mapping: Option<&Tensor>,
token_type_ids: Option<&Tensor>,
input_embeds: Option<Tensor>,
input_embeds: Option<&Tensor>,
train: bool,
) -> Result<LMModelOutput, RustBertError> {
let base_model_output = self.base_model.forward_t(
@ -966,7 +962,7 @@ impl XLNetForSequenceClassification {
perm_mask: Option<&Tensor>,
target_mapping: Option<&Tensor>,
token_type_ids: Option<&Tensor>,
input_embeds: Option<Tensor>,
input_embeds: Option<&Tensor>,
train: bool,
) -> XLNetSequenceClassificationOutput {
let base_model_output = self
@ -1124,7 +1120,7 @@ impl XLNetForTokenClassification {
perm_mask: Option<&Tensor>,
target_mapping: Option<&Tensor>,
token_type_ids: Option<&Tensor>,
input_embeds: Option<Tensor>,
input_embeds: Option<&Tensor>,
train: bool,
) -> XLNetTokenClassificationOutput {
let base_model_output = self
@ -1273,7 +1269,7 @@ impl XLNetForMultipleChoice {
perm_mask: Option<&Tensor>,
target_mapping: Option<&Tensor>,
token_type_ids: Option<&Tensor>,
input_embeds: Option<Tensor>,
input_embeds: Option<&Tensor>,
train: bool,
) -> XLNetSequenceClassificationOutput {
let (input_ids, num_choices) = match input_ids {
@ -1305,7 +1301,7 @@ impl XLNetForMultipleChoice {
perm_mask,
target_mapping,
token_type_ids.as_ref(),
input_embeds,
input_embeds.as_ref(),
train,
)
.unwrap();
@ -1444,7 +1440,7 @@ impl XLNetForQuestionAnswering {
perm_mask: Option<&Tensor>,
target_mapping: Option<&Tensor>,
token_type_ids: Option<&Tensor>,
input_embeds: Option<Tensor>,
input_embeds: Option<&Tensor>,
train: bool,
) -> XLNetQuestionAnsweringOutput {
let base_model_output = self

View File

@ -62,7 +62,7 @@ fn albert_masked_lm() -> anyhow::Result<()> {
// Forward pass
let model_output =
no_grad(|| albert_model.forward_t(Some(input_tensor), None, None, None, None, false));
no_grad(|| albert_model.forward_t(Some(&input_tensor), None, None, None, None, false));
// Print masked tokens
let index_1 = model_output
@ -135,7 +135,7 @@ fn albert_for_sequence_classification() -> anyhow::Result<()> {
// Forward pass
let model_output =
no_grad(|| albert_model.forward_t(Some(input_tensor), None, None, None, None, false));
no_grad(|| albert_model.forward_t(Some(&input_tensor), None, None, None, None, false));
assert_eq!(model_output.logits.size(), &[2, 3]);
assert_eq!(
@ -199,7 +199,7 @@ fn albert_for_multiple_choice() -> anyhow::Result<()> {
// Forward pass
let model_output = no_grad(|| {
albert_model
.forward_t(Some(input_tensor), None, None, None, None, false)
.forward_t(Some(&input_tensor), None, None, None, None, false)
.unwrap()
});
@ -268,7 +268,7 @@ fn albert_for_token_classification() -> anyhow::Result<()> {
// Forward pass
let model_output =
no_grad(|| bert_model.forward_t(Some(input_tensor), None, None, None, None, false));
no_grad(|| bert_model.forward_t(Some(&input_tensor), None, None, None, None, false));
assert_eq!(model_output.logits.size(), &[2, 12, 4]);
assert_eq!(
@ -329,7 +329,7 @@ fn albert_for_question_answering() -> anyhow::Result<()> {
// Forward pass
let model_output =
no_grad(|| albert_model.forward_t(Some(input_tensor), None, None, None, None, false));
no_grad(|| albert_model.forward_t(Some(&input_tensor), None, None, None, None, false));
assert_eq!(model_output.start_logits.size(), &[2, 12]);
assert_eq!(model_output.end_logits.size(), &[2, 12]);

View File

@ -72,13 +72,13 @@ fn bert_masked_lm() -> anyhow::Result<()> {
// Forward pass
let model_output = no_grad(|| {
bert_model.forward_t(
Some(input_tensor),
Some(&input_tensor),
None,
None,
None,
None,
None,
None,
&None,
&None,
false,
)
});
@ -152,7 +152,7 @@ fn bert_for_sequence_classification() -> anyhow::Result<()> {
// Forward pass
let model_output =
no_grad(|| bert_model.forward_t(Some(input_tensor), None, None, None, None, false));
no_grad(|| bert_model.forward_t(Some(&input_tensor), None, None, None, None, false));
assert_eq!(model_output.logits.size(), &[2, 3]);
assert_eq!(
@ -212,7 +212,7 @@ fn bert_for_multiple_choice() -> anyhow::Result<()> {
.unsqueeze(0);
// Forward pass
let model_output = no_grad(|| bert_model.forward_t(input_tensor, None, None, None, false));
let model_output = no_grad(|| bert_model.forward_t(&input_tensor, None, None, None, false));
assert_eq!(model_output.logits.size(), &[1, 2]);
assert_eq!(
@ -277,7 +277,7 @@ fn bert_for_token_classification() -> anyhow::Result<()> {
// Forward pass
let model_output =
no_grad(|| bert_model.forward_t(Some(input_tensor), None, None, None, None, false));
no_grad(|| bert_model.forward_t(Some(&input_tensor), None, None, None, None, false));
assert_eq!(model_output.logits.size(), &[2, 11, 4]);
assert_eq!(
@ -336,7 +336,7 @@ fn bert_for_question_answering() -> anyhow::Result<()> {
// Forward pass
let model_output =
no_grad(|| bert_model.forward_t(Some(input_tensor), None, None, None, None, false));
no_grad(|| bert_model.forward_t(Some(&input_tensor), None, None, None, None, false));
assert_eq!(model_output.start_logits.size(), &[2, 11]);
assert_eq!(model_output.end_logits.size(), &[2, 11]);

View File

@ -96,7 +96,7 @@ fn distilbert_masked_lm() -> anyhow::Result<()> {
// Forward pass
let model_output = no_grad(|| {
distil_bert_model
.forward_t(Some(input_tensor), None, None, false)
.forward_t(Some(&input_tensor), None, None, false)
.unwrap()
});
@ -167,7 +167,7 @@ fn distilbert_for_question_answering() -> anyhow::Result<()> {
// Forward pass
let model_output = no_grad(|| {
distil_bert_model
.forward_t(Some(input_tensor), None, None, false)
.forward_t(Some(&input_tensor), None, None, false)
.unwrap()
});
@ -238,7 +238,7 @@ fn distilbert_for_token_classification() -> anyhow::Result<()> {
// Forward pass
let model_output = no_grad(|| {
distil_bert_model
.forward_t(Some(input_tensor), None, None, false)
.forward_t(Some(&input_tensor), None, None, false)
.unwrap()
});

View File

@ -59,7 +59,7 @@ fn electra_masked_lm() -> anyhow::Result<()> {
// Forward pass
let model_output =
no_grad(|| electra_model.forward_t(Some(input_tensor), None, None, None, None, false));
no_grad(|| electra_model.forward_t(Some(&input_tensor), None, None, None, None, false));
// Decode output
let index_1 = model_output
@ -138,7 +138,7 @@ fn electra_discriminator() -> anyhow::Result<()> {
// Forward pass
let model_output =
no_grad(|| electra_model.forward_t(Some(input_tensor), None, None, None, None, false));
no_grad(|| electra_model.forward_t(Some(&input_tensor), None, None, None, None, false));
// Validate model predictions
let expected_probabilities = vec![

View File

@ -82,13 +82,13 @@ fn roberta_masked_lm() -> anyhow::Result<()> {
// Forward pass
let model_output = no_grad(|| {
roberta_model.forward_t(
Some(input_tensor),
Some(&input_tensor),
None,
None,
None,
None,
None,
None,
&None,
&None,
false,
)
});
@ -172,7 +172,7 @@ fn roberta_for_sequence_classification() -> anyhow::Result<()> {
// Forward pass
let model_output =
no_grad(|| roberta_model.forward_t(Some(input_tensor), None, None, None, None, false));
no_grad(|| roberta_model.forward_t(Some(&input_tensor), None, None, None, None, false));
assert_eq!(model_output.logits.size(), &[2, 3]);
assert_eq!(
@ -242,7 +242,7 @@ fn roberta_for_multiple_choice() -> anyhow::Result<()> {
.unsqueeze(0);
// Forward pass
let model_output = no_grad(|| roberta_model.forward_t(input_tensor, None, None, None, false));
let model_output = no_grad(|| roberta_model.forward_t(&input_tensor, None, None, None, false));
assert_eq!(model_output.logits.size(), &[1, 2]);
assert_eq!(
@ -317,7 +317,7 @@ fn roberta_for_token_classification() -> anyhow::Result<()> {
// Forward pass
let model_output =
no_grad(|| roberta_model.forward_t(Some(input_tensor), None, None, None, None, false));
no_grad(|| roberta_model.forward_t(Some(&input_tensor), None, None, None, None, false));
assert_eq!(model_output.logits.size(), &[2, 9, 4]);
assert_eq!(