mirror of
https://github.com/guillaume-be/rust-bert.git
synced 2024-08-16 16:10:25 +03:00
Updated borrowing for XLNet, integration tests
This commit is contained in:
parent
2baf659e9b
commit
cb6bc34eb4
@ -67,13 +67,13 @@ fn main() -> anyhow::Result<()> {
|
||||
// Forward pass
|
||||
let model_output = no_grad(|| {
|
||||
bert_model.forward_t(
|
||||
Some(input_tensor),
|
||||
Some(&input_tensor),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
&None,
|
||||
&None,
|
||||
false,
|
||||
)
|
||||
});
|
||||
|
@ -15,6 +15,7 @@
|
||||
use crate::bert::BertConfig;
|
||||
use crate::common::activations::Activation;
|
||||
use crate::common::dropout::Dropout;
|
||||
use crate::common::embeddings::get_shape_and_device_from_ids_embeddings_pair;
|
||||
use crate::electra::embeddings::ElectraEmbeddings;
|
||||
use crate::{bert::encoder::BertEncoder, common::activations::TensorFunction};
|
||||
use crate::{Config, RustBertError};
|
||||
@ -216,10 +217,10 @@ impl ElectraModel {
|
||||
/// let model_output = no_grad(|| {
|
||||
/// electra_model
|
||||
/// .forward_t(
|
||||
/// Some(input_tensor),
|
||||
/// Some(mask),
|
||||
/// Some(token_type_ids),
|
||||
/// Some(position_ids),
|
||||
/// Some(&input_tensor),
|
||||
/// Some(&mask),
|
||||
/// Some(&token_type_ids),
|
||||
/// Some(&position_ids),
|
||||
/// None,
|
||||
/// false,
|
||||
/// )
|
||||
@ -228,33 +229,22 @@ impl ElectraModel {
|
||||
/// ```
|
||||
pub fn forward_t(
|
||||
&self,
|
||||
input_ids: Option<Tensor>,
|
||||
mask: Option<Tensor>,
|
||||
token_type_ids: Option<Tensor>,
|
||||
position_ids: Option<Tensor>,
|
||||
input_embeds: Option<Tensor>,
|
||||
input_ids: Option<&Tensor>,
|
||||
mask: Option<&Tensor>,
|
||||
token_type_ids: Option<&Tensor>,
|
||||
position_ids: Option<&Tensor>,
|
||||
input_embeds: Option<&Tensor>,
|
||||
train: bool,
|
||||
) -> Result<ElectraModelOutput, RustBertError> {
|
||||
let (input_shape, device) = match &input_ids {
|
||||
Some(input_value) => match &input_embeds {
|
||||
Some(_) => {
|
||||
return Err(RustBertError::ValueError(
|
||||
"Only one of input ids or input embeddings may be set".into(),
|
||||
));
|
||||
}
|
||||
None => (input_value.size(), input_value.device()),
|
||||
},
|
||||
None => match &input_embeds {
|
||||
Some(embeds) => (vec![embeds.size()[0], embeds.size()[1]], embeds.device()),
|
||||
None => {
|
||||
return Err(RustBertError::ValueError(
|
||||
"At least one of input ids or input embeddings must be set".into(),
|
||||
));
|
||||
}
|
||||
},
|
||||
};
|
||||
let (input_shape, device) =
|
||||
get_shape_and_device_from_ids_embeddings_pair(input_ids, input_embeds)?;
|
||||
|
||||
let mask = mask.unwrap_or_else(|| Tensor::ones(&input_shape, (Kind::Int64, device)));
|
||||
let calc_mask = if mask.is_none() {
|
||||
Some(Tensor::ones(&input_shape, (Kind::Int64, device)))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let mask = mask.unwrap_or_else(|| calc_mask.as_ref().unwrap());
|
||||
|
||||
let extended_attention_mask = match mask.dim() {
|
||||
3 => mask.unsqueeze(1),
|
||||
@ -590,10 +580,10 @@ impl ElectraForMaskedLM {
|
||||
///
|
||||
/// let model_output = no_grad(|| {
|
||||
/// electra_model.forward_t(
|
||||
/// Some(input_tensor),
|
||||
/// Some(mask),
|
||||
/// Some(token_type_ids),
|
||||
/// Some(position_ids),
|
||||
/// Some(&input_tensor),
|
||||
/// Some(&mask),
|
||||
/// Some(&token_type_ids),
|
||||
/// Some(&position_ids),
|
||||
/// None,
|
||||
/// false,
|
||||
/// )
|
||||
@ -601,11 +591,11 @@ impl ElectraForMaskedLM {
|
||||
/// ```
|
||||
pub fn forward_t(
|
||||
&self,
|
||||
input_ids: Option<Tensor>,
|
||||
mask: Option<Tensor>,
|
||||
token_type_ids: Option<Tensor>,
|
||||
position_ids: Option<Tensor>,
|
||||
input_embeds: Option<Tensor>,
|
||||
input_ids: Option<&Tensor>,
|
||||
mask: Option<&Tensor>,
|
||||
token_type_ids: Option<&Tensor>,
|
||||
position_ids: Option<&Tensor>,
|
||||
input_embeds: Option<&Tensor>,
|
||||
train: bool,
|
||||
) -> ElectraMaskedLMOutput {
|
||||
let base_model_output = self
|
||||
@ -717,21 +707,21 @@ impl ElectraDiscriminator {
|
||||
///
|
||||
/// let model_output = no_grad(|| {
|
||||
/// electra_model
|
||||
/// .forward_t(Some(input_tensor),
|
||||
/// Some(mask),
|
||||
/// Some(token_type_ids),
|
||||
/// Some(position_ids),
|
||||
/// .forward_t(Some(&input_tensor),
|
||||
/// Some(&mask),
|
||||
/// Some(&token_type_ids),
|
||||
/// Some(&position_ids),
|
||||
/// None,
|
||||
/// false)
|
||||
/// });
|
||||
/// ```
|
||||
pub fn forward_t(
|
||||
&self,
|
||||
input_ids: Option<Tensor>,
|
||||
mask: Option<Tensor>,
|
||||
token_type_ids: Option<Tensor>,
|
||||
position_ids: Option<Tensor>,
|
||||
input_embeds: Option<Tensor>,
|
||||
input_ids: Option<&Tensor>,
|
||||
mask: Option<&Tensor>,
|
||||
token_type_ids: Option<&Tensor>,
|
||||
position_ids: Option<&Tensor>,
|
||||
input_embeds: Option<&Tensor>,
|
||||
train: bool,
|
||||
) -> ElectraDiscriminatorOutput {
|
||||
let base_model_output = self
|
||||
@ -858,21 +848,21 @@ impl ElectraForTokenClassification {
|
||||
///
|
||||
/// let model_output = no_grad(|| {
|
||||
/// electra_model
|
||||
/// .forward_t(Some(input_tensor),
|
||||
/// Some(mask),
|
||||
/// Some(token_type_ids),
|
||||
/// Some(position_ids),
|
||||
/// .forward_t(Some(&input_tensor),
|
||||
/// Some(&mask),
|
||||
/// Some(&token_type_ids),
|
||||
/// Some(&position_ids),
|
||||
/// None,
|
||||
/// false)
|
||||
/// });
|
||||
/// ```
|
||||
pub fn forward_t(
|
||||
&self,
|
||||
input_ids: Option<Tensor>,
|
||||
mask: Option<Tensor>,
|
||||
token_type_ids: Option<Tensor>,
|
||||
position_ids: Option<Tensor>,
|
||||
input_embeds: Option<Tensor>,
|
||||
input_ids: Option<&Tensor>,
|
||||
mask: Option<&Tensor>,
|
||||
token_type_ids: Option<&Tensor>,
|
||||
position_ids: Option<&Tensor>,
|
||||
input_embeds: Option<&Tensor>,
|
||||
train: bool,
|
||||
) -> ElectraTokenClassificationOutput {
|
||||
let base_model_output = self
|
||||
|
@ -13,6 +13,7 @@
|
||||
// limitations under the License.
|
||||
|
||||
use crate::common::dropout::Dropout;
|
||||
use crate::common::embeddings::process_ids_embeddings_pair;
|
||||
use crate::electra::electra_model::ElectraConfig;
|
||||
use crate::RustBertError;
|
||||
use std::borrow::Borrow;
|
||||
@ -84,50 +85,40 @@ impl ElectraEmbeddings {
|
||||
|
||||
pub fn forward_t(
|
||||
&self,
|
||||
input_ids: Option<Tensor>,
|
||||
token_type_ids: Option<Tensor>,
|
||||
position_ids: Option<Tensor>,
|
||||
input_embeds: Option<Tensor>,
|
||||
input_ids: Option<&Tensor>,
|
||||
token_type_ids: Option<&Tensor>,
|
||||
position_ids: Option<&Tensor>,
|
||||
input_embeds: Option<&Tensor>,
|
||||
train: bool,
|
||||
) -> Result<Tensor, RustBertError> {
|
||||
let (input_embeddings, input_shape) = match input_ids {
|
||||
Some(input_value) => match input_embeds {
|
||||
Some(_) => {
|
||||
return Err(RustBertError::ValueError(
|
||||
"Only one of input ids or input embeddings may be set".into(),
|
||||
));
|
||||
}
|
||||
None => (
|
||||
input_value.apply_t(&self.word_embeddings, train),
|
||||
input_value.size(),
|
||||
),
|
||||
},
|
||||
None => match input_embeds {
|
||||
Some(embeds) => {
|
||||
let size = vec![embeds.size()[0], embeds.size()[1]];
|
||||
(embeds, size)
|
||||
}
|
||||
None => {
|
||||
return Err(RustBertError::ValueError(
|
||||
"At least one of input ids or input embeddings must be set".into(),
|
||||
));
|
||||
}
|
||||
},
|
||||
};
|
||||
let (calc_input_embeddings, input_shape, _) =
|
||||
process_ids_embeddings_pair(input_ids, input_embeds, &self.word_embeddings)?;
|
||||
|
||||
let seq_length = input_embeddings.as_ref().size()[1].to_owned();
|
||||
let input_embeddings =
|
||||
input_embeds.unwrap_or_else(|| calc_input_embeddings.as_ref().unwrap());
|
||||
let seq_length = input_embeddings.size()[1].to_owned();
|
||||
|
||||
let position_ids = match position_ids {
|
||||
Some(value) => value,
|
||||
None => Tensor::arange(seq_length, (Kind::Int64, input_embeddings.device()))
|
||||
.unsqueeze(0)
|
||||
.expand(&input_shape, true),
|
||||
let calc_position_ids = if position_ids.is_none() {
|
||||
Some(
|
||||
Tensor::arange(seq_length, (Kind::Int64, input_embeddings.device()))
|
||||
.unsqueeze(0)
|
||||
.expand(&input_shape, true),
|
||||
)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let position_ids = position_ids.unwrap_or_else(|| calc_position_ids.as_ref().unwrap());
|
||||
|
||||
let token_type_ids = match token_type_ids {
|
||||
Some(value) => value,
|
||||
None => Tensor::zeros(&input_shape, (Kind::Int64, input_embeddings.device())),
|
||||
let calc_token_type_ids = if token_type_ids.is_none() {
|
||||
Some(Tensor::zeros(
|
||||
&input_shape,
|
||||
(Kind::Int64, input_embeddings.device()),
|
||||
))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let token_type_ids =
|
||||
token_type_ids.unwrap_or_else(|| calc_token_type_ids.as_ref().unwrap());
|
||||
|
||||
let position_embeddings = position_ids.apply(&self.position_embeddings);
|
||||
let token_type_embeddings = token_type_ids.apply(&self.token_type_embeddings);
|
||||
|
@ -545,12 +545,12 @@ impl TokenClassificationOption {
|
||||
Self::Longformer(ref model) => {
|
||||
model
|
||||
.forward_t(
|
||||
input_ids.as_ref(),
|
||||
mask.as_ref(),
|
||||
input_ids,
|
||||
mask,
|
||||
None,
|
||||
token_type_ids.as_ref(),
|
||||
position_ids.as_ref(),
|
||||
input_embeds.as_ref(),
|
||||
token_type_ids,
|
||||
position_ids,
|
||||
input_embeds,
|
||||
train,
|
||||
)
|
||||
.expect("Error in longformer forward_t")
|
||||
|
@ -390,38 +390,34 @@ impl XLNetModel {
|
||||
perm_mask: Option<&Tensor>,
|
||||
target_mapping: Option<&Tensor>,
|
||||
token_type_ids: Option<&Tensor>,
|
||||
input_embeds: Option<Tensor>,
|
||||
input_embeds: Option<&Tensor>,
|
||||
train: bool,
|
||||
) -> Result<XLNetModelOutput, RustBertError> {
|
||||
let (word_emb_k, input_shape) = match input_ids {
|
||||
Some(input_value) => match input_embeds {
|
||||
Some(_) => {
|
||||
return Err(RustBertError::ValueError(
|
||||
"Only one of input ids or input embeddings may be set".into(),
|
||||
));
|
||||
}
|
||||
None => {
|
||||
let size = input_value.size();
|
||||
(
|
||||
input_value
|
||||
.transpose(0, 1)
|
||||
.contiguous()
|
||||
.apply_t(&self.word_embeddings, train),
|
||||
vec![size[1], size[0]],
|
||||
)
|
||||
}
|
||||
},
|
||||
None => match input_embeds {
|
||||
Some(embeds) => {
|
||||
let size = vec![embeds.size()[1], embeds.size()[0]];
|
||||
(embeds.transpose(0, 1).contiguous(), size)
|
||||
}
|
||||
None => {
|
||||
return Err(RustBertError::ValueError(
|
||||
"At least one of input ids or input embeddings must be set".into(),
|
||||
));
|
||||
}
|
||||
},
|
||||
let (word_emb_k, input_shape) = match (input_ids, input_embeds) {
|
||||
(Some(_), Some(_)) => {
|
||||
return Err(RustBertError::ValueError(
|
||||
"Only one of input ids or input embeddings may be set".into(),
|
||||
));
|
||||
}
|
||||
(Some(input_value), None) => {
|
||||
let size = input_value.size();
|
||||
(
|
||||
input_value
|
||||
.transpose(0, 1)
|
||||
.contiguous()
|
||||
.apply_t(&self.word_embeddings, train),
|
||||
vec![size[1], size[0]],
|
||||
)
|
||||
}
|
||||
(None, Some(embeds)) => {
|
||||
let size = vec![embeds.size()[1], embeds.size()[0]];
|
||||
(embeds.transpose(0, 1).contiguous(), size)
|
||||
}
|
||||
(None, None) => {
|
||||
return Err(RustBertError::ValueError(
|
||||
"At least one of input ids or input embeddings must be set".into(),
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
let token_type_ids =
|
||||
@ -715,7 +711,7 @@ impl XLNetLMHeadModel {
|
||||
perm_mask: Option<&Tensor>,
|
||||
target_mapping: Option<&Tensor>,
|
||||
token_type_ids: Option<&Tensor>,
|
||||
input_embeds: Option<Tensor>,
|
||||
input_embeds: Option<&Tensor>,
|
||||
train: bool,
|
||||
) -> Result<LMModelOutput, RustBertError> {
|
||||
let base_model_output = self.base_model.forward_t(
|
||||
@ -966,7 +962,7 @@ impl XLNetForSequenceClassification {
|
||||
perm_mask: Option<&Tensor>,
|
||||
target_mapping: Option<&Tensor>,
|
||||
token_type_ids: Option<&Tensor>,
|
||||
input_embeds: Option<Tensor>,
|
||||
input_embeds: Option<&Tensor>,
|
||||
train: bool,
|
||||
) -> XLNetSequenceClassificationOutput {
|
||||
let base_model_output = self
|
||||
@ -1124,7 +1120,7 @@ impl XLNetForTokenClassification {
|
||||
perm_mask: Option<&Tensor>,
|
||||
target_mapping: Option<&Tensor>,
|
||||
token_type_ids: Option<&Tensor>,
|
||||
input_embeds: Option<Tensor>,
|
||||
input_embeds: Option<&Tensor>,
|
||||
train: bool,
|
||||
) -> XLNetTokenClassificationOutput {
|
||||
let base_model_output = self
|
||||
@ -1273,7 +1269,7 @@ impl XLNetForMultipleChoice {
|
||||
perm_mask: Option<&Tensor>,
|
||||
target_mapping: Option<&Tensor>,
|
||||
token_type_ids: Option<&Tensor>,
|
||||
input_embeds: Option<Tensor>,
|
||||
input_embeds: Option<&Tensor>,
|
||||
train: bool,
|
||||
) -> XLNetSequenceClassificationOutput {
|
||||
let (input_ids, num_choices) = match input_ids {
|
||||
@ -1305,7 +1301,7 @@ impl XLNetForMultipleChoice {
|
||||
perm_mask,
|
||||
target_mapping,
|
||||
token_type_ids.as_ref(),
|
||||
input_embeds,
|
||||
input_embeds.as_ref(),
|
||||
train,
|
||||
)
|
||||
.unwrap();
|
||||
@ -1444,7 +1440,7 @@ impl XLNetForQuestionAnswering {
|
||||
perm_mask: Option<&Tensor>,
|
||||
target_mapping: Option<&Tensor>,
|
||||
token_type_ids: Option<&Tensor>,
|
||||
input_embeds: Option<Tensor>,
|
||||
input_embeds: Option<&Tensor>,
|
||||
train: bool,
|
||||
) -> XLNetQuestionAnsweringOutput {
|
||||
let base_model_output = self
|
||||
|
@ -62,7 +62,7 @@ fn albert_masked_lm() -> anyhow::Result<()> {
|
||||
|
||||
// Forward pass
|
||||
let model_output =
|
||||
no_grad(|| albert_model.forward_t(Some(input_tensor), None, None, None, None, false));
|
||||
no_grad(|| albert_model.forward_t(Some(&input_tensor), None, None, None, None, false));
|
||||
|
||||
// Print masked tokens
|
||||
let index_1 = model_output
|
||||
@ -135,7 +135,7 @@ fn albert_for_sequence_classification() -> anyhow::Result<()> {
|
||||
|
||||
// Forward pass
|
||||
let model_output =
|
||||
no_grad(|| albert_model.forward_t(Some(input_tensor), None, None, None, None, false));
|
||||
no_grad(|| albert_model.forward_t(Some(&input_tensor), None, None, None, None, false));
|
||||
|
||||
assert_eq!(model_output.logits.size(), &[2, 3]);
|
||||
assert_eq!(
|
||||
@ -199,7 +199,7 @@ fn albert_for_multiple_choice() -> anyhow::Result<()> {
|
||||
// Forward pass
|
||||
let model_output = no_grad(|| {
|
||||
albert_model
|
||||
.forward_t(Some(input_tensor), None, None, None, None, false)
|
||||
.forward_t(Some(&input_tensor), None, None, None, None, false)
|
||||
.unwrap()
|
||||
});
|
||||
|
||||
@ -268,7 +268,7 @@ fn albert_for_token_classification() -> anyhow::Result<()> {
|
||||
|
||||
// Forward pass
|
||||
let model_output =
|
||||
no_grad(|| bert_model.forward_t(Some(input_tensor), None, None, None, None, false));
|
||||
no_grad(|| bert_model.forward_t(Some(&input_tensor), None, None, None, None, false));
|
||||
|
||||
assert_eq!(model_output.logits.size(), &[2, 12, 4]);
|
||||
assert_eq!(
|
||||
@ -329,7 +329,7 @@ fn albert_for_question_answering() -> anyhow::Result<()> {
|
||||
|
||||
// Forward pass
|
||||
let model_output =
|
||||
no_grad(|| albert_model.forward_t(Some(input_tensor), None, None, None, None, false));
|
||||
no_grad(|| albert_model.forward_t(Some(&input_tensor), None, None, None, None, false));
|
||||
|
||||
assert_eq!(model_output.start_logits.size(), &[2, 12]);
|
||||
assert_eq!(model_output.end_logits.size(), &[2, 12]);
|
||||
|
@ -72,13 +72,13 @@ fn bert_masked_lm() -> anyhow::Result<()> {
|
||||
// Forward pass
|
||||
let model_output = no_grad(|| {
|
||||
bert_model.forward_t(
|
||||
Some(input_tensor),
|
||||
Some(&input_tensor),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
&None,
|
||||
&None,
|
||||
false,
|
||||
)
|
||||
});
|
||||
@ -152,7 +152,7 @@ fn bert_for_sequence_classification() -> anyhow::Result<()> {
|
||||
|
||||
// Forward pass
|
||||
let model_output =
|
||||
no_grad(|| bert_model.forward_t(Some(input_tensor), None, None, None, None, false));
|
||||
no_grad(|| bert_model.forward_t(Some(&input_tensor), None, None, None, None, false));
|
||||
|
||||
assert_eq!(model_output.logits.size(), &[2, 3]);
|
||||
assert_eq!(
|
||||
@ -212,7 +212,7 @@ fn bert_for_multiple_choice() -> anyhow::Result<()> {
|
||||
.unsqueeze(0);
|
||||
|
||||
// Forward pass
|
||||
let model_output = no_grad(|| bert_model.forward_t(input_tensor, None, None, None, false));
|
||||
let model_output = no_grad(|| bert_model.forward_t(&input_tensor, None, None, None, false));
|
||||
|
||||
assert_eq!(model_output.logits.size(), &[1, 2]);
|
||||
assert_eq!(
|
||||
@ -277,7 +277,7 @@ fn bert_for_token_classification() -> anyhow::Result<()> {
|
||||
|
||||
// Forward pass
|
||||
let model_output =
|
||||
no_grad(|| bert_model.forward_t(Some(input_tensor), None, None, None, None, false));
|
||||
no_grad(|| bert_model.forward_t(Some(&input_tensor), None, None, None, None, false));
|
||||
|
||||
assert_eq!(model_output.logits.size(), &[2, 11, 4]);
|
||||
assert_eq!(
|
||||
@ -336,7 +336,7 @@ fn bert_for_question_answering() -> anyhow::Result<()> {
|
||||
|
||||
// Forward pass
|
||||
let model_output =
|
||||
no_grad(|| bert_model.forward_t(Some(input_tensor), None, None, None, None, false));
|
||||
no_grad(|| bert_model.forward_t(Some(&input_tensor), None, None, None, None, false));
|
||||
|
||||
assert_eq!(model_output.start_logits.size(), &[2, 11]);
|
||||
assert_eq!(model_output.end_logits.size(), &[2, 11]);
|
||||
|
@ -96,7 +96,7 @@ fn distilbert_masked_lm() -> anyhow::Result<()> {
|
||||
// Forward pass
|
||||
let model_output = no_grad(|| {
|
||||
distil_bert_model
|
||||
.forward_t(Some(input_tensor), None, None, false)
|
||||
.forward_t(Some(&input_tensor), None, None, false)
|
||||
.unwrap()
|
||||
});
|
||||
|
||||
@ -167,7 +167,7 @@ fn distilbert_for_question_answering() -> anyhow::Result<()> {
|
||||
// Forward pass
|
||||
let model_output = no_grad(|| {
|
||||
distil_bert_model
|
||||
.forward_t(Some(input_tensor), None, None, false)
|
||||
.forward_t(Some(&input_tensor), None, None, false)
|
||||
.unwrap()
|
||||
});
|
||||
|
||||
@ -238,7 +238,7 @@ fn distilbert_for_token_classification() -> anyhow::Result<()> {
|
||||
// Forward pass
|
||||
let model_output = no_grad(|| {
|
||||
distil_bert_model
|
||||
.forward_t(Some(input_tensor), None, None, false)
|
||||
.forward_t(Some(&input_tensor), None, None, false)
|
||||
.unwrap()
|
||||
});
|
||||
|
||||
|
@ -59,7 +59,7 @@ fn electra_masked_lm() -> anyhow::Result<()> {
|
||||
|
||||
// Forward pass
|
||||
let model_output =
|
||||
no_grad(|| electra_model.forward_t(Some(input_tensor), None, None, None, None, false));
|
||||
no_grad(|| electra_model.forward_t(Some(&input_tensor), None, None, None, None, false));
|
||||
|
||||
// Decode output
|
||||
let index_1 = model_output
|
||||
@ -138,7 +138,7 @@ fn electra_discriminator() -> anyhow::Result<()> {
|
||||
|
||||
// Forward pass
|
||||
let model_output =
|
||||
no_grad(|| electra_model.forward_t(Some(input_tensor), None, None, None, None, false));
|
||||
no_grad(|| electra_model.forward_t(Some(&input_tensor), None, None, None, None, false));
|
||||
|
||||
// Validate model predictions
|
||||
let expected_probabilities = vec![
|
||||
|
@ -82,13 +82,13 @@ fn roberta_masked_lm() -> anyhow::Result<()> {
|
||||
// Forward pass
|
||||
let model_output = no_grad(|| {
|
||||
roberta_model.forward_t(
|
||||
Some(input_tensor),
|
||||
Some(&input_tensor),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
&None,
|
||||
&None,
|
||||
false,
|
||||
)
|
||||
});
|
||||
@ -172,7 +172,7 @@ fn roberta_for_sequence_classification() -> anyhow::Result<()> {
|
||||
|
||||
// Forward pass
|
||||
let model_output =
|
||||
no_grad(|| roberta_model.forward_t(Some(input_tensor), None, None, None, None, false));
|
||||
no_grad(|| roberta_model.forward_t(Some(&input_tensor), None, None, None, None, false));
|
||||
|
||||
assert_eq!(model_output.logits.size(), &[2, 3]);
|
||||
assert_eq!(
|
||||
@ -242,7 +242,7 @@ fn roberta_for_multiple_choice() -> anyhow::Result<()> {
|
||||
.unsqueeze(0);
|
||||
|
||||
// Forward pass
|
||||
let model_output = no_grad(|| roberta_model.forward_t(input_tensor, None, None, None, false));
|
||||
let model_output = no_grad(|| roberta_model.forward_t(&input_tensor, None, None, None, false));
|
||||
|
||||
assert_eq!(model_output.logits.size(), &[1, 2]);
|
||||
assert_eq!(
|
||||
@ -317,7 +317,7 @@ fn roberta_for_token_classification() -> anyhow::Result<()> {
|
||||
|
||||
// Forward pass
|
||||
let model_output =
|
||||
no_grad(|| roberta_model.forward_t(Some(input_tensor), None, None, None, None, false));
|
||||
no_grad(|| roberta_model.forward_t(Some(&input_tensor), None, None, None, None, false));
|
||||
|
||||
assert_eq!(model_output.logits.size(), &[2, 9, 4]);
|
||||
assert_eq!(
|
||||
|
Loading…
Reference in New Issue
Block a user