mirror of
https://github.com/guillaume-be/rust-bert.git
synced 2024-10-26 14:07:25 +03:00
Updated borrowing for XLNet, integration tests
This commit is contained in:
parent
2baf659e9b
commit
cb6bc34eb4
@ -67,13 +67,13 @@ fn main() -> anyhow::Result<()> {
|
|||||||
// Forward pass
|
// Forward pass
|
||||||
let model_output = no_grad(|| {
|
let model_output = no_grad(|| {
|
||||||
bert_model.forward_t(
|
bert_model.forward_t(
|
||||||
Some(input_tensor),
|
Some(&input_tensor),
|
||||||
|
None,
|
||||||
|
None,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
&None,
|
|
||||||
&None,
|
|
||||||
false,
|
false,
|
||||||
)
|
)
|
||||||
});
|
});
|
||||||
|
@ -15,6 +15,7 @@
|
|||||||
use crate::bert::BertConfig;
|
use crate::bert::BertConfig;
|
||||||
use crate::common::activations::Activation;
|
use crate::common::activations::Activation;
|
||||||
use crate::common::dropout::Dropout;
|
use crate::common::dropout::Dropout;
|
||||||
|
use crate::common::embeddings::get_shape_and_device_from_ids_embeddings_pair;
|
||||||
use crate::electra::embeddings::ElectraEmbeddings;
|
use crate::electra::embeddings::ElectraEmbeddings;
|
||||||
use crate::{bert::encoder::BertEncoder, common::activations::TensorFunction};
|
use crate::{bert::encoder::BertEncoder, common::activations::TensorFunction};
|
||||||
use crate::{Config, RustBertError};
|
use crate::{Config, RustBertError};
|
||||||
@ -216,10 +217,10 @@ impl ElectraModel {
|
|||||||
/// let model_output = no_grad(|| {
|
/// let model_output = no_grad(|| {
|
||||||
/// electra_model
|
/// electra_model
|
||||||
/// .forward_t(
|
/// .forward_t(
|
||||||
/// Some(input_tensor),
|
/// Some(&input_tensor),
|
||||||
/// Some(mask),
|
/// Some(&mask),
|
||||||
/// Some(token_type_ids),
|
/// Some(&token_type_ids),
|
||||||
/// Some(position_ids),
|
/// Some(&position_ids),
|
||||||
/// None,
|
/// None,
|
||||||
/// false,
|
/// false,
|
||||||
/// )
|
/// )
|
||||||
@ -228,33 +229,22 @@ impl ElectraModel {
|
|||||||
/// ```
|
/// ```
|
||||||
pub fn forward_t(
|
pub fn forward_t(
|
||||||
&self,
|
&self,
|
||||||
input_ids: Option<Tensor>,
|
input_ids: Option<&Tensor>,
|
||||||
mask: Option<Tensor>,
|
mask: Option<&Tensor>,
|
||||||
token_type_ids: Option<Tensor>,
|
token_type_ids: Option<&Tensor>,
|
||||||
position_ids: Option<Tensor>,
|
position_ids: Option<&Tensor>,
|
||||||
input_embeds: Option<Tensor>,
|
input_embeds: Option<&Tensor>,
|
||||||
train: bool,
|
train: bool,
|
||||||
) -> Result<ElectraModelOutput, RustBertError> {
|
) -> Result<ElectraModelOutput, RustBertError> {
|
||||||
let (input_shape, device) = match &input_ids {
|
let (input_shape, device) =
|
||||||
Some(input_value) => match &input_embeds {
|
get_shape_and_device_from_ids_embeddings_pair(input_ids, input_embeds)?;
|
||||||
Some(_) => {
|
|
||||||
return Err(RustBertError::ValueError(
|
|
||||||
"Only one of input ids or input embeddings may be set".into(),
|
|
||||||
));
|
|
||||||
}
|
|
||||||
None => (input_value.size(), input_value.device()),
|
|
||||||
},
|
|
||||||
None => match &input_embeds {
|
|
||||||
Some(embeds) => (vec![embeds.size()[0], embeds.size()[1]], embeds.device()),
|
|
||||||
None => {
|
|
||||||
return Err(RustBertError::ValueError(
|
|
||||||
"At least one of input ids or input embeddings must be set".into(),
|
|
||||||
));
|
|
||||||
}
|
|
||||||
},
|
|
||||||
};
|
|
||||||
|
|
||||||
let mask = mask.unwrap_or_else(|| Tensor::ones(&input_shape, (Kind::Int64, device)));
|
let calc_mask = if mask.is_none() {
|
||||||
|
Some(Tensor::ones(&input_shape, (Kind::Int64, device)))
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
let mask = mask.unwrap_or_else(|| calc_mask.as_ref().unwrap());
|
||||||
|
|
||||||
let extended_attention_mask = match mask.dim() {
|
let extended_attention_mask = match mask.dim() {
|
||||||
3 => mask.unsqueeze(1),
|
3 => mask.unsqueeze(1),
|
||||||
@ -590,10 +580,10 @@ impl ElectraForMaskedLM {
|
|||||||
///
|
///
|
||||||
/// let model_output = no_grad(|| {
|
/// let model_output = no_grad(|| {
|
||||||
/// electra_model.forward_t(
|
/// electra_model.forward_t(
|
||||||
/// Some(input_tensor),
|
/// Some(&input_tensor),
|
||||||
/// Some(mask),
|
/// Some(&mask),
|
||||||
/// Some(token_type_ids),
|
/// Some(&token_type_ids),
|
||||||
/// Some(position_ids),
|
/// Some(&position_ids),
|
||||||
/// None,
|
/// None,
|
||||||
/// false,
|
/// false,
|
||||||
/// )
|
/// )
|
||||||
@ -601,11 +591,11 @@ impl ElectraForMaskedLM {
|
|||||||
/// ```
|
/// ```
|
||||||
pub fn forward_t(
|
pub fn forward_t(
|
||||||
&self,
|
&self,
|
||||||
input_ids: Option<Tensor>,
|
input_ids: Option<&Tensor>,
|
||||||
mask: Option<Tensor>,
|
mask: Option<&Tensor>,
|
||||||
token_type_ids: Option<Tensor>,
|
token_type_ids: Option<&Tensor>,
|
||||||
position_ids: Option<Tensor>,
|
position_ids: Option<&Tensor>,
|
||||||
input_embeds: Option<Tensor>,
|
input_embeds: Option<&Tensor>,
|
||||||
train: bool,
|
train: bool,
|
||||||
) -> ElectraMaskedLMOutput {
|
) -> ElectraMaskedLMOutput {
|
||||||
let base_model_output = self
|
let base_model_output = self
|
||||||
@ -717,21 +707,21 @@ impl ElectraDiscriminator {
|
|||||||
///
|
///
|
||||||
/// let model_output = no_grad(|| {
|
/// let model_output = no_grad(|| {
|
||||||
/// electra_model
|
/// electra_model
|
||||||
/// .forward_t(Some(input_tensor),
|
/// .forward_t(Some(&input_tensor),
|
||||||
/// Some(mask),
|
/// Some(&mask),
|
||||||
/// Some(token_type_ids),
|
/// Some(&token_type_ids),
|
||||||
/// Some(position_ids),
|
/// Some(&position_ids),
|
||||||
/// None,
|
/// None,
|
||||||
/// false)
|
/// false)
|
||||||
/// });
|
/// });
|
||||||
/// ```
|
/// ```
|
||||||
pub fn forward_t(
|
pub fn forward_t(
|
||||||
&self,
|
&self,
|
||||||
input_ids: Option<Tensor>,
|
input_ids: Option<&Tensor>,
|
||||||
mask: Option<Tensor>,
|
mask: Option<&Tensor>,
|
||||||
token_type_ids: Option<Tensor>,
|
token_type_ids: Option<&Tensor>,
|
||||||
position_ids: Option<Tensor>,
|
position_ids: Option<&Tensor>,
|
||||||
input_embeds: Option<Tensor>,
|
input_embeds: Option<&Tensor>,
|
||||||
train: bool,
|
train: bool,
|
||||||
) -> ElectraDiscriminatorOutput {
|
) -> ElectraDiscriminatorOutput {
|
||||||
let base_model_output = self
|
let base_model_output = self
|
||||||
@ -858,21 +848,21 @@ impl ElectraForTokenClassification {
|
|||||||
///
|
///
|
||||||
/// let model_output = no_grad(|| {
|
/// let model_output = no_grad(|| {
|
||||||
/// electra_model
|
/// electra_model
|
||||||
/// .forward_t(Some(input_tensor),
|
/// .forward_t(Some(&input_tensor),
|
||||||
/// Some(mask),
|
/// Some(&mask),
|
||||||
/// Some(token_type_ids),
|
/// Some(&token_type_ids),
|
||||||
/// Some(position_ids),
|
/// Some(&position_ids),
|
||||||
/// None,
|
/// None,
|
||||||
/// false)
|
/// false)
|
||||||
/// });
|
/// });
|
||||||
/// ```
|
/// ```
|
||||||
pub fn forward_t(
|
pub fn forward_t(
|
||||||
&self,
|
&self,
|
||||||
input_ids: Option<Tensor>,
|
input_ids: Option<&Tensor>,
|
||||||
mask: Option<Tensor>,
|
mask: Option<&Tensor>,
|
||||||
token_type_ids: Option<Tensor>,
|
token_type_ids: Option<&Tensor>,
|
||||||
position_ids: Option<Tensor>,
|
position_ids: Option<&Tensor>,
|
||||||
input_embeds: Option<Tensor>,
|
input_embeds: Option<&Tensor>,
|
||||||
train: bool,
|
train: bool,
|
||||||
) -> ElectraTokenClassificationOutput {
|
) -> ElectraTokenClassificationOutput {
|
||||||
let base_model_output = self
|
let base_model_output = self
|
||||||
|
@ -13,6 +13,7 @@
|
|||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
use crate::common::dropout::Dropout;
|
use crate::common::dropout::Dropout;
|
||||||
|
use crate::common::embeddings::process_ids_embeddings_pair;
|
||||||
use crate::electra::electra_model::ElectraConfig;
|
use crate::electra::electra_model::ElectraConfig;
|
||||||
use crate::RustBertError;
|
use crate::RustBertError;
|
||||||
use std::borrow::Borrow;
|
use std::borrow::Borrow;
|
||||||
@ -84,50 +85,40 @@ impl ElectraEmbeddings {
|
|||||||
|
|
||||||
pub fn forward_t(
|
pub fn forward_t(
|
||||||
&self,
|
&self,
|
||||||
input_ids: Option<Tensor>,
|
input_ids: Option<&Tensor>,
|
||||||
token_type_ids: Option<Tensor>,
|
token_type_ids: Option<&Tensor>,
|
||||||
position_ids: Option<Tensor>,
|
position_ids: Option<&Tensor>,
|
||||||
input_embeds: Option<Tensor>,
|
input_embeds: Option<&Tensor>,
|
||||||
train: bool,
|
train: bool,
|
||||||
) -> Result<Tensor, RustBertError> {
|
) -> Result<Tensor, RustBertError> {
|
||||||
let (input_embeddings, input_shape) = match input_ids {
|
let (calc_input_embeddings, input_shape, _) =
|
||||||
Some(input_value) => match input_embeds {
|
process_ids_embeddings_pair(input_ids, input_embeds, &self.word_embeddings)?;
|
||||||
Some(_) => {
|
|
||||||
return Err(RustBertError::ValueError(
|
|
||||||
"Only one of input ids or input embeddings may be set".into(),
|
|
||||||
));
|
|
||||||
}
|
|
||||||
None => (
|
|
||||||
input_value.apply_t(&self.word_embeddings, train),
|
|
||||||
input_value.size(),
|
|
||||||
),
|
|
||||||
},
|
|
||||||
None => match input_embeds {
|
|
||||||
Some(embeds) => {
|
|
||||||
let size = vec![embeds.size()[0], embeds.size()[1]];
|
|
||||||
(embeds, size)
|
|
||||||
}
|
|
||||||
None => {
|
|
||||||
return Err(RustBertError::ValueError(
|
|
||||||
"At least one of input ids or input embeddings must be set".into(),
|
|
||||||
));
|
|
||||||
}
|
|
||||||
},
|
|
||||||
};
|
|
||||||
|
|
||||||
let seq_length = input_embeddings.as_ref().size()[1].to_owned();
|
let input_embeddings =
|
||||||
|
input_embeds.unwrap_or_else(|| calc_input_embeddings.as_ref().unwrap());
|
||||||
|
let seq_length = input_embeddings.size()[1].to_owned();
|
||||||
|
|
||||||
let position_ids = match position_ids {
|
let calc_position_ids = if position_ids.is_none() {
|
||||||
Some(value) => value,
|
Some(
|
||||||
None => Tensor::arange(seq_length, (Kind::Int64, input_embeddings.device()))
|
Tensor::arange(seq_length, (Kind::Int64, input_embeddings.device()))
|
||||||
.unsqueeze(0)
|
.unsqueeze(0)
|
||||||
.expand(&input_shape, true),
|
.expand(&input_shape, true),
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
};
|
};
|
||||||
|
let position_ids = position_ids.unwrap_or_else(|| calc_position_ids.as_ref().unwrap());
|
||||||
|
|
||||||
let token_type_ids = match token_type_ids {
|
let calc_token_type_ids = if token_type_ids.is_none() {
|
||||||
Some(value) => value,
|
Some(Tensor::zeros(
|
||||||
None => Tensor::zeros(&input_shape, (Kind::Int64, input_embeddings.device())),
|
&input_shape,
|
||||||
|
(Kind::Int64, input_embeddings.device()),
|
||||||
|
))
|
||||||
|
} else {
|
||||||
|
None
|
||||||
};
|
};
|
||||||
|
let token_type_ids =
|
||||||
|
token_type_ids.unwrap_or_else(|| calc_token_type_ids.as_ref().unwrap());
|
||||||
|
|
||||||
let position_embeddings = position_ids.apply(&self.position_embeddings);
|
let position_embeddings = position_ids.apply(&self.position_embeddings);
|
||||||
let token_type_embeddings = token_type_ids.apply(&self.token_type_embeddings);
|
let token_type_embeddings = token_type_ids.apply(&self.token_type_embeddings);
|
||||||
|
@ -545,12 +545,12 @@ impl TokenClassificationOption {
|
|||||||
Self::Longformer(ref model) => {
|
Self::Longformer(ref model) => {
|
||||||
model
|
model
|
||||||
.forward_t(
|
.forward_t(
|
||||||
input_ids.as_ref(),
|
input_ids,
|
||||||
mask.as_ref(),
|
mask,
|
||||||
None,
|
None,
|
||||||
token_type_ids.as_ref(),
|
token_type_ids,
|
||||||
position_ids.as_ref(),
|
position_ids,
|
||||||
input_embeds.as_ref(),
|
input_embeds,
|
||||||
train,
|
train,
|
||||||
)
|
)
|
||||||
.expect("Error in longformer forward_t")
|
.expect("Error in longformer forward_t")
|
||||||
|
@ -390,38 +390,34 @@ impl XLNetModel {
|
|||||||
perm_mask: Option<&Tensor>,
|
perm_mask: Option<&Tensor>,
|
||||||
target_mapping: Option<&Tensor>,
|
target_mapping: Option<&Tensor>,
|
||||||
token_type_ids: Option<&Tensor>,
|
token_type_ids: Option<&Tensor>,
|
||||||
input_embeds: Option<Tensor>,
|
input_embeds: Option<&Tensor>,
|
||||||
train: bool,
|
train: bool,
|
||||||
) -> Result<XLNetModelOutput, RustBertError> {
|
) -> Result<XLNetModelOutput, RustBertError> {
|
||||||
let (word_emb_k, input_shape) = match input_ids {
|
let (word_emb_k, input_shape) = match (input_ids, input_embeds) {
|
||||||
Some(input_value) => match input_embeds {
|
(Some(_), Some(_)) => {
|
||||||
Some(_) => {
|
return Err(RustBertError::ValueError(
|
||||||
return Err(RustBertError::ValueError(
|
"Only one of input ids or input embeddings may be set".into(),
|
||||||
"Only one of input ids or input embeddings may be set".into(),
|
));
|
||||||
));
|
}
|
||||||
}
|
(Some(input_value), None) => {
|
||||||
None => {
|
let size = input_value.size();
|
||||||
let size = input_value.size();
|
(
|
||||||
(
|
input_value
|
||||||
input_value
|
.transpose(0, 1)
|
||||||
.transpose(0, 1)
|
.contiguous()
|
||||||
.contiguous()
|
.apply_t(&self.word_embeddings, train),
|
||||||
.apply_t(&self.word_embeddings, train),
|
vec![size[1], size[0]],
|
||||||
vec![size[1], size[0]],
|
)
|
||||||
)
|
}
|
||||||
}
|
(None, Some(embeds)) => {
|
||||||
},
|
let size = vec![embeds.size()[1], embeds.size()[0]];
|
||||||
None => match input_embeds {
|
(embeds.transpose(0, 1).contiguous(), size)
|
||||||
Some(embeds) => {
|
}
|
||||||
let size = vec![embeds.size()[1], embeds.size()[0]];
|
(None, None) => {
|
||||||
(embeds.transpose(0, 1).contiguous(), size)
|
return Err(RustBertError::ValueError(
|
||||||
}
|
"At least one of input ids or input embeddings must be set".into(),
|
||||||
None => {
|
));
|
||||||
return Err(RustBertError::ValueError(
|
}
|
||||||
"At least one of input ids or input embeddings must be set".into(),
|
|
||||||
));
|
|
||||||
}
|
|
||||||
},
|
|
||||||
};
|
};
|
||||||
|
|
||||||
let token_type_ids =
|
let token_type_ids =
|
||||||
@ -715,7 +711,7 @@ impl XLNetLMHeadModel {
|
|||||||
perm_mask: Option<&Tensor>,
|
perm_mask: Option<&Tensor>,
|
||||||
target_mapping: Option<&Tensor>,
|
target_mapping: Option<&Tensor>,
|
||||||
token_type_ids: Option<&Tensor>,
|
token_type_ids: Option<&Tensor>,
|
||||||
input_embeds: Option<Tensor>,
|
input_embeds: Option<&Tensor>,
|
||||||
train: bool,
|
train: bool,
|
||||||
) -> Result<LMModelOutput, RustBertError> {
|
) -> Result<LMModelOutput, RustBertError> {
|
||||||
let base_model_output = self.base_model.forward_t(
|
let base_model_output = self.base_model.forward_t(
|
||||||
@ -966,7 +962,7 @@ impl XLNetForSequenceClassification {
|
|||||||
perm_mask: Option<&Tensor>,
|
perm_mask: Option<&Tensor>,
|
||||||
target_mapping: Option<&Tensor>,
|
target_mapping: Option<&Tensor>,
|
||||||
token_type_ids: Option<&Tensor>,
|
token_type_ids: Option<&Tensor>,
|
||||||
input_embeds: Option<Tensor>,
|
input_embeds: Option<&Tensor>,
|
||||||
train: bool,
|
train: bool,
|
||||||
) -> XLNetSequenceClassificationOutput {
|
) -> XLNetSequenceClassificationOutput {
|
||||||
let base_model_output = self
|
let base_model_output = self
|
||||||
@ -1124,7 +1120,7 @@ impl XLNetForTokenClassification {
|
|||||||
perm_mask: Option<&Tensor>,
|
perm_mask: Option<&Tensor>,
|
||||||
target_mapping: Option<&Tensor>,
|
target_mapping: Option<&Tensor>,
|
||||||
token_type_ids: Option<&Tensor>,
|
token_type_ids: Option<&Tensor>,
|
||||||
input_embeds: Option<Tensor>,
|
input_embeds: Option<&Tensor>,
|
||||||
train: bool,
|
train: bool,
|
||||||
) -> XLNetTokenClassificationOutput {
|
) -> XLNetTokenClassificationOutput {
|
||||||
let base_model_output = self
|
let base_model_output = self
|
||||||
@ -1273,7 +1269,7 @@ impl XLNetForMultipleChoice {
|
|||||||
perm_mask: Option<&Tensor>,
|
perm_mask: Option<&Tensor>,
|
||||||
target_mapping: Option<&Tensor>,
|
target_mapping: Option<&Tensor>,
|
||||||
token_type_ids: Option<&Tensor>,
|
token_type_ids: Option<&Tensor>,
|
||||||
input_embeds: Option<Tensor>,
|
input_embeds: Option<&Tensor>,
|
||||||
train: bool,
|
train: bool,
|
||||||
) -> XLNetSequenceClassificationOutput {
|
) -> XLNetSequenceClassificationOutput {
|
||||||
let (input_ids, num_choices) = match input_ids {
|
let (input_ids, num_choices) = match input_ids {
|
||||||
@ -1305,7 +1301,7 @@ impl XLNetForMultipleChoice {
|
|||||||
perm_mask,
|
perm_mask,
|
||||||
target_mapping,
|
target_mapping,
|
||||||
token_type_ids.as_ref(),
|
token_type_ids.as_ref(),
|
||||||
input_embeds,
|
input_embeds.as_ref(),
|
||||||
train,
|
train,
|
||||||
)
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
@ -1444,7 +1440,7 @@ impl XLNetForQuestionAnswering {
|
|||||||
perm_mask: Option<&Tensor>,
|
perm_mask: Option<&Tensor>,
|
||||||
target_mapping: Option<&Tensor>,
|
target_mapping: Option<&Tensor>,
|
||||||
token_type_ids: Option<&Tensor>,
|
token_type_ids: Option<&Tensor>,
|
||||||
input_embeds: Option<Tensor>,
|
input_embeds: Option<&Tensor>,
|
||||||
train: bool,
|
train: bool,
|
||||||
) -> XLNetQuestionAnsweringOutput {
|
) -> XLNetQuestionAnsweringOutput {
|
||||||
let base_model_output = self
|
let base_model_output = self
|
||||||
|
@ -62,7 +62,7 @@ fn albert_masked_lm() -> anyhow::Result<()> {
|
|||||||
|
|
||||||
// Forward pass
|
// Forward pass
|
||||||
let model_output =
|
let model_output =
|
||||||
no_grad(|| albert_model.forward_t(Some(input_tensor), None, None, None, None, false));
|
no_grad(|| albert_model.forward_t(Some(&input_tensor), None, None, None, None, false));
|
||||||
|
|
||||||
// Print masked tokens
|
// Print masked tokens
|
||||||
let index_1 = model_output
|
let index_1 = model_output
|
||||||
@ -135,7 +135,7 @@ fn albert_for_sequence_classification() -> anyhow::Result<()> {
|
|||||||
|
|
||||||
// Forward pass
|
// Forward pass
|
||||||
let model_output =
|
let model_output =
|
||||||
no_grad(|| albert_model.forward_t(Some(input_tensor), None, None, None, None, false));
|
no_grad(|| albert_model.forward_t(Some(&input_tensor), None, None, None, None, false));
|
||||||
|
|
||||||
assert_eq!(model_output.logits.size(), &[2, 3]);
|
assert_eq!(model_output.logits.size(), &[2, 3]);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
@ -199,7 +199,7 @@ fn albert_for_multiple_choice() -> anyhow::Result<()> {
|
|||||||
// Forward pass
|
// Forward pass
|
||||||
let model_output = no_grad(|| {
|
let model_output = no_grad(|| {
|
||||||
albert_model
|
albert_model
|
||||||
.forward_t(Some(input_tensor), None, None, None, None, false)
|
.forward_t(Some(&input_tensor), None, None, None, None, false)
|
||||||
.unwrap()
|
.unwrap()
|
||||||
});
|
});
|
||||||
|
|
||||||
@ -268,7 +268,7 @@ fn albert_for_token_classification() -> anyhow::Result<()> {
|
|||||||
|
|
||||||
// Forward pass
|
// Forward pass
|
||||||
let model_output =
|
let model_output =
|
||||||
no_grad(|| bert_model.forward_t(Some(input_tensor), None, None, None, None, false));
|
no_grad(|| bert_model.forward_t(Some(&input_tensor), None, None, None, None, false));
|
||||||
|
|
||||||
assert_eq!(model_output.logits.size(), &[2, 12, 4]);
|
assert_eq!(model_output.logits.size(), &[2, 12, 4]);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
@ -329,7 +329,7 @@ fn albert_for_question_answering() -> anyhow::Result<()> {
|
|||||||
|
|
||||||
// Forward pass
|
// Forward pass
|
||||||
let model_output =
|
let model_output =
|
||||||
no_grad(|| albert_model.forward_t(Some(input_tensor), None, None, None, None, false));
|
no_grad(|| albert_model.forward_t(Some(&input_tensor), None, None, None, None, false));
|
||||||
|
|
||||||
assert_eq!(model_output.start_logits.size(), &[2, 12]);
|
assert_eq!(model_output.start_logits.size(), &[2, 12]);
|
||||||
assert_eq!(model_output.end_logits.size(), &[2, 12]);
|
assert_eq!(model_output.end_logits.size(), &[2, 12]);
|
||||||
|
@ -72,13 +72,13 @@ fn bert_masked_lm() -> anyhow::Result<()> {
|
|||||||
// Forward pass
|
// Forward pass
|
||||||
let model_output = no_grad(|| {
|
let model_output = no_grad(|| {
|
||||||
bert_model.forward_t(
|
bert_model.forward_t(
|
||||||
Some(input_tensor),
|
Some(&input_tensor),
|
||||||
|
None,
|
||||||
|
None,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
&None,
|
|
||||||
&None,
|
|
||||||
false,
|
false,
|
||||||
)
|
)
|
||||||
});
|
});
|
||||||
@ -152,7 +152,7 @@ fn bert_for_sequence_classification() -> anyhow::Result<()> {
|
|||||||
|
|
||||||
// Forward pass
|
// Forward pass
|
||||||
let model_output =
|
let model_output =
|
||||||
no_grad(|| bert_model.forward_t(Some(input_tensor), None, None, None, None, false));
|
no_grad(|| bert_model.forward_t(Some(&input_tensor), None, None, None, None, false));
|
||||||
|
|
||||||
assert_eq!(model_output.logits.size(), &[2, 3]);
|
assert_eq!(model_output.logits.size(), &[2, 3]);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
@ -212,7 +212,7 @@ fn bert_for_multiple_choice() -> anyhow::Result<()> {
|
|||||||
.unsqueeze(0);
|
.unsqueeze(0);
|
||||||
|
|
||||||
// Forward pass
|
// Forward pass
|
||||||
let model_output = no_grad(|| bert_model.forward_t(input_tensor, None, None, None, false));
|
let model_output = no_grad(|| bert_model.forward_t(&input_tensor, None, None, None, false));
|
||||||
|
|
||||||
assert_eq!(model_output.logits.size(), &[1, 2]);
|
assert_eq!(model_output.logits.size(), &[1, 2]);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
@ -277,7 +277,7 @@ fn bert_for_token_classification() -> anyhow::Result<()> {
|
|||||||
|
|
||||||
// Forward pass
|
// Forward pass
|
||||||
let model_output =
|
let model_output =
|
||||||
no_grad(|| bert_model.forward_t(Some(input_tensor), None, None, None, None, false));
|
no_grad(|| bert_model.forward_t(Some(&input_tensor), None, None, None, None, false));
|
||||||
|
|
||||||
assert_eq!(model_output.logits.size(), &[2, 11, 4]);
|
assert_eq!(model_output.logits.size(), &[2, 11, 4]);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
@ -336,7 +336,7 @@ fn bert_for_question_answering() -> anyhow::Result<()> {
|
|||||||
|
|
||||||
// Forward pass
|
// Forward pass
|
||||||
let model_output =
|
let model_output =
|
||||||
no_grad(|| bert_model.forward_t(Some(input_tensor), None, None, None, None, false));
|
no_grad(|| bert_model.forward_t(Some(&input_tensor), None, None, None, None, false));
|
||||||
|
|
||||||
assert_eq!(model_output.start_logits.size(), &[2, 11]);
|
assert_eq!(model_output.start_logits.size(), &[2, 11]);
|
||||||
assert_eq!(model_output.end_logits.size(), &[2, 11]);
|
assert_eq!(model_output.end_logits.size(), &[2, 11]);
|
||||||
|
@ -96,7 +96,7 @@ fn distilbert_masked_lm() -> anyhow::Result<()> {
|
|||||||
// Forward pass
|
// Forward pass
|
||||||
let model_output = no_grad(|| {
|
let model_output = no_grad(|| {
|
||||||
distil_bert_model
|
distil_bert_model
|
||||||
.forward_t(Some(input_tensor), None, None, false)
|
.forward_t(Some(&input_tensor), None, None, false)
|
||||||
.unwrap()
|
.unwrap()
|
||||||
});
|
});
|
||||||
|
|
||||||
@ -167,7 +167,7 @@ fn distilbert_for_question_answering() -> anyhow::Result<()> {
|
|||||||
// Forward pass
|
// Forward pass
|
||||||
let model_output = no_grad(|| {
|
let model_output = no_grad(|| {
|
||||||
distil_bert_model
|
distil_bert_model
|
||||||
.forward_t(Some(input_tensor), None, None, false)
|
.forward_t(Some(&input_tensor), None, None, false)
|
||||||
.unwrap()
|
.unwrap()
|
||||||
});
|
});
|
||||||
|
|
||||||
@ -238,7 +238,7 @@ fn distilbert_for_token_classification() -> anyhow::Result<()> {
|
|||||||
// Forward pass
|
// Forward pass
|
||||||
let model_output = no_grad(|| {
|
let model_output = no_grad(|| {
|
||||||
distil_bert_model
|
distil_bert_model
|
||||||
.forward_t(Some(input_tensor), None, None, false)
|
.forward_t(Some(&input_tensor), None, None, false)
|
||||||
.unwrap()
|
.unwrap()
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -59,7 +59,7 @@ fn electra_masked_lm() -> anyhow::Result<()> {
|
|||||||
|
|
||||||
// Forward pass
|
// Forward pass
|
||||||
let model_output =
|
let model_output =
|
||||||
no_grad(|| electra_model.forward_t(Some(input_tensor), None, None, None, None, false));
|
no_grad(|| electra_model.forward_t(Some(&input_tensor), None, None, None, None, false));
|
||||||
|
|
||||||
// Decode output
|
// Decode output
|
||||||
let index_1 = model_output
|
let index_1 = model_output
|
||||||
@ -138,7 +138,7 @@ fn electra_discriminator() -> anyhow::Result<()> {
|
|||||||
|
|
||||||
// Forward pass
|
// Forward pass
|
||||||
let model_output =
|
let model_output =
|
||||||
no_grad(|| electra_model.forward_t(Some(input_tensor), None, None, None, None, false));
|
no_grad(|| electra_model.forward_t(Some(&input_tensor), None, None, None, None, false));
|
||||||
|
|
||||||
// Validate model predictions
|
// Validate model predictions
|
||||||
let expected_probabilities = vec![
|
let expected_probabilities = vec![
|
||||||
|
@ -82,13 +82,13 @@ fn roberta_masked_lm() -> anyhow::Result<()> {
|
|||||||
// Forward pass
|
// Forward pass
|
||||||
let model_output = no_grad(|| {
|
let model_output = no_grad(|| {
|
||||||
roberta_model.forward_t(
|
roberta_model.forward_t(
|
||||||
Some(input_tensor),
|
Some(&input_tensor),
|
||||||
|
None,
|
||||||
|
None,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
&None,
|
|
||||||
&None,
|
|
||||||
false,
|
false,
|
||||||
)
|
)
|
||||||
});
|
});
|
||||||
@ -172,7 +172,7 @@ fn roberta_for_sequence_classification() -> anyhow::Result<()> {
|
|||||||
|
|
||||||
// Forward pass
|
// Forward pass
|
||||||
let model_output =
|
let model_output =
|
||||||
no_grad(|| roberta_model.forward_t(Some(input_tensor), None, None, None, None, false));
|
no_grad(|| roberta_model.forward_t(Some(&input_tensor), None, None, None, None, false));
|
||||||
|
|
||||||
assert_eq!(model_output.logits.size(), &[2, 3]);
|
assert_eq!(model_output.logits.size(), &[2, 3]);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
@ -242,7 +242,7 @@ fn roberta_for_multiple_choice() -> anyhow::Result<()> {
|
|||||||
.unsqueeze(0);
|
.unsqueeze(0);
|
||||||
|
|
||||||
// Forward pass
|
// Forward pass
|
||||||
let model_output = no_grad(|| roberta_model.forward_t(input_tensor, None, None, None, false));
|
let model_output = no_grad(|| roberta_model.forward_t(&input_tensor, None, None, None, false));
|
||||||
|
|
||||||
assert_eq!(model_output.logits.size(), &[1, 2]);
|
assert_eq!(model_output.logits.size(), &[1, 2]);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
@ -317,7 +317,7 @@ fn roberta_for_token_classification() -> anyhow::Result<()> {
|
|||||||
|
|
||||||
// Forward pass
|
// Forward pass
|
||||||
let model_output =
|
let model_output =
|
||||||
no_grad(|| roberta_model.forward_t(Some(input_tensor), None, None, None, None, false));
|
no_grad(|| roberta_model.forward_t(Some(&input_tensor), None, None, None, None, false));
|
||||||
|
|
||||||
assert_eq!(model_output.logits.size(), &[2, 9, 4]);
|
assert_eq!(model_output.logits.size(), &[2, 9, 4]);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
|
Loading…
Reference in New Issue
Block a user