Updated BERT and DistilBERT (clippy)

This commit is contained in:
Guillaume B 2020-09-12 17:12:49 +02:00
parent 7e9d6d7e39
commit 5aa1e635ba
16 changed files with 355 additions and 190 deletions

View File

@ -65,7 +65,7 @@ fn main() -> anyhow::Result<()> {
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
let (output, _, _) = no_grad(|| {
let model_output = no_grad(|| {
bert_model.forward_t(
Some(input_tensor),
None,
@ -79,8 +79,16 @@ fn main() -> anyhow::Result<()> {
});
// Print masked tokens
let index_1 = output.get(0).get(4).argmax(0, false);
let index_2 = output.get(1).get(7).argmax(0, false);
let index_1 = model_output
.prediction_scores
.get(0)
.get(4)
.argmax(0, false);
let index_2 = model_output
.prediction_scores
.get(1)
.get(7)
.argmax(0, false);
let word_1 = tokenizer.vocab().id_to_token(&index_1.int64_value(&[]));
let word_2 = tokenizer.vocab().id_to_token(&index_2.int64_value(&[]));

View File

@ -77,15 +77,23 @@ fn main() -> anyhow::Result<()> {
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
let (output, _, _) = no_grad(|| {
let model_output = no_grad(|| {
distil_bert_model
.forward_t(Some(input_tensor), None, None, false)
.unwrap()
});
// Print masked tokens
let index_1 = output.get(0).get(4).argmax(0, false);
let index_2 = output.get(1).get(6).argmax(0, false);
let index_1 = model_output
.prediction_scores
.get(0)
.get(4)
.argmax(0, false);
let index_2 = model_output
.prediction_scores
.get(1)
.get(6)
.argmax(0, false);
let word_1 = tokenizer.vocab().id_to_token(&index_1.int64_value(&[]));
let word_2 = tokenizer.vocab().id_to_token(&index_2.int64_value(&[]));

View File

@ -224,7 +224,7 @@ impl<T: BertEmbedding> BertModel<T> {
/// let position_ids = Tensor::arange(sequence_length, (Int64, device))
/// .expand(&[batch_size, sequence_length], true);
///
/// let (output, pooled_output, all_hidden_states, all_attentions) = no_grad(|| {
/// let model_output = no_grad(|| {
/// bert_model
/// .forward_t(
/// Some(input_tensor),
@ -249,7 +249,7 @@ impl<T: BertEmbedding> BertModel<T> {
encoder_hidden_states: &Option<Tensor>,
encoder_mask: &Option<Tensor>,
train: bool,
) -> Result<(Tensor, Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
) -> Result<BertModelOutput, &'static str> {
let (input_shape, device) = match &input_ids {
Some(input_value) => match &input_embeds {
Some(_) => {
@ -275,7 +275,7 @@ impl<T: BertEmbedding> BertModel<T> {
2 => {
if self.is_decoder {
let seq_ids = Tensor::arange(input_shape[1], (Float, device));
let causal_mask = seq_ids.unsqueeze(0).unsqueeze(0).repeat(&vec![
let causal_mask = seq_ids.unsqueeze(0).unsqueeze(0).repeat(&[
input_shape[0],
input_shape[1],
1,
@ -342,12 +342,12 @@ impl<T: BertEmbedding> BertModel<T> {
let pooled_output = self.pooler.forward(&hidden_state);
Ok((
Ok(BertModelOutput {
hidden_state,
pooled_output,
all_hidden_states,
all_attentions,
))
})
}
}
@ -510,7 +510,7 @@ impl BertForMaskedLM {
/// let position_ids = Tensor::arange(sequence_length, (Int64, device))
/// .expand(&[batch_size, sequence_length], true);
///
/// let (output, all_hidden_states, all_attentions) = no_grad(|| {
/// let model_output = no_grad(|| {
/// bert_model.forward_t(
/// Some(input_tensor),
/// Some(mask),
@ -533,8 +533,8 @@ impl BertForMaskedLM {
encoder_hidden_states: &Option<Tensor>,
encoder_mask: &Option<Tensor>,
train: bool,
) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
let (hidden_state, _, all_hidden_states, all_attentions) = self
) -> BertMaskedLMOutput {
let model_output = self
.bert
.forward_t(
input_ids,
@ -548,8 +548,12 @@ impl BertForMaskedLM {
)
.unwrap();
let prediction_scores = self.cls.forward(&hidden_state);
(prediction_scores, all_hidden_states, all_attentions)
let prediction_scores = self.cls.forward(&model_output.hidden_state);
BertMaskedLMOutput {
prediction_scores,
all_hidden_states: model_output.all_hidden_states,
all_attentions: model_output.all_attentions,
}
}
}
@ -650,7 +654,7 @@ impl BertForSequenceClassification {
/// let position_ids = Tensor::arange(sequence_length, (Int64, device))
/// .expand(&[batch_size, sequence_length], true);
///
/// let (labels, all_hidden_states, all_attentions) = no_grad(|| {
/// let model_output = no_grad(|| {
/// bert_model.forward_t(
/// Some(input_tensor),
/// Some(mask),
@ -669,8 +673,8 @@ impl BertForSequenceClassification {
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool,
) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
let (_, pooled_output, all_hidden_states, all_attentions) = self
) -> BertSequenceClassificationOutput {
let model_output = self
.bert
.forward_t(
input_ids,
@ -684,10 +688,15 @@ impl BertForSequenceClassification {
)
.unwrap();
let output = pooled_output
let logits = model_output
.pooled_output
.apply_t(&self.dropout, train)
.apply(&self.classifier);
(output, all_hidden_states, all_attentions)
BertSequenceClassificationOutput {
logits,
all_hidden_states: model_output.all_hidden_states,
all_attentions: model_output.all_attentions,
}
}
}
@ -779,7 +788,7 @@ impl BertForMultipleChoice {
/// let position_ids = Tensor::arange(sequence_length, (Int64, device))
/// .expand(&[num_choices, sequence_length], true);
///
/// let (choices, all_hidden_states, all_attentions) = no_grad(|| {
/// let model_output = no_grad(|| {
/// bert_model.forward_t(
/// input_tensor,
/// Some(mask),
@ -796,7 +805,7 @@ impl BertForMultipleChoice {
token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>,
train: bool,
) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
) -> BertSequenceClassificationOutput {
let num_choices = input_ids.size()[1];
let input_ids = input_ids.view((-1, *input_ids.size().last().unwrap()));
@ -813,7 +822,7 @@ impl BertForMultipleChoice {
None => None,
};
let (_, pooled_output, all_hidden_states, all_attentions) = self
let model_output = self
.bert
.forward_t(
Some(input_ids),
@ -827,11 +836,16 @@ impl BertForMultipleChoice {
)
.unwrap();
let output = pooled_output
let logits = model_output
.pooled_output
.apply_t(&self.dropout, train)
.apply(&self.classifier)
.view((-1, num_choices));
(output, all_hidden_states, all_attentions)
BertSequenceClassificationOutput {
logits,
all_hidden_states: model_output.all_hidden_states,
all_attentions: model_output.all_attentions,
}
}
}
@ -933,7 +947,7 @@ impl BertForTokenClassification {
/// let position_ids = Tensor::arange(sequence_length, (Int64, device))
/// .expand(&[batch_size, sequence_length], true);
///
/// let (token_labels, all_hidden_states, all_attentions) = no_grad(|| {
/// let model_output = no_grad(|| {
/// bert_model.forward_t(
/// Some(input_tensor),
/// Some(mask),
@ -952,8 +966,8 @@ impl BertForTokenClassification {
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool,
) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
let (hidden_state, _, all_hidden_states, all_attentions) = self
) -> BertTokenClassificationOutput {
let model_output = self
.bert
.forward_t(
input_ids,
@ -967,10 +981,15 @@ impl BertForTokenClassification {
)
.unwrap();
let sequence_output = hidden_state
let logits = model_output
.hidden_state
.apply_t(&self.dropout, train)
.apply(&self.classifier);
(sequence_output, all_hidden_states, all_attentions)
BertTokenClassificationOutput {
logits,
all_hidden_states: model_output.all_hidden_states,
all_attentions: model_output.all_attentions,
}
}
}
@ -1064,7 +1083,7 @@ impl BertForQuestionAnswering {
/// let position_ids = Tensor::arange(sequence_length, (Int64, device))
/// .expand(&[batch_size, sequence_length], true);
///
/// let (start_scores, end_scores, all_hidden_states, all_attentions) = no_grad(|| {
/// let model_output = no_grad(|| {
/// bert_model.forward_t(
/// Some(input_tensor),
/// Some(mask),
@ -1083,8 +1102,8 @@ impl BertForQuestionAnswering {
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool,
) -> (Tensor, Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
let (hidden_state, _, all_hidden_states, all_attentions) = self
) -> BertQuestionAnsweringOutput {
let model_output = self
.bert
.forward_t(
input_ids,
@ -1098,12 +1117,49 @@ impl BertForQuestionAnswering {
)
.unwrap();
let sequence_output = hidden_state.apply(&self.qa_outputs);
let sequence_output = model_output.hidden_state.apply(&self.qa_outputs);
let logits = sequence_output.split(1, -1);
let (start_logits, end_logits) = (&logits[0], &logits[1]);
let start_logits = start_logits.squeeze1(-1);
let end_logits = end_logits.squeeze1(-1);
(start_logits, end_logits, all_hidden_states, all_attentions)
BertQuestionAnsweringOutput {
start_logits,
end_logits,
all_hidden_states: model_output.all_hidden_states,
all_attentions: model_output.all_attentions,
}
}
}
pub struct BertModelOutput {
pub hidden_state: Tensor,
pub pooled_output: Tensor,
pub all_hidden_states: Option<Vec<Tensor>>,
pub all_attentions: Option<Vec<Tensor>>,
}
pub struct BertMaskedLMOutput {
pub prediction_scores: Tensor,
pub all_hidden_states: Option<Vec<Tensor>>,
pub all_attentions: Option<Vec<Tensor>>,
}
pub struct BertSequenceClassificationOutput {
pub logits: Tensor,
pub all_hidden_states: Option<Vec<Tensor>>,
pub all_attentions: Option<Vec<Tensor>>,
}
pub struct BertTokenClassificationOutput {
pub logits: Tensor,
pub all_hidden_states: Option<Vec<Tensor>>,
pub all_attentions: Option<Vec<Tensor>>,
}
pub struct BertQuestionAnsweringOutput {
pub start_logits: Tensor,
pub end_logits: Tensor,
pub all_hidden_states: Option<Vec<Tensor>>,
pub all_attentions: Option<Vec<Tensor>>,
}

View File

@ -34,7 +34,7 @@ impl BertLayer {
let attention = BertAttention::new(p / "attention", &config);
let (is_decoder, cross_attention) = match config.is_decoder {
Some(value) => {
if value == true {
if value {
(
value,
Some(BertAttention::new(p / "cross_attention", &config)),
@ -150,28 +150,23 @@ impl BertEncoder {
let mut hidden_state = hidden_states.copy();
let mut attention_weights: Option<Tensor>;
let mut layers = self.layers.iter();
loop {
match layers.next() {
Some(layer) => {
if let Some(hidden_states) = all_hidden_states.borrow_mut() {
hidden_states.push(hidden_state.as_ref().copy());
};
let temp = layer.forward_t(
&hidden_state,
&mask,
encoder_hidden_states,
encoder_mask,
train,
);
hidden_state = temp.0;
attention_weights = temp.1;
if let Some(attentions) = all_attentions.borrow_mut() {
attentions.push(attention_weights.as_ref().unwrap().copy());
};
}
None => break,
for layer in &self.layers {
if let Some(hidden_states) = all_hidden_states.borrow_mut() {
hidden_states.push(hidden_state.as_ref().copy());
};
let temp = layer.forward_t(
&hidden_state,
&mask,
encoder_hidden_states,
encoder_mask,
train,
);
hidden_state = temp.0;
attention_weights = temp.1;
if let Some(attentions) = all_attentions.borrow_mut() {
attentions.push(attention_weights.as_ref().unwrap().copy());
};
}

View File

@ -152,7 +152,7 @@ lazy_static! {
}
fn _get_cache_directory() -> PathBuf {
let home = match env::var("RUSTBERT_CACHE") {
match env::var("RUSTBERT_CACHE") {
Ok(value) => PathBuf::from(value),
Err(_) => {
let mut home = dirs::home_dir().unwrap();
@ -160,8 +160,7 @@ fn _get_cache_directory() -> PathBuf {
home.push(".rustbert");
home
}
};
home
}
}
#[deprecated(

View File

@ -65,7 +65,7 @@ impl MultiHeadSelfAttention {
fn flatten(&self, x: Tensor, bs: i64, dim_per_head: i64) -> Tensor {
x.transpose(1, 2)
.contiguous()
.view((bs, -1, &self.n_heads * dim_per_head))
.view((bs, -1, self.n_heads * dim_per_head))
}
pub fn forward_t(

View File

@ -15,7 +15,7 @@ extern crate tch;
use self::tch::{nn, Tensor};
use crate::common::dropout::Dropout;
use crate::distilbert::embeddings::DistilBertEmbedding;
use crate::distilbert::transformer::Transformer;
use crate::distilbert::transformer::{DistilBertTransformerOutput, Transformer};
use crate::Config;
use serde::{Deserialize, Serialize};
use std::{borrow::Borrow, collections::HashMap};
@ -202,7 +202,7 @@ impl DistilBertModel {
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
///
/// let (output, all_hidden_states, all_attentions) = no_grad(|| {
/// let model_output = no_grad(|| {
/// distilbert_model
/// .forward_t(Some(input_tensor), Some(mask), None, false)
/// .unwrap()
@ -214,7 +214,7 @@ impl DistilBertModel {
mask: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool,
) -> Result<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
) -> Result<DistilBertTransformerOutput, &'static str> {
let input_embeddings = match input {
Some(input_value) => match input_embeds {
Some(_) => {
@ -335,7 +335,7 @@ impl DistilBertModelClassifier {
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
///
/// let (output, all_hidden_states, all_attentions) = no_grad(|| {
/// let model_output = no_grad(|| {
/// distilbert_model
/// .forward_t(Some(input_tensor),
/// Some(mask),
@ -349,24 +349,28 @@ impl DistilBertModelClassifier {
mask: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool,
) -> Result<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
let (output, all_hidden_states, all_attentions) =
match self
.distil_bert_model
.forward_t(input, mask, input_embeds, train)
{
Ok(value) => value,
Err(err) => return Err(err),
};
) -> Result<DistilBertSequenceClassificationOutput, &'static str> {
let model_output = match self
.distil_bert_model
.forward_t(input, mask, input_embeds, train)
{
Ok(value) => value,
Err(err) => return Err(err),
};
let output = output
let logits = model_output
.hidden_state
.select(1, 0)
.apply(&self.pre_classifier)
.relu()
.apply_t(&self.dropout, train)
.apply(&self.classifier);
Ok((output, all_hidden_states, all_attentions))
Ok(DistilBertSequenceClassificationOutput {
logits,
all_hidden_states: model_output.all_hidden_states,
all_attentions: model_output.all_attentions,
})
}
}
@ -473,7 +477,7 @@ impl DistilBertModelMaskedLM {
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
///
/// let (output, all_hidden_states, all_attentions) = no_grad(|| {
/// let model_output = no_grad(|| {
/// distilbert_model
/// .forward_t(Some(input_tensor), Some(mask), None, false)
/// .unwrap()
@ -485,23 +489,27 @@ impl DistilBertModelMaskedLM {
mask: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool,
) -> Result<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
let (output, all_hidden_states, all_attentions) =
match self
.distil_bert_model
.forward_t(input, mask, input_embeds, train)
{
Ok(value) => value,
Err(err) => return Err(err),
};
) -> Result<DistilBertMaskedLMOutput, &'static str> {
let model_output = match self
.distil_bert_model
.forward_t(input, mask, input_embeds, train)
{
Ok(value) => value,
Err(err) => return Err(err),
};
let output = output
let prediction_scores = model_output
.hidden_state
.apply(&self.vocab_transform)
.gelu()
.apply(&self.vocab_layer_norm)
.apply(&self.vocab_projector);
Ok((output, all_hidden_states, all_attentions))
Ok(DistilBertMaskedLMOutput {
prediction_scores,
all_hidden_states: model_output.all_hidden_states,
all_attentions: model_output.all_attentions,
})
}
}
@ -591,7 +599,7 @@ impl DistilBertForQuestionAnswering {
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
///
/// let (start_scores, end_score, all_hidden_states, all_attentions) = no_grad(|| {
/// let model_output = no_grad(|| {
/// distilbert_model
/// .forward_t(Some(input_tensor), Some(mask), None, false)
/// .unwrap()
@ -603,24 +611,31 @@ impl DistilBertForQuestionAnswering {
mask: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool,
) -> Result<(Tensor, Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
let (output, all_hidden_states, all_attentions) =
match self
.distil_bert_model
.forward_t(input, mask, input_embeds, train)
{
Ok(value) => value,
Err(err) => return Err(err),
};
) -> Result<DistilBertQuestionAnsweringOutput, &'static str> {
let model_output = match self
.distil_bert_model
.forward_t(input, mask, input_embeds, train)
{
Ok(value) => value,
Err(err) => return Err(err),
};
let output = output.apply_t(&self.dropout, train).apply(&self.qa_outputs);
let output = model_output
.hidden_state
.apply_t(&self.dropout, train)
.apply(&self.qa_outputs);
let logits = output.split(1, -1);
let (start_logits, end_logits) = (&logits[0], &logits[1]);
let start_logits = start_logits.squeeze1(-1);
let end_logits = end_logits.squeeze1(-1);
Ok((start_logits, end_logits, all_hidden_states, all_attentions))
Ok(DistilBertQuestionAnsweringOutput {
start_logits,
end_logits,
all_hidden_states: model_output.all_hidden_states,
all_attentions: model_output.all_attentions,
})
}
}
@ -715,7 +730,7 @@ impl DistilBertForTokenClassification {
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
///
/// let (output, all_hidden_states, all_attentions) = no_grad(|| {
/// let model_output = no_grad(|| {
/// distilbert_model
/// .forward_t(Some(input_tensor), Some(mask), None, false)
/// .unwrap()
@ -727,18 +742,49 @@ impl DistilBertForTokenClassification {
mask: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool,
) -> Result<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
let (output, all_hidden_states, all_attentions) =
match self
.distil_bert_model
.forward_t(input, mask, input_embeds, train)
{
Ok(value) => value,
Err(err) => return Err(err),
};
) -> Result<DistilBertTokenClassificationOutput, &'static str> {
let model_output = match self
.distil_bert_model
.forward_t(input, mask, input_embeds, train)
{
Ok(value) => value,
Err(err) => return Err(err),
};
let output = output.apply_t(&self.dropout, train).apply(&self.classifier);
let logits = model_output
.hidden_state
.apply_t(&self.dropout, train)
.apply(&self.classifier);
Ok((output, all_hidden_states, all_attentions))
Ok(DistilBertTokenClassificationOutput {
logits,
all_hidden_states: model_output.all_hidden_states,
all_attentions: model_output.all_attentions,
})
}
}
pub struct DistilBertMaskedLMOutput {
pub prediction_scores: Tensor,
pub all_hidden_states: Option<Vec<Tensor>>,
pub all_attentions: Option<Vec<Tensor>>,
}
pub struct DistilBertSequenceClassificationOutput {
pub logits: Tensor,
pub all_hidden_states: Option<Vec<Tensor>>,
pub all_attentions: Option<Vec<Tensor>>,
}
pub struct DistilBertTokenClassificationOutput {
pub logits: Tensor,
pub all_hidden_states: Option<Vec<Tensor>>,
pub all_attentions: Option<Vec<Tensor>>,
}
pub struct DistilBertQuestionAnsweringOutput {
pub start_logits: Tensor,
pub end_logits: Tensor,
pub all_hidden_states: Option<Vec<Tensor>>,
pub all_attentions: Option<Vec<Tensor>>,
}

View File

@ -125,10 +125,8 @@ impl ModuleT for DistilBertEmbedding {
let position_embed = position_ids.apply(&self.position_embeddings);
let embeddings = word_embed + position_embed;
let embeddings = embeddings
.apply(&self.layer_norm)
.apply_t(&self.dropout, train);
embeddings
.apply(&self.layer_norm)
.apply_t(&self.dropout, train)
}
}

View File

@ -149,7 +149,7 @@ impl Transformer {
input: &Tensor,
mask: Option<Tensor>,
train: bool,
) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
) -> DistilBertTransformerOutput {
let mut all_hidden_states: Option<Vec<Tensor>> = if self.output_hidden_states {
Some(vec![])
} else {
@ -163,25 +163,30 @@ impl Transformer {
let mut hidden_state = input.copy();
let mut attention_weights: Option<Tensor>;
let mut layers = self.layers.iter();
loop {
match layers.next() {
Some(layer) => {
if let Some(hidden_states) = all_hidden_states.borrow_mut() {
hidden_states.push(hidden_state.as_ref().copy());
};
let temp = layer.forward_t(&hidden_state, &mask, train);
hidden_state = temp.0;
attention_weights = temp.1;
if let Some(attentions) = all_attentions.borrow_mut() {
attentions.push(attention_weights.as_ref().unwrap().copy());
};
}
None => break,
for layer in &self.layers {
if let Some(hidden_states) = all_hidden_states.borrow_mut() {
hidden_states.push(hidden_state.as_ref().copy());
};
let temp = layer.forward_t(&hidden_state, &mask, train);
hidden_state = temp.0;
attention_weights = temp.1;
if let Some(attentions) = all_attentions.borrow_mut() {
attentions.push(attention_weights.as_ref().unwrap().copy());
};
}
(hidden_state, all_hidden_states, all_attentions)
DistilBertTransformerOutput {
hidden_state,
all_hidden_states,
all_attentions,
}
}
}
pub struct DistilBertTransformerOutput {
pub hidden_state: Tensor,
pub all_hidden_states: Option<Vec<Tensor>>,
pub all_attentions: Option<Vec<Tensor>>,
}

View File

@ -356,13 +356,13 @@ impl QuestionAnsweringOption {
match *self {
Self::Bert(ref model) => {
let outputs = model.forward_t(input_ids, mask, None, None, input_embeds, train);
(outputs.0, outputs.1)
(outputs.start_logits, outputs.end_logits)
}
Self::DistilBert(ref model) => {
let outputs = model
.forward_t(input_ids, mask, input_embeds, train)
.expect("Error in distilbert forward_t");
(outputs.0, outputs.1)
(outputs.start_logits, outputs.end_logits)
}
Self::Roberta(ref model) | Self::XLMRoberta(ref model) => {
let outputs = model.forward_t(input_ids, mask, None, None, input_embeds, train);

View File

@ -313,13 +313,13 @@ impl SequenceClassificationOption {
input_embeds,
train,
)
.0
.logits
}
Self::DistilBert(ref model) => {
model
.forward_t(input_ids, mask, input_embeds, train)
.expect("Error in distilbert forward_t")
.0
.logits
}
Self::Roberta(ref model) | Self::XLMRoberta(ref model) => {
model

View File

@ -408,13 +408,13 @@ impl TokenClassificationOption {
input_embeds,
train,
)
.0
.logits
}
Self::DistilBert(ref model) => {
model
.forward_t(input_ids, mask, input_embeds, train)
.expect("Error in distilbert forward_t")
.0
.logits
}
Self::Roberta(ref model) | Self::XLMRoberta(ref model) => {
model

View File

@ -344,13 +344,13 @@ impl ZeroShotClassificationOption {
input_embeds,
train,
)
.0
.logits
}
Self::DistilBert(ref model) => {
model
.forward_t(input_ids, mask, input_embeds, train)
.expect("Error in distilbert forward_t")
.0
.logits
}
Self::Roberta(ref model) | Self::XLMRoberta(ref model) => {
model

View File

@ -306,7 +306,7 @@ impl RobertaForMaskedLM {
encoder_mask: &Option<Tensor>,
train: bool,
) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
let (hidden_state, _, all_hidden_states, all_attentions) = self
let model_output = self
.roberta
.forward_t(
input_ids,
@ -320,8 +320,12 @@ impl RobertaForMaskedLM {
)
.unwrap();
let prediction_scores = self.lm_head.forward(&hidden_state);
(prediction_scores, all_hidden_states, all_attentions)
let prediction_scores = self.lm_head.forward(&model_output.hidden_state);
(
prediction_scores,
model_output.all_hidden_states,
model_output.all_attentions,
)
}
}
@ -480,7 +484,7 @@ impl RobertaForSequenceClassification {
input_embeds: Option<Tensor>,
train: bool,
) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
let (hidden_state, _, all_hidden_states, all_attentions) = self
let model_output = self
.roberta
.forward_t(
input_ids,
@ -494,8 +498,12 @@ impl RobertaForSequenceClassification {
)
.unwrap();
let output = self.classifier.forward_t(&hidden_state, train);
(output, all_hidden_states, all_attentions)
let output = self.classifier.forward_t(&model_output.hidden_state, train);
(
output,
model_output.all_hidden_states,
model_output.all_attentions,
)
}
}
@ -623,7 +631,7 @@ impl RobertaForMultipleChoice {
None => None,
};
let (_, pooled_output, all_hidden_states, all_attentions) = self
let model_output = self
.roberta
.forward_t(
flat_input_ids,
@ -637,11 +645,16 @@ impl RobertaForMultipleChoice {
)
.unwrap();
let output = pooled_output
let output = model_output
.pooled_output
.apply_t(&self.dropout, train)
.apply(&self.classifier)
.view((-1, num_choices));
(output, all_hidden_states, all_attentions)
(
output,
model_output.all_hidden_states,
model_output.all_attentions,
)
}
}
@ -765,7 +778,7 @@ impl RobertaForTokenClassification {
input_embeds: Option<Tensor>,
train: bool,
) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
let (hidden_state, _, all_hidden_states, all_attentions) = self
let model_output = self
.roberta
.forward_t(
input_ids,
@ -779,10 +792,15 @@ impl RobertaForTokenClassification {
)
.unwrap();
let sequence_output = hidden_state
let sequence_output = model_output
.hidden_state
.apply_t(&self.dropout, train)
.apply(&self.classifier);
(sequence_output, all_hidden_states, all_attentions)
(
sequence_output,
model_output.all_hidden_states,
model_output.all_attentions,
)
}
}
@ -901,7 +919,7 @@ impl RobertaForQuestionAnswering {
input_embeds: Option<Tensor>,
train: bool,
) -> (Tensor, Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
let (hidden_state, _, all_hidden_states, all_attentions) = self
let model_output = self
.roberta
.forward_t(
input_ids,
@ -915,12 +933,17 @@ impl RobertaForQuestionAnswering {
)
.unwrap();
let sequence_output = hidden_state.apply(&self.qa_outputs);
let sequence_output = model_output.hidden_state.apply(&self.qa_outputs);
let logits = sequence_output.split(1, -1);
let (start_logits, end_logits) = (&logits[0], &logits[1]);
let start_logits = start_logits.squeeze1(-1);
let end_logits = end_logits.squeeze1(-1);
(start_logits, end_logits, all_hidden_states, all_attentions)
(
start_logits,
end_logits,
model_output.all_hidden_states,
model_output.all_attentions,
)
}
}

View File

@ -70,7 +70,7 @@ fn bert_masked_lm() -> anyhow::Result<()> {
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
let (output, _, _) = no_grad(|| {
let model_output = no_grad(|| {
bert_model.forward_t(
Some(input_tensor),
None,
@ -84,8 +84,16 @@ fn bert_masked_lm() -> anyhow::Result<()> {
});
// Print masked tokens
let index_1 = output.get(0).get(4).argmax(0, false);
let index_2 = output.get(1).get(6).argmax(0, false);
let index_1 = model_output
.prediction_scores
.get(0)
.get(4)
.argmax(0, false);
let index_2 = model_output
.prediction_scores
.get(1)
.get(6)
.argmax(0, false);
let word_1 = tokenizer.vocab().id_to_token(&index_1.int64_value(&[]));
let word_2 = tokenizer.vocab().id_to_token(&index_2.int64_value(&[]));
@ -144,17 +152,17 @@ fn bert_for_sequence_classification() -> anyhow::Result<()> {
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
let (output, all_hidden_states, all_attentions) =
let model_output =
no_grad(|| bert_model.forward_t(Some(input_tensor), None, None, None, None, false));
assert_eq!(output.size(), &[2, 3]);
assert_eq!(model_output.logits.size(), &[2, 3]);
assert_eq!(
config.num_hidden_layers as usize,
all_hidden_states.unwrap().len()
model_output.all_hidden_states.unwrap().len()
);
assert_eq!(
config.num_hidden_layers as usize,
all_attentions.unwrap().len()
model_output.all_attentions.unwrap().len()
);
Ok(())
@ -206,17 +214,16 @@ fn bert_for_multiple_choice() -> anyhow::Result<()> {
.unsqueeze(0);
// Forward pass
let (output, all_hidden_states, all_attentions) =
no_grad(|| bert_model.forward_t(input_tensor, None, None, None, false));
let model_output = no_grad(|| bert_model.forward_t(input_tensor, None, None, None, false));
assert_eq!(output.size(), &[1, 2]);
assert_eq!(model_output.logits.size(), &[1, 2]);
assert_eq!(
config.num_hidden_layers as usize,
all_hidden_states.unwrap().len()
model_output.all_hidden_states.unwrap().len()
);
assert_eq!(
config.num_hidden_layers as usize,
all_attentions.unwrap().len()
model_output.all_attentions.unwrap().len()
);
Ok(())
@ -272,17 +279,17 @@ fn bert_for_token_classification() -> anyhow::Result<()> {
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
let (output, all_hidden_states, all_attentions) =
let model_output =
no_grad(|| bert_model.forward_t(Some(input_tensor), None, None, None, None, false));
assert_eq!(output.size(), &[2, 11, 4]);
assert_eq!(model_output.logits.size(), &[2, 11, 4]);
assert_eq!(
config.num_hidden_layers as usize,
all_hidden_states.unwrap().len()
model_output.all_hidden_states.unwrap().len()
);
assert_eq!(
config.num_hidden_layers as usize,
all_attentions.unwrap().len()
model_output.all_attentions.unwrap().len()
);
Ok(())
@ -332,18 +339,18 @@ fn bert_for_question_answering() -> anyhow::Result<()> {
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
let (start_scores, end_scores, all_hidden_states, all_attentions) =
let model_output =
no_grad(|| bert_model.forward_t(Some(input_tensor), None, None, None, None, false));
assert_eq!(start_scores.size(), &[2, 11]);
assert_eq!(end_scores.size(), &[2, 11]);
assert_eq!(model_output.start_logits.size(), &[2, 11]);
assert_eq!(model_output.end_logits.size(), &[2, 11]);
assert_eq!(
config.num_hidden_layers as usize,
all_hidden_states.unwrap().len()
model_output.all_hidden_states.unwrap().len()
);
assert_eq!(
config.num_hidden_layers as usize,
all_attentions.unwrap().len()
model_output.all_attentions.unwrap().len()
);
Ok(())

View File

@ -96,15 +96,23 @@ fn distilbert_masked_lm() -> anyhow::Result<()> {
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
let (output, _, _) = no_grad(|| {
let model_output = no_grad(|| {
distil_bert_model
.forward_t(Some(input_tensor), None, None, false)
.unwrap()
});
// Print masked tokens
let index_1 = output.get(0).get(4).argmax(0, false);
let index_2 = output.get(1).get(6).argmax(0, false);
let index_1 = model_output
.prediction_scores
.get(0)
.get(4)
.argmax(0, false);
let index_2 = model_output
.prediction_scores
.get(1)
.get(6)
.argmax(0, false);
let word_1 = tokenizer.vocab().id_to_token(&index_1.int64_value(&[]));
let word_2 = tokenizer.vocab().id_to_token(&index_2.int64_value(&[]));
@ -160,16 +168,22 @@ fn distilbert_for_question_answering() -> anyhow::Result<()> {
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
let (start_scores, end_scores, all_hidden_states, all_attentions) = no_grad(|| {
let model_output = no_grad(|| {
distil_bert_model
.forward_t(Some(input_tensor), None, None, false)
.unwrap()
});
assert_eq!(start_scores.size(), &[2, 11]);
assert_eq!(end_scores.size(), &[2, 11]);
assert_eq!(config.n_layers as usize, all_hidden_states.unwrap().len());
assert_eq!(config.n_layers as usize, all_attentions.unwrap().len());
assert_eq!(model_output.start_logits.size(), &[2, 11]);
assert_eq!(model_output.end_logits.size(), &[2, 11]);
assert_eq!(
config.n_layers as usize,
model_output.all_hidden_states.unwrap().len()
);
assert_eq!(
config.n_layers as usize,
model_output.all_attentions.unwrap().len()
);
Ok(())
}
@ -226,15 +240,21 @@ fn distilbert_for_token_classification() -> anyhow::Result<()> {
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
let (output, all_hidden_states, all_attentions) = no_grad(|| {
let model_output = no_grad(|| {
distil_bert_model
.forward_t(Some(input_tensor), None, None, false)
.unwrap()
});
assert_eq!(output.size(), &[2, 11, 4]);
assert_eq!(config.n_layers as usize, all_hidden_states.unwrap().len());
assert_eq!(config.n_layers as usize, all_attentions.unwrap().len());
assert_eq!(model_output.logits.size(), &[2, 11, 4]);
assert_eq!(
config.n_layers as usize,
model_output.all_hidden_states.unwrap().len()
);
assert_eq!(
config.n_layers as usize,
model_output.all_attentions.unwrap().len()
);
Ok(())
}