Merge branch 'master' of https://github.com/guillaume-be/rust-bert into xlnet_implementation

This commit is contained in:
Guillaume B 2020-09-16 17:19:22 +02:00
commit 6f907ff995
71 changed files with 2356 additions and 1955 deletions

View File

@ -7,6 +7,10 @@ jobs:
- rustup component add rustfmt
script:
- cargo fmt -- --check
- before_script:
- rustup component add clippy
script:
- cargo clippy --all-targets --all-features -- -D warnings -A clippy::assign_op_pattern
- script:
- cargo build --verbose
- os:

View File

@ -1,6 +1,6 @@
[package]
name = "rust-bert"
version = "0.9.0"
version = "0.10.0"
authors = ["Guillaume Becquin <guillaume.becquin@gmail.com>"]
edition = "2018"
description = "Ready-to-use NLP pipelines and transformer-based models (BERT, DistilBERT, GPT2,...)"
@ -37,7 +37,7 @@ serde = { version = "1.0.114", features = ["derive"] }
dirs = "3.0.1"
itertools = "0.9.0"
ordered-float = "2.0.0"
cached-path = "0.4.3"
cached-path = "0.4.5"
lazy_static = "1.4.0"
uuid = { version = "0.8.1", features = ["v4"] }
thiserror = "1.0.20"

1
clippy.toml Normal file
View File

@ -0,0 +1 @@
too-many-arguments-threshold = 10

View File

@ -70,12 +70,23 @@ fn main() -> anyhow::Result<()> {
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
let (output, _, _) =
let model_output =
no_grad(|| albert_model.forward_t(Some(input_tensor), None, None, None, None, false));
println!("{:?}", output.double_value(&[0, 0, 0]));
println!(
"{:?}",
model_output.prediction_scores.double_value(&[0, 0, 0])
);
// Print masked tokens
let index_1 = output.get(0).get(4).argmax(0, false);
let index_2 = output.get(1).get(7).argmax(0, false);
let index_1 = model_output
.prediction_scores
.get(0)
.get(4)
.argmax(0, false);
let index_2 = model_output
.prediction_scores
.get(1)
.get(7)
.argmax(0, false);
let word_1 = tokenizer.vocab().id_to_token(&index_1.int64_value(&[]));
let word_2 = tokenizer.vocab().id_to_token(&index_2.int64_value(&[]));

View File

@ -73,12 +73,12 @@ fn main() -> anyhow::Result<()> {
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
let (decoder_output, encoder_output, _, _, _, _, _) =
let model_output =
no_grad(|| bart_model.forward_t(Some(&input_tensor), None, None, None, None, None, false));
// Print masked tokens
println!("{:?}", encoder_output);
println!("{:?}", decoder_output);
println!("{:?}", decoder_output.double_value(&[0, 0, 0]));
println!("{:?}", model_output.encoder_hidden_state);
println!("{:?}", model_output.decoder_output);
println!("{:?}", model_output.decoder_output.double_value(&[0, 0, 0]));
Ok(())
}

View File

@ -65,7 +65,7 @@ fn main() -> anyhow::Result<()> {
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
let (output, _, _) = no_grad(|| {
let model_output = no_grad(|| {
bert_model.forward_t(
Some(input_tensor),
None,
@ -79,8 +79,16 @@ fn main() -> anyhow::Result<()> {
});
// Print masked tokens
let index_1 = output.get(0).get(4).argmax(0, false);
let index_2 = output.get(1).get(7).argmax(0, false);
let index_1 = model_output
.prediction_scores
.get(0)
.get(4)
.argmax(0, false);
let index_2 = model_output
.prediction_scores
.get(1)
.get(7)
.argmax(0, false);
let word_1 = tokenizer.vocab().id_to_token(&index_1.int64_value(&[]));
let word_2 = tokenizer.vocab().id_to_token(&index_2.int64_value(&[]));

View File

@ -77,15 +77,23 @@ fn main() -> anyhow::Result<()> {
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
let (output, _, _) = no_grad(|| {
let model_output = no_grad(|| {
distil_bert_model
.forward_t(Some(input_tensor), None, None, false)
.unwrap()
});
// Print masked tokens
let index_1 = output.get(0).get(4).argmax(0, false);
let index_2 = output.get(1).get(6).argmax(0, false);
let index_1 = model_output
.prediction_scores
.get(0)
.get(4)
.argmax(0, false);
let index_2 = model_output
.prediction_scores
.get(1)
.get(6)
.argmax(0, false);
let word_1 = tokenizer.vocab().id_to_token(&index_1.int64_value(&[]));
let word_2 = tokenizer.vocab().id_to_token(&index_2.int64_value(&[]));

View File

@ -66,12 +66,12 @@ fn main() -> anyhow::Result<()> {
let input_tensor = Tensor::stack(encoded_input.as_slice(), 0).to(device);
// Forward pass
let (output, _, _) =
let model_output =
no_grad(|| electra_model.forward_t(Some(input_tensor), None, None, None, None, false));
// Print model predictions
for (position, token) in tokenized_input[0].token_ids.iter().enumerate() {
let probability = output.double_value(&[position as i64]);
let probability = model_output.probabilities.double_value(&[position as i64]);
let generated = if probability > 0.5 {
"generated"
} else {

View File

@ -69,12 +69,20 @@ fn main() -> anyhow::Result<()> {
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
let (output, _, _) =
let model_output =
no_grad(|| electra_model.forward_t(Some(input_tensor), None, None, None, None, false));
// Print masked tokens
let index_1 = output.get(0).get(4).argmax(0, false);
let index_2 = output.get(1).get(7).argmax(0, false);
let index_1 = model_output
.prediction_scores
.get(0)
.get(4)
.argmax(0, false);
let index_2 = model_output
.prediction_scores
.get(1)
.get(7)
.argmax(0, false);
let word_1 = tokenizer.vocab().id_to_token(&index_1.int64_value(&[]));
let word_2 = tokenizer.vocab().id_to_token(&index_2.int64_value(&[]));

View File

@ -70,7 +70,7 @@ fn main() -> anyhow::Result<()> {
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
let (output, _, _, _, _) = gpt2_model
let model_output = gpt2_model
.forward_t(
&Some(input_tensor),
Cache::None,
@ -84,7 +84,12 @@ fn main() -> anyhow::Result<()> {
)
.unwrap();
let next_word_id = output.get(0).get(-1).argmax(-1, true).int64_value(&[0]);
let next_word_id = model_output
.lm_logits
.get(0)
.get(-1)
.argmax(-1, true)
.int64_value(&[0]);
let next_word = tokenizer.decode(vec![next_word_id], true, true);
println!("Provided input: {}", input[0]);
println!("Next word: {}", next_word);

View File

@ -75,7 +75,7 @@ fn main() -> anyhow::Result<()> {
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
let (output, _, _, _, _) = openai_gpt
let model_output = openai_gpt
.forward_t(
&Some(input_tensor),
Cache::None,
@ -89,7 +89,12 @@ fn main() -> anyhow::Result<()> {
)
.unwrap();
let next_word_id = output.get(0).get(-1).argmax(-1, true).int64_value(&[0]);
let next_word_id = model_output
.lm_logits
.get(0)
.get(-1)
.argmax(-1, true)
.int64_value(&[0]);
let next_word = tokenizer.decode(vec![next_word_id], true, true);
println!("Provided input: {}", input[0]);
println!("Next word: {}", next_word);

View File

@ -33,7 +33,7 @@ fn main() -> anyhow::Result<()> {
};
// Get answer
let answers = qa_model.predict(&vec![qa_input_1, qa_input_2], 1, 32);
let answers = qa_model.predict(&[qa_input_1, qa_input_2], 1, 32);
println!("{:?}", answers);
Ok(())
}

View File

@ -51,7 +51,7 @@ fn main() -> anyhow::Result<()> {
};
// Get answer
let answers = qa_model.predict(&vec![qa_input_1, qa_input_2], 1, 32);
let answers = qa_model.predict(&[qa_input_1, qa_input_2], 1, 32);
println!("{:?}", answers);
Ok(())
}

View File

@ -85,7 +85,7 @@ fn main() -> anyhow::Result<()> {
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
let (output, _, _) = no_grad(|| {
let model_output = no_grad(|| {
bert_model.forward_t(
Some(input_tensor),
None,
@ -99,8 +99,16 @@ fn main() -> anyhow::Result<()> {
});
// Print masked tokens
let index_1 = output.get(0).get(4).argmax(0, false);
let index_2 = output.get(1).get(5).argmax(0, false);
let index_1 = model_output
.prediction_scores
.get(0)
.get(4)
.argmax(0, false);
let index_2 = model_output
.prediction_scores
.get(1)
.get(5)
.argmax(0, false);
let word_1 = tokenizer.vocab().id_to_token(&index_1.int64_value(&[]));
let word_2 = tokenizer.vocab().id_to_token(&index_2.int64_value(&[]));

View File

@ -27,8 +27,10 @@ fn main() -> anyhow::Result<()> {
// Run model
let output = sequence_classification_model.predict_multilabel(&input, 0.05);
for label in output {
println!("{:?}", label);
if let Ok(labels) = output {
for label in labels {
println!("{:?}", label);
}
}
Ok(())

View File

@ -15,7 +15,7 @@ use crate::albert::embeddings::AlbertEmbeddings;
use crate::albert::encoder::AlbertTransformer;
use crate::common::activations::{_gelu, _gelu_new, _mish, _relu, _tanh};
use crate::common::dropout::Dropout;
use crate::Config;
use crate::{Config, RustBertError};
use serde::{Deserialize, Serialize};
use std::{borrow::Borrow, collections::HashMap};
use tch::nn::Module;
@ -176,11 +176,11 @@ impl AlbertModel {
/// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
///
/// # Returns
///
/// * `output` - `Tensor` of shape (*batch size*, *sequence_length*, *hidden_size*)
/// * `pooled_output` - `Tensor` of shape (*batch size*, *hidden_size*)
/// * `hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
/// * `attentions` - `Option<Vec<Vec<Tensor>>>` of length *num_hidden_layers* of nested length *inner_group_num* with shape (*batch size*, *sequence_length*, *hidden_size*)
/// * `AlbertOutput` containing:
/// - `hidden_state` - `Tensor` of shape (*batch size*, *sequence_length*, *hidden_size*)
/// - `pooled_output` - `Tensor` of shape (*batch size*, *hidden_size*)
/// - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
/// - `all_attentions` - `Option<Vec<Vec<Tensor>>>` of length *num_hidden_layers* of nested length *inner_group_num* with shape (*batch size*, *sequence_length*, *hidden_size*)
///
/// # Example
///
@ -202,7 +202,7 @@ impl AlbertModel {
/// let position_ids = Tensor::arange(sequence_length, (Int64, device))
/// .expand(&[batch_size, sequence_length], true);
///
/// let (output, pooled_output, all_hidden_states, all_attentions) = no_grad(|| {
/// let model_output = no_grad(|| {
/// albert_model
/// .forward_t(
/// Some(input_tensor),
@ -223,26 +223,22 @@ impl AlbertModel {
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool,
) -> Result<
(
Tensor,
Tensor,
Option<Vec<Tensor>>,
Option<Vec<Vec<Tensor>>>,
),
&'static str,
> {
) -> Result<AlbertOutput, RustBertError> {
let (input_shape, device) = match &input_ids {
Some(input_value) => match &input_embeds {
Some(_) => {
return Err("Only one of input ids or input embeddings may be set");
return Err(RustBertError::ValueError(
"Only one of input ids or input embeddings may be set".into(),
));
}
None => (input_value.size(), input_value.device()),
},
None => match &input_embeds {
Some(embeds) => (vec![embeds.size()[0], embeds.size()[1]], embeds.device()),
None => {
return Err("At least one of input ids or input embeddings must be set");
return Err(RustBertError::ValueError(
"At least one of input ids or input embeddings must be set".into(),
));
}
},
};
@ -269,19 +265,21 @@ impl AlbertModel {
}
};
let (hidden_state, all_hidden_states, all_attentions) =
let transformer_output =
self.encoder
.forward_t(&embedding_output, Some(extended_attention_mask), train);
let pooled_output = self.pooler.forward(&hidden_state.select(1, 0));
let pooled_output = self
.pooler
.forward(&transformer_output.hidden_state.select(1, 0));
let pooled_output = (self.pooler_activation)(&pooled_output);
Ok((
hidden_state,
Ok(AlbertOutput {
hidden_state: transformer_output.hidden_state,
pooled_output,
all_hidden_states,
all_attentions,
))
all_hidden_states: transformer_output.all_hidden_states,
all_attentions: transformer_output.all_attentions,
})
}
}
@ -406,9 +404,10 @@ impl AlbertForMaskedLM {
///
/// # Returns
///
/// * `output` - `Tensor` of shape (*batch size*, *sequence_length*, *vocab_size*)
/// * `hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
/// * `attentions` - `Option<Vec<Vec<Tensor>>>` of length *num_hidden_layers* of nested length *inner_group_num* with shape (*batch size*, *sequence_length*, *hidden_size*)
/// * `AlbertMaskedLMOutput` containing:
/// - `prediction_scores` - `Tensor` of shape (*batch size*, *sequence_length*, *vocab_size*)
/// - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
/// - `all_attentions` - `Option<Vec<Vec<Tensor>>>` of length *num_hidden_layers* of nested length *inner_group_num* with shape (*batch size*, *sequence_length*, *hidden_size*)
///
/// # Example
///
@ -430,7 +429,7 @@ impl AlbertForMaskedLM {
/// let position_ids = Tensor::arange(sequence_length, (Int64, device))
/// .expand(&[batch_size, sequence_length], true);
///
/// let (output, all_hidden_states, all_attentions) = no_grad(|| {
/// let masked_lm_output = no_grad(|| {
/// albert_model.forward_t(
/// Some(input_tensor),
/// Some(mask),
@ -449,8 +448,8 @@ impl AlbertForMaskedLM {
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool,
) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Vec<Tensor>>>) {
let (hidden_state, _, all_hidden_states, all_attentions) = self
) -> AlbertMaskedLMOutput {
let base_model_output = self
.albert
.forward_t(
input_ids,
@ -461,8 +460,12 @@ impl AlbertForMaskedLM {
train,
)
.unwrap();
let prediction_scores = self.predictions.forward(&hidden_state);
(prediction_scores, all_hidden_states, all_attentions)
let prediction_scores = self.predictions.forward(&base_model_output.hidden_state);
AlbertMaskedLMOutput {
prediction_scores,
all_hidden_states: base_model_output.all_hidden_states,
all_attentions: base_model_output.all_attentions,
}
}
}
@ -545,9 +548,10 @@ impl AlbertForSequenceClassification {
///
/// # Returns
///
/// * `output` - `Tensor` of shape (*batch size*, *num_labels*)
/// * `hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
/// * `attentions` - `Option<Vec<Vec<Tensor>>>` of length *num_hidden_layers* of nested length *inner_group_num* with shape (*batch size*, *sequence_length*, *hidden_size*)
/// * `AlbertSequenceClassificationOutput` containing:
/// - `logits` - `Tensor` of shape (*batch size*, *num_labels*)
/// - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
/// - `all_attentions` - `Option<Vec<Vec<Tensor>>>` of length *num_hidden_layers* of nested length *inner_group_num* with shape (*batch size*, *sequence_length*, *hidden_size*)
///
/// # Example
///
@ -568,7 +572,7 @@ impl AlbertForSequenceClassification {
/// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
/// let position_ids = Tensor::arange(sequence_length, (Int64, device)).expand(&[batch_size, sequence_length], true);
///
/// let (output, all_hidden_states, all_attentions) = no_grad(|| {
/// let classification_output = no_grad(|| {
/// albert_model
/// .forward_t(Some(input_tensor),
/// Some(mask),
@ -586,8 +590,8 @@ impl AlbertForSequenceClassification {
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool,
) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Vec<Tensor>>>) {
let (_, pooled_output, all_hidden_states, all_attentions) = self
) -> AlbertSequenceClassificationOutput {
let base_model_output = self
.albert
.forward_t(
input_ids,
@ -598,10 +602,15 @@ impl AlbertForSequenceClassification {
train,
)
.unwrap();
let logits = pooled_output
let logits = base_model_output
.pooled_output
.apply_t(&self.dropout, train)
.apply(&self.classifier);
(logits, all_hidden_states, all_attentions)
AlbertSequenceClassificationOutput {
logits,
all_hidden_states: base_model_output.all_hidden_states,
all_attentions: base_model_output.all_attentions,
}
}
}
@ -681,9 +690,10 @@ impl AlbertForTokenClassification {
///
/// # Returns
///
/// * `output` - `Tensor` of shape (*batch size*, *sequence_length*, *num_labels*) containing the logits for each of the input tokens and classes
/// * `hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
/// * `attentions` - `Option<Vec<Vec<Tensor>>>` of length *num_hidden_layers* of nested length *inner_group_num* with shape (*batch size*, *sequence_length*, *hidden_size*)
/// * `AlbertTokenClassificationOutput` containing:
/// - `logits` - `Tensor` of shape (*batch size*, *sequence_length*, *num_labels*) containing the logits for each of the input tokens and classes
/// - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
/// - `all_attentions` - `Option<Vec<Vec<Tensor>>>` of length *num_hidden_layers* of nested length *inner_group_num* with shape (*batch size*, *sequence_length*, *hidden_size*)
///
/// # Example
///
@ -704,7 +714,7 @@ impl AlbertForTokenClassification {
/// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
/// let position_ids = Tensor::arange(sequence_length, (Int64, device)).expand(&[batch_size, sequence_length], true);
///
/// let (output, all_hidden_states, all_attentions) = no_grad(|| {
/// let model_output = no_grad(|| {
/// albert_model
/// .forward_t(Some(input_tensor),
/// Some(mask),
@ -722,8 +732,8 @@ impl AlbertForTokenClassification {
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool,
) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Vec<Tensor>>>) {
let (sequence_output, _, all_hidden_states, all_attentions) = self
) -> AlbertTokenClassificationOutput {
let base_model_output = self
.albert
.forward_t(
input_ids,
@ -734,10 +744,15 @@ impl AlbertForTokenClassification {
train,
)
.unwrap();
let logits = sequence_output
let logits = base_model_output
.hidden_state
.apply_t(&self.dropout, train)
.apply(&self.classifier);
(logits, all_hidden_states, all_attentions)
AlbertTokenClassificationOutput {
logits,
all_hidden_states: base_model_output.all_hidden_states,
all_attentions: base_model_output.all_attentions,
}
}
}
@ -806,10 +821,11 @@ impl AlbertForQuestionAnswering {
///
/// # Returns
///
/// * `start_scores` - `Tensor` of shape (*batch size*, *sequence_length*) containing the logits for start of the answer
/// * `end_scores` - `Tensor` of shape (*batch size*, *sequence_length*) containing the logits for end of the answer
/// * `hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
/// * `attentions` - `Option<Vec<Vec<Tensor>>>` of length *num_hidden_layers* of nested length *inner_group_num* with shape (*batch size*, *sequence_length*, *hidden_size*)
/// * `AlbertQuestionAnsweringOutput` containing:
/// - `start_logits` - `Tensor` of shape (*batch size*, *sequence_length*) containing the logits for start of the answer
/// - `end_logits` - `Tensor` of shape (*batch size*, *sequence_length*) containing the logits for end of the answer
/// - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
/// - `all_attentions` - `Option<Vec<Vec<Tensor>>>` of length *num_hidden_layers* of nested length *inner_group_num* with shape (*batch size*, *sequence_length*, *hidden_size*)
///
/// # Example
///
@ -830,7 +846,7 @@ impl AlbertForQuestionAnswering {
/// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
/// let position_ids = Tensor::arange(sequence_length, (Int64, device)).expand(&[batch_size, sequence_length], true);
///
/// let (start_logits, end_logits, all_hidden_states, all_attentions) = no_grad(|| {
/// let model_output = no_grad(|| {
/// albert_model
/// .forward_t(Some(input_tensor),
/// Some(mask),
@ -848,13 +864,8 @@ impl AlbertForQuestionAnswering {
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool,
) -> (
Tensor,
Tensor,
Option<Vec<Tensor>>,
Option<Vec<Vec<Tensor>>>,
) {
let (sequence_output, _, all_hidden_states, all_attentions) = self
) -> AlbertQuestionAnsweringOutput {
let base_model_output = self
.albert
.forward_t(
input_ids,
@ -865,12 +876,20 @@ impl AlbertForQuestionAnswering {
train,
)
.unwrap();
let logits = sequence_output.apply(&self.qa_outputs).split(1, -1);
let logits = base_model_output
.hidden_state
.apply(&self.qa_outputs)
.split(1, -1);
let (start_logits, end_logits) = (&logits[0], &logits[1]);
let start_logits = start_logits.squeeze1(-1);
let end_logits = end_logits.squeeze1(-1);
(start_logits, end_logits, all_hidden_states, all_attentions)
AlbertQuestionAnsweringOutput {
start_logits,
end_logits,
all_hidden_states: base_model_output.all_hidden_states,
all_attentions: base_model_output.all_attentions,
}
}
}
@ -946,9 +965,10 @@ impl AlbertForMultipleChoice {
///
/// # Returns
///
/// * `output` - `Tensor` of shape (*1*, *batch size*) containing the logits for each of the alternatives given
/// * `hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
/// * `attentions` - `Option<Vec<Vec<Tensor>>>` of length *num_hidden_layers* of nested length *inner_group_num* with shape (*batch size*, *sequence_length*, *hidden_size*)
/// * `AlbertSequenceClassificationOutput` containing:
/// - `logits` - `Tensor` of shape (*1*, *batch size*) containing the logits for each of the alternatives given
/// - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
/// - `all_attentions` - `Option<Vec<Vec<Tensor>>>` of length *num_hidden_layers* of nested length *inner_group_num* with shape (*batch size*, *sequence_length*, *hidden_size*)
///
/// # Example
///
@ -969,7 +989,7 @@ impl AlbertForMultipleChoice {
/// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
/// let position_ids = Tensor::arange(sequence_length, (Int64, device)).expand(&[batch_size, sequence_length], true);
///
/// let (output, all_hidden_states, all_attentions) = no_grad(|| {
/// let model_output = no_grad(|| {
/// albert_model
/// .forward_t(Some(input_tensor),
/// Some(mask),
@ -987,11 +1007,13 @@ impl AlbertForMultipleChoice {
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool,
) -> Result<(Tensor, Option<Vec<Tensor>>, Option<Vec<Vec<Tensor>>>), &'static str> {
) -> Result<AlbertSequenceClassificationOutput, RustBertError> {
let (input_ids, input_embeds, num_choices) = match &input_ids {
Some(input_value) => match &input_embeds {
Some(_) => {
return Err("Only one of input ids or input embeddings may be set");
return Err(RustBertError::ValueError(
"Only one of input ids or input embeddings may be set".into(),
));
}
None => (
Some(input_value.view((-1, *input_value.size().last().unwrap()))),
@ -1006,7 +1028,9 @@ impl AlbertForMultipleChoice {
embeds.size()[1],
),
None => {
return Err("At least one of input ids or input embeddings must be set");
return Err(RustBertError::ValueError(
"At least one of input ids or input embeddings must be set".into(),
));
}
},
};
@ -1024,7 +1048,7 @@ impl AlbertForMultipleChoice {
None => None,
};
let (_, pooled_output, all_hidden_states, all_attentions) = self
let base_model_output = self
.albert
.forward_t(
input_ids,
@ -1035,11 +1059,70 @@ impl AlbertForMultipleChoice {
train,
)
.unwrap();
let logits = pooled_output
let logits = base_model_output
.pooled_output
.apply_t(&self.dropout, train)
.apply(&self.classifier)
.view((-1, num_choices));
Ok((logits, all_hidden_states, all_attentions))
Ok(AlbertSequenceClassificationOutput {
logits,
all_hidden_states: base_model_output.all_hidden_states,
all_attentions: base_model_output.all_attentions,
})
}
}
/// Container for the ALBERT model output.
pub struct AlbertOutput {
/// Last hidden states from the model
pub hidden_state: Tensor,
/// Pooled output (hidden state for the first token)
pub pooled_output: Tensor,
/// Hidden states for all intermediate layers
pub all_hidden_states: Option<Vec<Tensor>>,
/// Attention weights for all intermediate layers
pub all_attentions: Option<Vec<Vec<Tensor>>>,
}
/// Container for the ALBERT masked LM model output.
pub struct AlbertMaskedLMOutput {
/// Logits for the vocabulary items at each sequence position
pub prediction_scores: Tensor,
/// Hidden states for all intermediate layers
pub all_hidden_states: Option<Vec<Tensor>>,
/// Attention weights for all intermediate layers
pub all_attentions: Option<Vec<Vec<Tensor>>>,
}
/// Container for the ALBERT sequence classification model
pub struct AlbertSequenceClassificationOutput {
/// Logits for each input (sequence) for each target class
pub logits: Tensor,
/// Hidden states for all intermediate layers
pub all_hidden_states: Option<Vec<Tensor>>,
/// Attention weights for all intermediate layers
pub all_attentions: Option<Vec<Vec<Tensor>>>,
}
/// Container for the ALBERT token classification model
pub struct AlbertTokenClassificationOutput {
/// Logits for each sequence item (token) for each target class
pub logits: Tensor,
/// Hidden states for all intermediate layers
pub all_hidden_states: Option<Vec<Tensor>>,
/// Attention weights for all intermediate layers
pub all_attentions: Option<Vec<Vec<Tensor>>>,
}
/// Container for the ALBERT question answering model
pub struct AlbertQuestionAnsweringOutput {
/// Logits for the start position for token of each input sequence
pub start_logits: Tensor,
/// Logits for the end position for token of each input sequence
pub end_logits: Tensor,
/// Hidden states for all intermediate layers
pub all_hidden_states: Option<Vec<Tensor>>,
/// Attention weights for all intermediate layers
pub all_attentions: Option<Vec<Vec<Tensor>>>,
}

View File

@ -13,6 +13,7 @@
use crate::albert::AlbertConfig;
use crate::common::dropout::Dropout;
use crate::RustBertError;
use std::borrow::Borrow;
use tch::nn::{embedding, EmbeddingConfig};
use tch::{nn, Kind, Tensor};
@ -90,11 +91,13 @@ impl AlbertEmbeddings {
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool,
) -> Result<Tensor, &'static str> {
) -> Result<Tensor, RustBertError> {
let (input_embeddings, input_shape) = match input_ids {
Some(input_value) => match input_embeds {
Some(_) => {
return Err("Only one of input ids or input embeddings may be set");
return Err(RustBertError::ValueError(
"Only one of input ids or input embeddings may be set".into(),
));
}
None => (
input_value.apply_t(&self.word_embeddings, train),
@ -107,7 +110,9 @@ impl AlbertEmbeddings {
(embeds, size)
}
None => {
return Err("Only one of input ids or input embeddings may be set");
return Err(RustBertError::ValueError(
"At least one of input ids or input embeddings must be set".into(),
));
}
},
};

View File

@ -11,7 +11,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use crate::albert::albert::Activation;
use crate::albert::albert_model::Activation;
use crate::albert::attention::AlbertSelfAttention;
use crate::albert::AlbertConfig;
use crate::common::activations::{_gelu, _gelu_new, _mish, _relu};
@ -149,22 +149,17 @@ impl AlbertLayerGroup {
let mut hidden_state = hidden_states.copy();
let mut attention_weights: Option<Tensor>;
let mut layers = self.layers.iter();
loop {
match layers.next() {
Some(layer) => {
if let Some(hidden_states) = all_hidden_states.borrow_mut() {
hidden_states.push(hidden_state.as_ref().copy());
};
let temp = layer.forward_t(&hidden_state, &mask, train);
hidden_state = temp.0;
attention_weights = temp.1;
if let Some(attentions) = all_attentions.borrow_mut() {
attentions.push(attention_weights.as_ref().unwrap().copy());
};
}
None => break,
for layer in &self.layers {
if let Some(hidden_states) = all_hidden_states.borrow_mut() {
hidden_states.push(hidden_state.as_ref().copy());
};
let temp = layer.forward_t(&hidden_state, &mask, train);
hidden_state = temp.0;
attention_weights = temp.1;
if let Some(attentions) = all_attentions.borrow_mut() {
attentions.push(attention_weights.as_ref().unwrap().copy());
};
}
@ -226,7 +221,7 @@ impl AlbertTransformer {
hidden_states: &Tensor,
mask: Option<Tensor>,
train: bool,
) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Vec<Tensor>>>) {
) -> AlbertTransformerOutput {
let mut hidden_state = hidden_states.apply(&self.embedding_hidden_mapping_in);
let mut all_hidden_states: Option<Vec<Tensor>> = if self.output_hidden_states {
@ -256,6 +251,20 @@ impl AlbertTransformer {
};
}
(hidden_state, all_hidden_states, all_attentions)
AlbertTransformerOutput {
hidden_state,
all_hidden_states,
all_attentions,
}
}
}
/// Container holding the ALBERT transformer output
pub struct AlbertTransformerOutput {
/// Last hidden states of the transformer
pub hidden_state: Tensor,
/// Hidden states for all intermediate layers
pub all_hidden_states: Option<Vec<Tensor>>,
/// Attention weights for all intermediate layers. As layers in ALBERT can be made of a number of sub-layers, a vector of vector is used to store al of the attentions
pub all_attentions: Option<Vec<Vec<Tensor>>>,
}

View File

@ -11,7 +11,7 @@
//!
//! # Model set-up and pre-trained weights loading
//!
//! A full working example is provided in `examples/albert.rs`, run with `cargo run --example albert`.
//! A full working example is provided in `examples/albert`, run with `cargo run --example albert`.
//! The example below illustrate a Masked language model example, the structure is similar for other models.
//! All models expect the following resources:
//! - Configuration file expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers)
@ -53,13 +53,15 @@
//! # }
//! ```
mod albert;
mod albert_model;
mod attention;
mod embeddings;
mod encoder;
pub use albert::{
pub use albert_model::{
AlbertConfig, AlbertConfigResources, AlbertForMaskedLM, AlbertForMultipleChoice,
AlbertForQuestionAnswering, AlbertForSequenceClassification, AlbertForTokenClassification,
AlbertModel, AlbertModelResources, AlbertVocabResources,
AlbertMaskedLMOutput, AlbertModel, AlbertModelResources, AlbertOutput,
AlbertQuestionAnsweringOutput, AlbertSequenceClassificationOutput,
AlbertTokenClassificationOutput, AlbertVocabResources,
};

View File

@ -13,10 +13,10 @@
use crate::bart::attention::LayerState;
use crate::bart::decoder::BartDecoder;
use crate::bart::encoder::BartEncoder;
use crate::bart::encoder::{BartEncoder, BartEncoderOutput};
use crate::common::dropout::Dropout;
use crate::pipelines::generation::{Cache, LMHeadModel};
use crate::Config;
use crate::pipelines::generation::{Cache, LMHeadModel, LMModelOutput};
use crate::{Config, RustBertError};
use serde::{Deserialize, Serialize};
use std::borrow::Borrow;
use std::collections::HashMap;
@ -315,14 +315,14 @@ impl BartModel {
///
/// # Returns
///
/// * `decoder_output` - `Tensor` of shape (*batch size*, *target_sequence_length*, *hidden_size*) representing the activations of the last decoder hidden state
/// * `encoder_hidden_states` - `Tensor` of shape (*batch size*, *source_sequence_length*, *hidden_size*) representing the activations of the last encoder hidden state
/// * `decoder_cache` - `(Option<Tensor>, Option<Vec<&LayerState, &LayerState>>)` of length *n_layer* containing the encoder padding mask and past keys and values for
/// both the self attention and the encoder cross attention of each layer of the decoder.
/// * `all_encoder_hidden_states` - `Option<Vec<Tensor>>` of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*)
/// * `all_encoder_attentions` - `Option<Vec<Tensor>>` of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*)
/// * `all_decoder_hidden_states` - `Option<Vec<Tensor>>` of length *num_decoder_layers* with shape (*batch size*, *target_sequence_length*, *hidden_size*)
/// * `all_decoder_attentions` - `Option<Vec<Tensor>>` of length *num_decoder_layers* with shape (*batch size*, *target_sequence_length*, *hidden_size*)
/// * `BartModelOutput` containing:
/// - `decoder_output` - `Tensor` of shape (*batch size*, *target_sequence_length*, *hidden_size*) representing the activations of the last decoder hidden state
/// - `encoder_hidden_states` - `Tensor` of shape (*batch size*, *source_sequence_length*, *hidden_size*) representing the activations of the last encoder hidden state
/// - `cache` - `(Option<Tensor>, Option<Vec<&LayerState, &LayerState>>)` of length *n_layer* containing the encoder padding mask and past keys and values for both the self attention and the encoder cross attention of each layer of the decoder.
/// - `all_encoder_hidden_states` - `Option<Vec<Tensor>>` of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*)
/// - `all_encoder_attentions` - `Option<Vec<Tensor>>` of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*)
/// - `all_decoder_hidden_states` - `Option<Vec<Tensor>>` of length *num_decoder_layers* with shape (*batch size*, *target_sequence_length*, *hidden_size*)
/// - `all_decoder_attentions` - `Option<Vec<Tensor>>` of length *num_decoder_layers* with shape (*batch size*, *target_sequence_length*, *hidden_size*)
///
/// # Example
///
@ -346,15 +346,7 @@ impl BartModel {
/// let decoder_attention_mask =
/// Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
///
/// let (
/// decoder_output,
/// encoder_hidden_states,
/// decoder_cache,
/// all_encoder_hidden_states,
/// all_encoder_attentions,
/// all_decoder_hidden_states,
/// all_decoder_attentions,
/// ) = no_grad(|| {
/// let model_output = no_grad(|| {
/// bart_model.forward_t(
/// Some(&input_tensor),
/// Some(&encoder_attention_mask),
@ -371,19 +363,11 @@ impl BartModel {
input_ids: Option<&Tensor>,
attention_mask: Option<&Tensor>,
decoder_input_ids: Option<&Tensor>,
encoder_outputs: Option<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>)>,
encoder_output: Option<BartEncoderOutput>,
decoder_attention_mask: Option<&Tensor>,
layer_states: Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
train: bool,
) -> (
Tensor,
Tensor,
Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
Option<Vec<Tensor>>,
Option<Vec<Tensor>>,
Option<Vec<Tensor>>,
Option<Vec<Tensor>>,
) {
) -> BartModelOutput {
let (decoder_input_ids, decoder_padding_mask, causal_mask) = if self.generation_mode {
(decoder_input_ids.unwrap().copy(), None, None)
} else {
@ -398,43 +382,37 @@ impl BartModel {
decoder_attention_mask,
)
};
let (encoder_hidden_states, all_encoder_hidden_states, all_encoder_attentions) =
match encoder_outputs {
Some(value) => value,
None => {
assert!(
input_ids.is_some(),
"input_ids must be provided when encoder output is not pre-computed"
);
self.encoder.forward_t(
input_ids.unwrap(),
attention_mask,
&self.embeddings,
train,
)
}
};
let encoder_output = match encoder_output {
Some(value) => value,
None => {
assert!(
input_ids.is_some(),
"input_ids must be provided when encoder output is not pre-computed"
);
self.encoder
.forward_t(input_ids.unwrap(), attention_mask, &self.embeddings, train)
}
};
let (decoder_outputs, decoder_cache, all_decoder_hidden_states, all_decoder_attentions) =
self.decoder.forward_t(
&decoder_input_ids,
&encoder_hidden_states,
attention_mask,
decoder_padding_mask.as_ref(),
causal_mask.as_ref(),
&self.embeddings,
layer_states,
train,
);
(
decoder_outputs,
encoder_hidden_states,
decoder_cache.1,
all_decoder_hidden_states,
all_decoder_attentions,
all_encoder_hidden_states,
all_encoder_attentions,
)
let decoder_output = self.decoder.forward_t(
&decoder_input_ids,
&encoder_output.hidden_state,
attention_mask,
decoder_padding_mask.as_ref(),
causal_mask.as_ref(),
&self.embeddings,
layer_states,
train,
);
BartModelOutput {
decoder_output: decoder_output.hidden_state,
encoder_hidden_state: encoder_output.hidden_state,
cache: decoder_output.next_decoder_cache,
all_decoder_hidden_states: decoder_output.all_hidden_states,
all_decoder_attentions: decoder_output.all_attentions,
all_encoder_hidden_states: encoder_output.all_hidden_states,
all_encoder_attentions: encoder_output.all_attentions,
}
}
}
@ -498,12 +476,14 @@ impl BartForConditionalGeneration {
///
/// # Returns
///
/// * `lm_logits` - `Tensor` of shape (*batch size*, *target_sequence_length*, *vocab_size*) representing the logits for each vocab item and position
/// * `encoder_hidden_states` - `Tensor` of shape (*batch size*, *source_sequence_length*, *hidden_size*) representing the activations of the last encoder hidden state
/// * `all_encoder_hidden_states` - `Option<Vec<Tensor>>` of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*)
/// * `all_encoder_attentions` - `Option<Vec<Tensor>>` of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*)
/// * `all_decoder_hidden_states` - `Option<Vec<Tensor>>` of length *num_decoder_layers* with shape (*batch size*, *target_sequence_length*, *hidden_size*)
/// * `all_decoder_attentions` - `Option<Vec<Tensor>>` of length *num_decoder_layers* with shape (*batch size*, *target_sequence_length*, *hidden_size*)
/// * `BartModelOutput` containing:
/// - `decoder_output` - `Tensor` of shape (*batch size*, *target_sequence_length*, *vocab_size*) representing the logits for each vocabulary item and position
/// - `encoder_hidden_states` - `Tensor` of shape (*batch size*, *source_sequence_length*, *hidden_size*) representing the activations of the last encoder hidden state
/// - `cache` - `(Option<Tensor>, Option<Vec<&LayerState, &LayerState>>)` of length *n_layer* containing the encoder padding mask and past keys and values for both the self attention and the encoder cross attention of each layer of the decoder.
/// - `all_encoder_hidden_states` - `Option<Vec<Tensor>>` of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*)
/// - `all_encoder_attentions` - `Option<Vec<Tensor>>` of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*)
/// - `all_decoder_hidden_states` - `Option<Vec<Tensor>>` of length *num_decoder_layers* with shape (*batch size*, *target_sequence_length*, *hidden_size*)
/// - `all_decoder_attentions` - `Option<Vec<Tensor>>` of length *num_decoder_layers* with shape (*batch size*, *target_sequence_length*, *hidden_size*)
///
/// # Example
///
@ -525,9 +505,7 @@ impl BartForConditionalGeneration {
/// let encoder_attention_mask = Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
/// let decoder_attention_mask = Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
///
/// let (decoder_output, encoder_hidden_states, cache,
/// all_encoder_hidden_states, all_encoder_attentions,
/// all_decoder_hidden_states, all_decoder_attentions) = no_grad(|| {
/// let model_output = no_grad(|| {
/// bart_model
/// .forward_t(Some(&input_tensor),
/// Some(&encoder_attention_mask),
@ -542,58 +520,41 @@ impl BartForConditionalGeneration {
&self,
input_ids: Option<&Tensor>,
attention_mask: Option<&Tensor>,
encoder_outputs: Option<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>)>,
encoder_output: Option<BartEncoderOutput>,
decoder_input_ids: Option<&Tensor>,
decoder_attention_mask: Option<&Tensor>,
old_layer_states: Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
train: bool,
) -> (
Tensor,
Tensor,
Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
Option<Vec<Tensor>>,
Option<Vec<Tensor>>,
Option<Vec<Tensor>>,
Option<Vec<Tensor>>,
) {
let (
decoder_outputs,
encoder_hidden_states,
decoder_cache,
all_decoder_hidden_states,
all_decoder_attentions,
all_encoder_hidden_states,
all_encoder_attentions,
) = self.base_model.forward_t(
) -> BartModelOutput {
let base_model_output = self.base_model.forward_t(
input_ids,
attention_mask,
decoder_input_ids,
encoder_outputs,
encoder_output,
decoder_attention_mask,
old_layer_states,
train,
);
let lm_logits = decoder_outputs.linear::<Tensor>(&self.base_model.embeddings.ws, None);
(
lm_logits,
encoder_hidden_states,
decoder_cache,
all_decoder_hidden_states,
all_decoder_attentions,
all_encoder_hidden_states,
all_encoder_attentions,
)
let lm_logits = base_model_output
.decoder_output
.linear::<Tensor>(&self.base_model.embeddings.ws, None);
BartModelOutput {
decoder_output: lm_logits,
..base_model_output
}
}
pub fn encode(&self, input_ids: &Tensor, attention_mask: Option<&Tensor>) -> Tensor {
let (encoder_hidden_states, _, _) = self.base_model.encoder.forward_t(
input_ids,
attention_mask,
&self.base_model.embeddings,
false,
);
encoder_hidden_states
self.base_model
.encoder
.forward_t(
input_ids,
attention_mask,
&self.base_model.embeddings,
false,
)
.hidden_state
}
}
@ -713,12 +674,14 @@ impl BartForSequenceClassification {
///
/// # Returns
///
/// * `logits` - `Tensor` of shape (*batch size*, *num_classes*) representing the logits for each class item and batch item
/// * `encoder_hidden_states` - `Tensor` of shape (*batch size*, *source_sequence_length*, *hidden_size*) representing the activations of the last encoder hidden state
/// * `all_encoder_hidden_states` - `Option<Vec<Tensor>>` of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*)
/// * `all_encoder_attentions` - `Option<Vec<Tensor>>` of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*)
/// * `all_decoder_hidden_states` - `Option<Vec<Tensor>>` of length *num_decoder_layers* with shape (*batch size*, *target_sequence_length*, *hidden_size*)
/// * `all_decoder_attentions` - `Option<Vec<Tensor>>` of length *num_decoder_layers* with shape (*batch size*, *target_sequence_length*, *hidden_size*)
/// * `BartModelOutput` containing:
/// - `decoder_output` - `Tensor` of shape (*batch size*, *num_classes*) representing the activations for each class and batch item
/// - `encoder_hidden_states` - `Tensor` of shape (*batch size*, *source_sequence_length*, *hidden_size*) representing the activations of the last encoder hidden state
/// - `cache` - `(Option<Tensor>, Option<Vec<&LayerState, &LayerState>>)` of length *n_layer* containing the encoder padding mask and past keys and values for both the self attention and the encoder cross attention of each layer of the decoder.
/// - `all_encoder_hidden_states` - `Option<Vec<Tensor>>` of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*)
/// - `all_encoder_attentions` - `Option<Vec<Tensor>>` of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*)
/// - `all_decoder_hidden_states` - `Option<Vec<Tensor>>` of length *num_decoder_layers* with shape (*batch size*, *target_sequence_length*, *hidden_size*)
/// - `all_decoder_attentions` - `Option<Vec<Tensor>>` of length *num_decoder_layers* with shape (*batch size*, *target_sequence_length*, *hidden_size*)
///
/// # Example
///
@ -740,9 +703,7 @@ impl BartForSequenceClassification {
/// let encoder_attention_mask = Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
/// let decoder_attention_mask = Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
///
/// let (decoder_output, encoder_hidden_states, cache,
/// all_encoder_hidden_states, all_encoder_attentions,
/// all_decoder_hidden_states, all_decoder_attentions) = no_grad(|| {
/// let model_output = no_grad(|| {
/// bart_model
/// .forward_t(Some(&input_tensor),
/// Some(&encoder_attention_mask),
@ -757,60 +718,47 @@ impl BartForSequenceClassification {
&self,
input_ids: &Tensor,
attention_mask: Option<&Tensor>,
encoder_outputs: Option<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>)>,
encoder_output: Option<BartEncoderOutput>,
decoder_input_ids: Option<&Tensor>,
decoder_attention_mask: Option<&Tensor>,
train: bool,
) -> (
Tensor,
Tensor,
Option<Vec<Tensor>>,
Option<Vec<Tensor>>,
Option<Vec<Tensor>>,
Option<Vec<Tensor>>,
) {
let (
decoder_outputs,
encoder_hidden_states,
_,
all_decoder_hidden_states,
all_decoder_attentions,
all_encoder_hidden_states,
all_encoder_attentions,
) = self.base_model.forward_t(
) -> BartModelOutput {
let base_model_output = self.base_model.forward_t(
Some(input_ids),
attention_mask,
decoder_input_ids,
encoder_outputs,
encoder_output,
decoder_attention_mask,
None,
train,
);
let eos_mask = input_ids.eq(self.eos_token_id);
let reshape = eos_mask.sum1(&[1], true, Int64);
let sentence_representation = decoder_outputs
let sentence_representation = base_model_output
.decoder_output
.permute(&[2, 0, 1])
.masked_select(&eos_mask)
.view((-1, reshape.size()[0] * reshape.int64_value(&[0, 0])))
.transpose(0, 1)
.view((
decoder_outputs.size()[0],
base_model_output.decoder_output.size()[0],
-1,
*decoder_outputs.size().last().unwrap(),
*base_model_output.decoder_output.size().last().unwrap(),
))
.select(1, -1);
let logits = self
.classification_head
.forward_t(&sentence_representation, train);
(
logits,
encoder_hidden_states,
all_decoder_hidden_states,
all_decoder_attentions,
all_encoder_hidden_states,
all_encoder_attentions,
)
BartModelOutput {
decoder_output: logits,
encoder_hidden_state: base_model_output.encoder_hidden_state,
cache: None,
all_decoder_hidden_states: base_model_output.all_decoder_hidden_states,
all_decoder_attentions: base_model_output.all_decoder_attentions,
all_encoder_hidden_states: base_model_output.all_encoder_hidden_states,
all_encoder_attentions: base_model_output.all_encoder_attentions,
}
}
}
@ -832,12 +780,13 @@ impl LMHeadModel for BartForConditionalGeneration {
///
/// # Returns
///
/// * `lm_logits` - `Tensor` of shape (*batch size*, *sequence_length*, *vocab_size*) representing the logits for each vocab item and position
/// * `past` - `BartCache` made of `Option<Vec<(Option<Vec<&LayerState, &LayerState>>)>>` of length *n_layer* containing the encoder past keys and values for
/// * `LMModelOutput` containing:
/// - `lm_logits` - `Tensor` of shape (*batch size*, *sequence_length*, *vocab_size*) representing the logits for each vocab item and position
/// - `cache` - `BartCache` made of `Option<Vec<(Option<Vec<&LayerState, &LayerState>>)>>` of length *n_layer* containing the encoder past keys and values for
/// both the self attention and the encoder cross attention of each layer of the decoder.
/// * `encoder_hidden_states` - `Option<Tensor>` Hidden states for the encoder
/// * `hidden_states` - None
/// * `attentions` - None
/// - `encoder_hidden_states` - `Option<Tensor>` Hidden states for the encoder
/// - `all_hidden_states` - None
/// - `all_attentions` - None
///
/// # Example
///
@ -860,9 +809,7 @@ impl LMHeadModel for BartForConditionalGeneration {
/// let encoder_attention_mask = Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
/// let decoder_attention_mask = Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
///
/// let (decoder_output, encoder_hidden_states, cache,
/// all_encoder_hidden_states, all_encoder_attentions,
/// all_decoder_hidden_states, all_decoder_attentions) = no_grad(|| {
/// let model_output = no_grad(|| {
/// bart_model
/// .forward_t(Some(&input_tensor),
/// Some(&encoder_attention_mask),
@ -884,22 +831,17 @@ impl LMHeadModel for BartForConditionalGeneration {
encoder_outputs: Option<&Tensor>,
decoder_input_ids: &Option<Tensor>,
train: bool,
) -> Result<
(
Tensor,
Option<Tensor>,
Cache,
Option<Vec<Tensor>>,
Option<Vec<Tensor>>,
),
&'static str,
> {
let (decoder_output, encoder_hidden_states, new_cache, _, _, _, _) = match cache {
) -> Result<LMModelOutput, RustBertError> {
let base_model_output = match cache {
Cache::BARTCache(cached_layer_states) => self.base_model.forward_t(
input_ids.as_ref(),
attention_mask.as_ref(),
decoder_input_ids.as_ref(),
Some((encoder_outputs.as_ref().unwrap().copy(), None, None)),
Some(BartEncoderOutput {
hidden_state: encoder_outputs.as_ref().unwrap().copy(),
all_hidden_states: None,
all_attentions: None,
}),
None,
cached_layer_states,
train,
@ -909,21 +851,52 @@ impl LMHeadModel for BartForConditionalGeneration {
input_ids.as_ref(),
attention_mask.as_ref(),
decoder_input_ids.as_ref(),
Some((encoder_outputs.as_ref().unwrap().copy(), None, None)),
Some(BartEncoderOutput {
hidden_state: encoder_outputs.as_ref().unwrap().copy(),
all_hidden_states: None,
all_attentions: None,
}),
None,
None,
train,
),
_ => Err("Cache not compatible with BART Model")?,
_ => {
return Err(RustBertError::ValueError(
"Cache not compatible with BART Model".into(),
));
}
};
let lm_logits = decoder_output.linear::<Tensor>(&self.base_model.embeddings.ws, None);
Ok((
let lm_logits = base_model_output
.decoder_output
.linear::<Tensor>(&self.base_model.embeddings.ws, None);
Ok(LMModelOutput {
lm_logits,
Some(encoder_hidden_states),
Cache::BARTCache(new_cache),
None,
None,
))
encoder_hidden_state: Some(base_model_output.encoder_hidden_state),
cache: Cache::BARTCache(base_model_output.cache),
all_hidden_states: None,
all_attentions: None,
})
}
}
/// Container holding a BART model output. The decoder output may hold the hidden state of
/// the last layer of the decoder, or may hold logits for a custom head module after the
/// decoder (e.g. for classification or language modeling tasks)
pub struct BartModelOutput {
/// Hidden state of the last layer of the decoder, or logits for a custom head
/// module after the decoder (e.g. for classification or language modeling tasks)
pub decoder_output: Tensor,
/// Hidden state for the last layer of the encoder
pub encoder_hidden_state: Tensor,
/// Cached outputs of the model (attention layers keys and values) if the model is used for generation
pub cache: Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
/// Hidden states for all layers of the decoder
pub all_decoder_hidden_states: Option<Vec<Tensor>>,
/// Attention weights for all layers of the decoder
pub all_decoder_attentions: Option<Vec<Tensor>>,
/// Hidden states for all layers of the encoder
pub all_encoder_hidden_states: Option<Vec<Tensor>>,
/// Attention weights for all layers of the encoder
pub all_encoder_attentions: Option<Vec<Tensor>>,
}

View File

@ -12,7 +12,7 @@
// limitations under the License.
use crate::bart::attention::{LayerState, SelfAttention};
use crate::bart::bart::Activation;
use crate::bart::bart_model::Activation;
use crate::bart::embeddings::{
EmbeddingOption, LearnedPositionalEmbedding, SinusoidalPositionalEmbedding,
};
@ -288,15 +288,7 @@ impl BartDecoder {
embeddings: &nn::Embedding,
old_layer_states: Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
train: bool,
) -> (
Tensor,
(
Option<Tensor>,
Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
),
Option<Vec<Tensor>>,
Option<Vec<Tensor>>,
) {
) -> BartDecoderOutput {
let encoder_padding_mask = match encoder_padding_mask {
Some(mask) => Some(mask.eq(0).to_kind(Bool)),
None => None,
@ -342,45 +334,54 @@ impl BartDecoder {
};
let encoder_hidden_states = encoder_hidden_states.transpose(0, 1);
let mut attention_weights: Option<Tensor>;
let mut layers = self.layers.iter().enumerate();
loop {
match layers.next() {
Some((layer_idx, layer)) => {
let layer_state = match &next_decoder_cache {
Some(values) => values[layer_idx].to_owned(),
None => (None, None),
};
let temp = layer.forward_t(
&hidden_state,
&encoder_hidden_states,
encoder_padding_mask.as_ref(),
decoder_causal_mask,
decoder_padding_mask,
layer_state,
train,
);
hidden_state = temp.0;
attention_weights = temp.1;
if let Some(hidden_states) = all_hidden_states.borrow_mut() {
hidden_states.push(hidden_state.as_ref().copy().transpose(0, 1));
};
if let Some(attentions) = all_attentions.borrow_mut() {
attentions.push(attention_weights.as_ref().unwrap().copy());
};
if let Some(value) = &mut next_decoder_cache {
value[layer_idx] = temp.2
};
}
None => break,
for (layer_idx, layer) in self.layers.iter().enumerate() {
let layer_state = match &next_decoder_cache {
Some(values) => values[layer_idx].to_owned(),
None => (None, None),
};
let temp = layer.forward_t(
&hidden_state,
&encoder_hidden_states,
encoder_padding_mask.as_ref(),
decoder_causal_mask,
decoder_padding_mask,
layer_state,
train,
);
hidden_state = temp.0;
attention_weights = temp.1;
if let Some(hidden_states) = all_hidden_states.borrow_mut() {
hidden_states.push(hidden_state.as_ref().copy().transpose(0, 1));
};
if let Some(attentions) = all_attentions.borrow_mut() {
attentions.push(attention_weights.as_ref().unwrap().copy());
};
if let Some(value) = &mut next_decoder_cache {
value[layer_idx] = temp.2
};
}
(
hidden_state.transpose(0, 1),
(encoder_padding_mask, next_decoder_cache),
BartDecoderOutput {
hidden_state: hidden_state.transpose(0, 1),
encoder_padding_mask,
next_decoder_cache,
all_hidden_states,
all_attentions,
)
}
}
}
///Container holding a BART decoder output
pub struct BartDecoderOutput {
/// last decoder layer hidden state
pub hidden_state: Tensor,
/// Padding mask for the encoder positions to attend to
pub encoder_padding_mask: Option<Tensor>,
/// Cached outputs of the model (attention layers keys and values) if the model is used for generation
pub next_decoder_cache: Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
/// Hidden states for all intermediate layers
pub all_hidden_states: Option<Vec<Tensor>>,
/// Attention weights for all intermediate layers
pub all_attentions: Option<Vec<Tensor>>,
}

View File

@ -12,7 +12,7 @@
// limitations under the License.
use crate::bart::attention::SelfAttention;
use crate::bart::bart::Activation;
use crate::bart::bart_model::Activation;
use crate::bart::embeddings::{
EmbeddingOption, LearnedPositionalEmbedding, SinusoidalPositionalEmbedding,
};
@ -232,7 +232,7 @@ impl BartEncoder {
attention_mask: Option<&Tensor>,
embeddings: &nn::Embedding,
train: bool,
) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
) -> BartEncoderOutput {
let attention_mask = match attention_mask {
Some(mask) => Some(mask.eq(0).to_kind(Bool)),
None => None,
@ -260,33 +260,38 @@ impl BartEncoder {
let mut hidden_state = x.copy();
let mut attention_weights: Option<Tensor>;
let mut layers = self.layers.iter();
loop {
match layers.next() {
Some(layer) => {
if let Some(hidden_states) = all_hidden_states.borrow_mut() {
hidden_states.push(hidden_state.as_ref().copy().transpose(0, 1));
};
for layer in &self.layers {
if let Some(hidden_states) = all_hidden_states.borrow_mut() {
hidden_states.push(hidden_state.as_ref().copy().transpose(0, 1));
};
let temp = layer.forward_t(&hidden_state, attention_mask.as_ref(), train);
hidden_state = temp.0;
attention_weights = temp.1;
if let Some(attentions) = all_attentions.borrow_mut() {
attentions.push(attention_weights.as_ref().unwrap().copy());
};
}
None => break,
let temp = layer.forward_t(&hidden_state, attention_mask.as_ref(), train);
hidden_state = temp.0;
attention_weights = temp.1;
if let Some(attentions) = all_attentions.borrow_mut() {
attentions.push(attention_weights.as_ref().unwrap().copy());
};
}
if let Some(hidden_states) = all_hidden_states.borrow_mut() {
hidden_states.push(hidden_state.as_ref().copy().transpose(0, 1));
};
(
hidden_state.transpose(0, 1),
BartEncoderOutput {
hidden_state: hidden_state.transpose(0, 1),
all_hidden_states,
all_attentions,
)
}
}
}
/// Container holding a BART encoder output
pub struct BartEncoderOutput {
/// Last encoder layer hidden state
pub hidden_state: Tensor,
/// Hidden states for all intermediate layers
pub all_hidden_states: Option<Vec<Tensor>>,
/// Attention weights for all intermediate layers
pub all_attentions: Option<Vec<Tensor>>,
}

View File

@ -6,7 +6,7 @@
//!
//! # Model set-up and pre-trained weights loading
//!
//! A full working example is provided in `examples/bart.rs`, run with `cargo run --example bart`.
//! A full working example is provided in `examples/bart`, run with `cargo run --example bart`.
//! Alternatively, the summarization capabilities are illustrated in `examples/summarization.rs`, run with `cargo run --example summarization`.
//! All models expect the following resources:
//! - Configuration file expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers)
@ -58,14 +58,16 @@
//! ```
mod attention;
mod bart;
mod bart_model;
mod decoder;
mod embeddings;
mod encoder;
pub use attention::LayerState;
pub use bart::{
pub use bart_model::{
Activation, BartConfig, BartConfigResources, BartForConditionalGeneration,
BartForSequenceClassification, BartMergesResources, BartModel, BartModelResources,
BartVocabResources,
BartForSequenceClassification, BartMergesResources, BartModel, BartModelOutput,
BartModelResources, BartVocabResources,
};
pub(crate) use encoder::BartEncoderOutput;

View File

@ -11,7 +11,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use crate::bert::bert::{Activation, BertConfig};
use crate::bert::bert_model::{Activation, BertConfig};
use crate::common::activations::{_gelu, _mish, _relu};
use crate::common::dropout::Dropout;
use std::borrow::Borrow;
@ -86,7 +86,7 @@ impl BertSelfAttention {
fn flatten(&self, x: Tensor, bs: i64, dim_per_head: i64) -> Tensor {
x.transpose(1, 2)
.contiguous()
.view((bs, -1, &self.num_attention_heads * dim_per_head))
.view((bs, -1, self.num_attention_heads * dim_per_head))
}
pub fn forward_t(

View File

@ -16,7 +16,7 @@ use crate::bert::encoder::{BertEncoder, BertPooler};
use crate::common::activations::{_gelu, _mish, _relu};
use crate::common::dropout::Dropout;
use crate::common::linear::{linear_no_bias, LinearNoBias};
use crate::Config;
use crate::{Config, RustBertError};
use serde::{Deserialize, Serialize};
use std::borrow::Borrow;
use std::collections::HashMap;
@ -199,10 +199,11 @@ impl<T: BertEmbedding> BertModel<T> {
///
/// # Returns
///
/// * `output` - `Tensor` of shape (*batch size*, *sequence_length*, *hidden_size*)
/// * `pooled_output` - `Tensor` of shape (*batch size*, *hidden_size*)
/// * `hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
/// * `attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
/// * `BertOutput` containing:
/// - `hidden_state` - `Tensor` of shape (*batch size*, *sequence_length*, *hidden_size*)
/// - `pooled_output` - `Tensor` of shape (*batch size*, *hidden_size*)
/// - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
/// - `all_attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
///
/// # Example
///
@ -224,7 +225,7 @@ impl<T: BertEmbedding> BertModel<T> {
/// let position_ids = Tensor::arange(sequence_length, (Int64, device))
/// .expand(&[batch_size, sequence_length], true);
///
/// let (output, pooled_output, all_hidden_states, all_attentions) = no_grad(|| {
/// let model_output = no_grad(|| {
/// bert_model
/// .forward_t(
/// Some(input_tensor),
@ -249,18 +250,22 @@ impl<T: BertEmbedding> BertModel<T> {
encoder_hidden_states: &Option<Tensor>,
encoder_mask: &Option<Tensor>,
train: bool,
) -> Result<(Tensor, Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
) -> Result<BertModelOutput, RustBertError> {
let (input_shape, device) = match &input_ids {
Some(input_value) => match &input_embeds {
Some(_) => {
return Err("Only one of input ids or input embeddings may be set");
return Err(RustBertError::ValueError(
"Only one of input ids or input embeddings may be set".into(),
));
}
None => (input_value.size(), input_value.device()),
},
None => match &input_embeds {
Some(embeds) => (vec![embeds.size()[0], embeds.size()[1]], embeds.device()),
None => {
return Err("At least one of input ids or input embeddings must be set");
return Err(RustBertError::ValueError(
"At least one of input ids or input embeddings must be set".into(),
));
}
},
};
@ -275,7 +280,7 @@ impl<T: BertEmbedding> BertModel<T> {
2 => {
if self.is_decoder {
let seq_ids = Tensor::arange(input_shape[1], (Float, device));
let causal_mask = seq_ids.unsqueeze(0).unsqueeze(0).repeat(&vec![
let causal_mask = seq_ids.unsqueeze(0).unsqueeze(0).repeat(&[
input_shape[0],
input_shape[1],
1,
@ -287,7 +292,9 @@ impl<T: BertEmbedding> BertModel<T> {
}
}
_ => {
return Err("Invalid attention mask dimension, must be 2 or 3");
return Err(RustBertError::ValueError(
"Invalid attention mask dimension, must be 2 or 3".into(),
));
}
};
@ -312,7 +319,9 @@ impl<T: BertEmbedding> BertModel<T> {
2 => Some(encoder_mask.unsqueeze(1).unsqueeze(1)),
3 => Some(encoder_mask.unsqueeze(1)),
_ => {
return Err("Invalid encoder attention mask dimension, must be 2 or 3");
return Err(RustBertError::ValueError(
"Invalid attention mask dimension, must be 2 or 3".into(),
));
}
}
} else {
@ -342,12 +351,12 @@ impl<T: BertEmbedding> BertModel<T> {
let pooled_output = self.pooler.forward(&hidden_state);
Ok((
Ok(BertModelOutput {
hidden_state,
pooled_output,
all_hidden_states,
all_attentions,
))
})
}
}
@ -486,9 +495,10 @@ impl BertForMaskedLM {
///
/// # Returns
///
/// * `output` - `Tensor` of shape (*batch size*, *num_labels*, *vocab_size*)
/// * `hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
/// * `attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
/// * `BertMaskedLMOutput` containing:
/// - `prediction_scores` - `Tensor` of shape (*batch size*, *sequence_length*, *vocab_size*)
/// - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
/// - `all_attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
///
/// # Example
///
@ -510,7 +520,7 @@ impl BertForMaskedLM {
/// let position_ids = Tensor::arange(sequence_length, (Int64, device))
/// .expand(&[batch_size, sequence_length], true);
///
/// let (output, all_hidden_states, all_attentions) = no_grad(|| {
/// let model_output = no_grad(|| {
/// bert_model.forward_t(
/// Some(input_tensor),
/// Some(mask),
@ -533,8 +543,8 @@ impl BertForMaskedLM {
encoder_hidden_states: &Option<Tensor>,
encoder_mask: &Option<Tensor>,
train: bool,
) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
let (hidden_state, _, all_hidden_states, all_attentions) = self
) -> BertMaskedLMOutput {
let base_model_output = self
.bert
.forward_t(
input_ids,
@ -548,8 +558,12 @@ impl BertForMaskedLM {
)
.unwrap();
let prediction_scores = self.cls.forward(&hidden_state);
(prediction_scores, all_hidden_states, all_attentions)
let prediction_scores = self.cls.forward(&base_model_output.hidden_state);
BertMaskedLMOutput {
prediction_scores,
all_hidden_states: base_model_output.all_hidden_states,
all_attentions: base_model_output.all_attentions,
}
}
}
@ -626,9 +640,10 @@ impl BertForSequenceClassification {
///
/// # Returns
///
/// * `labels` - `Tensor` of shape (*batch size*, *num_labels*)
/// * `hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
/// * `attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
/// * `BertSequenceClassificationOutput` containing:
/// - `logits` - `Tensor` of shape (*batch size*, *num_labels*)
/// - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
/// - `all_attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
///
/// # Example
///
@ -650,7 +665,7 @@ impl BertForSequenceClassification {
/// let position_ids = Tensor::arange(sequence_length, (Int64, device))
/// .expand(&[batch_size, sequence_length], true);
///
/// let (labels, all_hidden_states, all_attentions) = no_grad(|| {
/// let model_output = no_grad(|| {
/// bert_model.forward_t(
/// Some(input_tensor),
/// Some(mask),
@ -669,8 +684,8 @@ impl BertForSequenceClassification {
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool,
) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
let (_, pooled_output, all_hidden_states, all_attentions) = self
) -> BertSequenceClassificationOutput {
let base_model_output = self
.bert
.forward_t(
input_ids,
@ -684,10 +699,15 @@ impl BertForSequenceClassification {
)
.unwrap();
let output = pooled_output
let logits = base_model_output
.pooled_output
.apply_t(&self.dropout, train)
.apply(&self.classifier);
(output, all_hidden_states, all_attentions)
BertSequenceClassificationOutput {
logits,
all_hidden_states: base_model_output.all_hidden_states,
all_attentions: base_model_output.all_attentions,
}
}
}
@ -755,9 +775,10 @@ impl BertForMultipleChoice {
///
/// # Returns
///
/// * `output` - `Tensor` of shape (*1*, *batch size*) containing the logits for each of the alternatives given
/// * `hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
/// * `attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
/// * `BertSequenceClassificationOutput` containing:
/// - `logits` - `Tensor` of shape (*1*, *batch size*) containing the logits for each of the alternatives given
/// - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
/// - `all_attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
///
/// # Example
///
@ -779,7 +800,7 @@ impl BertForMultipleChoice {
/// let position_ids = Tensor::arange(sequence_length, (Int64, device))
/// .expand(&[num_choices, sequence_length], true);
///
/// let (choices, all_hidden_states, all_attentions) = no_grad(|| {
/// let model_output = no_grad(|| {
/// bert_model.forward_t(
/// input_tensor,
/// Some(mask),
@ -796,7 +817,7 @@ impl BertForMultipleChoice {
token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>,
train: bool,
) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
) -> BertSequenceClassificationOutput {
let num_choices = input_ids.size()[1];
let input_ids = input_ids.view((-1, *input_ids.size().last().unwrap()));
@ -813,7 +834,7 @@ impl BertForMultipleChoice {
None => None,
};
let (_, pooled_output, all_hidden_states, all_attentions) = self
let base_model_output = self
.bert
.forward_t(
Some(input_ids),
@ -827,11 +848,16 @@ impl BertForMultipleChoice {
)
.unwrap();
let output = pooled_output
let logits = base_model_output
.pooled_output
.apply_t(&self.dropout, train)
.apply(&self.classifier)
.view((-1, num_choices));
(output, all_hidden_states, all_attentions)
BertSequenceClassificationOutput {
logits,
all_hidden_states: base_model_output.all_hidden_states,
all_attentions: base_model_output.all_attentions,
}
}
}
@ -909,9 +935,10 @@ impl BertForTokenClassification {
///
/// # Returns
///
/// * `output` - `Tensor` of shape (*batch size*, *sequence_length*, *num_labels*) containing the logits for each of the input tokens and classes
/// * `hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
/// * `attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
/// * `BertTokenClassificationOutput` containing:
/// - `logits` - `Tensor` of shape (*batch size*, *sequence_length*, *num_labels*) containing the logits for each of the input tokens and classes
/// - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
/// - `all_attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
///
/// # Example
///
@ -933,7 +960,7 @@ impl BertForTokenClassification {
/// let position_ids = Tensor::arange(sequence_length, (Int64, device))
/// .expand(&[batch_size, sequence_length], true);
///
/// let (token_labels, all_hidden_states, all_attentions) = no_grad(|| {
/// let model_output = no_grad(|| {
/// bert_model.forward_t(
/// Some(input_tensor),
/// Some(mask),
@ -952,8 +979,8 @@ impl BertForTokenClassification {
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool,
) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
let (hidden_state, _, all_hidden_states, all_attentions) = self
) -> BertTokenClassificationOutput {
let base_model_output = self
.bert
.forward_t(
input_ids,
@ -967,10 +994,15 @@ impl BertForTokenClassification {
)
.unwrap();
let sequence_output = hidden_state
let logits = base_model_output
.hidden_state
.apply_t(&self.dropout, train)
.apply(&self.classifier);
(sequence_output, all_hidden_states, all_attentions)
BertTokenClassificationOutput {
logits,
all_hidden_states: base_model_output.all_hidden_states,
all_attentions: base_model_output.all_attentions,
}
}
}
@ -1039,10 +1071,11 @@ impl BertForQuestionAnswering {
///
/// # Returns
///
/// * `start_scores` - `Tensor` of shape (*batch size*, *sequence_length*) containing the logits for start of the answer
/// * `end_scores` - `Tensor` of shape (*batch size*, *sequence_length*) containing the logits for end of the answer
/// * `hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
/// * `attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
/// * `BertQuestionAnsweringOutput` containing:
/// - `start_logits` - `Tensor` of shape (*batch size*, *sequence_length*) containing the logits for start of the answer
/// - `end_logits` - `Tensor` of shape (*batch size*, *sequence_length*) containing the logits for end of the answer
/// - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
/// - `all_attentions` - `Option<Vec<Vec<Tensor>>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
///
/// # Example
///
@ -1064,7 +1097,7 @@ impl BertForQuestionAnswering {
/// let position_ids = Tensor::arange(sequence_length, (Int64, device))
/// .expand(&[batch_size, sequence_length], true);
///
/// let (start_scores, end_scores, all_hidden_states, all_attentions) = no_grad(|| {
/// let model_output = no_grad(|| {
/// bert_model.forward_t(
/// Some(input_tensor),
/// Some(mask),
@ -1083,8 +1116,8 @@ impl BertForQuestionAnswering {
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool,
) -> (Tensor, Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
let (hidden_state, _, all_hidden_states, all_attentions) = self
) -> BertQuestionAnsweringOutput {
let base_model_output = self
.bert
.forward_t(
input_ids,
@ -1098,12 +1131,71 @@ impl BertForQuestionAnswering {
)
.unwrap();
let sequence_output = hidden_state.apply(&self.qa_outputs);
let sequence_output = base_model_output.hidden_state.apply(&self.qa_outputs);
let logits = sequence_output.split(1, -1);
let (start_logits, end_logits) = (&logits[0], &logits[1]);
let start_logits = start_logits.squeeze1(-1);
let end_logits = end_logits.squeeze1(-1);
(start_logits, end_logits, all_hidden_states, all_attentions)
BertQuestionAnsweringOutput {
start_logits,
end_logits,
all_hidden_states: base_model_output.all_hidden_states,
all_attentions: base_model_output.all_attentions,
}
}
}
/// Container for the BERT model output.
pub struct BertModelOutput {
/// Last hidden states from the model
pub hidden_state: Tensor,
/// Pooled output (hidden state for the first token)
pub pooled_output: Tensor,
/// Hidden states for all intermediate layers
pub all_hidden_states: Option<Vec<Tensor>>,
/// Attention weights for all intermediate layers
pub all_attentions: Option<Vec<Tensor>>,
}
/// Container for the BERT masked LM model output.
pub struct BertMaskedLMOutput {
/// Logits for the vocabulary items at each sequence position
pub prediction_scores: Tensor,
/// Hidden states for all intermediate layers
pub all_hidden_states: Option<Vec<Tensor>>,
/// Attention weights for all intermediate layers
pub all_attentions: Option<Vec<Tensor>>,
}
/// Container for the BERT sequence classification model output.
pub struct BertSequenceClassificationOutput {
/// Logits for each input (sequence) for each target class
pub logits: Tensor,
/// Hidden states for all intermediate layers
pub all_hidden_states: Option<Vec<Tensor>>,
/// Attention weights for all intermediate layers
pub all_attentions: Option<Vec<Tensor>>,
}
/// Container for the BERT token classification model output.
pub struct BertTokenClassificationOutput {
/// Logits for each sequence item (token) for each target class
pub logits: Tensor,
/// Hidden states for all intermediate layers
pub all_hidden_states: Option<Vec<Tensor>>,
/// Attention weights for all intermediate layers
pub all_attentions: Option<Vec<Tensor>>,
}
/// Container for the BERT question answering model output.
pub struct BertQuestionAnsweringOutput {
/// Logits for the start position for token of each input sequence
pub start_logits: Tensor,
/// Logits for the end position for token of each input sequence
pub end_logits: Tensor,
/// Hidden states for all intermediate layers
pub all_hidden_states: Option<Vec<Tensor>>,
/// Attention weights for all intermediate layers
pub all_attentions: Option<Vec<Tensor>>,
}

View File

@ -11,8 +11,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use crate::bert::bert::BertConfig;
use crate::bert::bert_model::BertConfig;
use crate::common::dropout::Dropout;
use crate::RustBertError;
use std::borrow::Borrow;
use tch::nn::{embedding, EmbeddingConfig};
use tch::{nn, Kind, Tensor};
@ -31,7 +32,7 @@ pub trait BertEmbedding {
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool,
) -> Result<Tensor, &'static str>;
) -> Result<Tensor, RustBertError>;
}
#[derive(Debug)]
@ -168,11 +169,13 @@ impl BertEmbedding for BertEmbeddings {
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool,
) -> Result<Tensor, &'static str> {
) -> Result<Tensor, RustBertError> {
let (input_embeddings, input_shape) = match input_ids {
Some(input_value) => match input_embeds {
Some(_) => {
return Err("Only one of input ids or input embeddings may be set");
return Err(RustBertError::ValueError(
"Only one of input ids or input embeddings may be set".into(),
));
}
None => (
input_value.apply_t(&self.word_embeddings, train),
@ -185,7 +188,9 @@ impl BertEmbedding for BertEmbeddings {
(embeds, size)
}
None => {
return Err("Only one of input ids or input embeddings may be set");
return Err(RustBertError::ValueError(
"At least one of input ids or input embeddings must be set".into(),
));
}
},
};

View File

@ -12,7 +12,7 @@
// limitations under the License.
use crate::bert::attention::{BertAttention, BertIntermediate, BertOutput};
use crate::bert::bert::BertConfig;
use crate::bert::bert_model::BertConfig;
use std::borrow::{Borrow, BorrowMut};
use tch::{nn, Tensor};
@ -34,7 +34,7 @@ impl BertLayer {
let attention = BertAttention::new(p / "attention", &config);
let (is_decoder, cross_attention) = match config.is_decoder {
Some(value) => {
if value == true {
if value {
(
value,
Some(BertAttention::new(p / "cross_attention", &config)),
@ -150,28 +150,23 @@ impl BertEncoder {
let mut hidden_state = hidden_states.copy();
let mut attention_weights: Option<Tensor>;
let mut layers = self.layers.iter();
loop {
match layers.next() {
Some(layer) => {
if let Some(hidden_states) = all_hidden_states.borrow_mut() {
hidden_states.push(hidden_state.as_ref().copy());
};
let temp = layer.forward_t(
&hidden_state,
&mask,
encoder_hidden_states,
encoder_mask,
train,
);
hidden_state = temp.0;
attention_weights = temp.1;
if let Some(attentions) = all_attentions.borrow_mut() {
attentions.push(attention_weights.as_ref().unwrap().copy());
};
}
None => break,
for layer in &self.layers {
if let Some(hidden_states) = all_hidden_states.borrow_mut() {
hidden_states.push(hidden_state.as_ref().copy());
};
let temp = layer.forward_t(
&hidden_state,
&mask,
encoder_hidden_states,
encoder_mask,
train,
);
hidden_state = temp.0;
attention_weights = temp.1;
if let Some(attentions) = all_attentions.borrow_mut() {
attentions.push(attention_weights.as_ref().unwrap().copy());
};
}

View File

@ -10,7 +10,7 @@
//!
//! # Model set-up and pre-trained weights loading
//!
//! A full working example is provided in `examples/bert.rs`, run with `cargo run --example bert`.
//! A full working example is provided in `examples/bert`, run with `cargo run --example bert`.
//! The example below illustrate a Masked language model example, the structure is similar for other models.
//! All models expect the following resources:
//! - Configuration file expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers)
@ -53,13 +53,15 @@
//! ```
mod attention;
mod bert;
mod bert_model;
mod embeddings;
pub(crate) mod encoder;
pub use bert::{
pub use bert_model::{
Activation, BertConfig, BertConfigResources, BertForMaskedLM, BertForMultipleChoice,
BertForQuestionAnswering, BertForSequenceClassification, BertForTokenClassification, BertModel,
BertModelResources, BertVocabResources,
BertForQuestionAnswering, BertForSequenceClassification, BertForTokenClassification,
BertMaskedLMOutput, BertModel, BertModelOutput, BertModelResources,
BertQuestionAnsweringOutput, BertSequenceClassificationOutput, BertTokenClassificationOutput,
BertVocabResources,
};
pub use embeddings::{BertEmbedding, BertEmbeddings};

View File

@ -18,7 +18,7 @@
//! pre-trained models in each model module.
use crate::common::error::RustBertError;
use cached_path::Cache;
use cached_path::{Cache, Options, ProgressBar};
use lazy_static::lazy_static;
use std::env;
use std::path::PathBuf;
@ -59,8 +59,10 @@ impl Resource {
match self {
Resource::Local(resource) => Ok(resource.local_path.clone()),
Resource::Remote(resource) => {
let cached_path =
CACHE.cached_path_in_subdir(&resource.url, Some(&resource.cache_subdir))?;
let cached_path = CACHE.cached_path_with_options(
&resource.url,
&Options::default().subdir(&resource.cache_subdir),
)?;
Ok(cached_path)
}
}
@ -148,11 +150,14 @@ lazy_static! {
/// # Global cache directory
/// If the environment variable `RUSTBERT_CACHE` is set, will save the cache model files at that
/// location. Otherwise defaults to `~/.cache/.rustbert`.
pub static ref CACHE: Cache = Cache::builder().dir(_get_cache_directory()).build().unwrap();
pub static ref CACHE: Cache = Cache::builder()
.dir(_get_cache_directory())
.progress_bar(Some(ProgressBar::Light))
.build().unwrap();
}
fn _get_cache_directory() -> PathBuf {
let home = match env::var("RUSTBERT_CACHE") {
match env::var("RUSTBERT_CACHE") {
Ok(value) => PathBuf::from(value),
Err(_) => {
let mut home = dirs::home_dir().unwrap();
@ -160,8 +165,7 @@ fn _get_cache_directory() -> PathBuf {
home.push(".rustbert");
home
}
};
home
}
}
#[deprecated(

View File

@ -11,7 +11,7 @@
// limitations under the License.
use crate::common::dropout::Dropout;
use crate::distilbert::distilbert::DistilBertConfig;
use crate::distilbert::distilbert_model::DistilBertConfig;
use std::borrow::Borrow;
use tch::kind::Kind::Float;
use tch::{nn, Tensor};
@ -65,7 +65,7 @@ impl MultiHeadSelfAttention {
fn flatten(&self, x: Tensor, bs: i64, dim_per_head: i64) -> Tensor {
x.transpose(1, 2)
.contiguous()
.view((bs, -1, &self.n_heads * dim_per_head))
.view((bs, -1, self.n_heads * dim_per_head))
}
pub fn forward_t(

View File

@ -15,8 +15,8 @@ extern crate tch;
use self::tch::{nn, Tensor};
use crate::common::dropout::Dropout;
use crate::distilbert::embeddings::DistilBertEmbedding;
use crate::distilbert::transformer::Transformer;
use crate::Config;
use crate::distilbert::transformer::{DistilBertTransformerOutput, Transformer};
use crate::{Config, RustBertError};
use serde::{Deserialize, Serialize};
use std::{borrow::Borrow, collections::HashMap};
@ -180,9 +180,10 @@ impl DistilBertModel {
///
/// # Returns
///
/// * `output` - `Tensor` of shape (*batch size*, *sequence_length*, *hidden_size*) representing the activations of the last hidden state
/// * `hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
/// * `attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
/// * `DistilBertTransformerOutput` containing:
/// - `hidden_state` - `Tensor` of shape (*batch size*, *sequence_length*, *hidden_size*)
/// - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
/// - `all_attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
///
/// # Example
///
@ -202,7 +203,7 @@ impl DistilBertModel {
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
///
/// let (output, all_hidden_states, all_attentions) = no_grad(|| {
/// let model_output = no_grad(|| {
/// distilbert_model
/// .forward_t(Some(input_tensor), Some(mask), None, false)
/// .unwrap()
@ -214,18 +215,22 @@ impl DistilBertModel {
mask: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool,
) -> Result<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
) -> Result<DistilBertTransformerOutput, RustBertError> {
let input_embeddings = match input {
Some(input_value) => match input_embeds {
Some(_) => {
return Err("Only one of input ids or input embeddings may be set");
return Err(RustBertError::ValueError(
"Only one of input ids or input embeddings may be set".into(),
));
}
None => input_value.apply_t(&self.embeddings, train),
},
None => match input_embeds {
Some(embeds) => embeds,
None => {
return Err("At least one of input ids or input embeddings must be set");
return Err(RustBertError::ValueError(
"At least one of input ids or input embeddings must be set".into(),
));
}
},
};
@ -313,9 +318,10 @@ impl DistilBertModelClassifier {
///
/// # Returns
///
/// * `output` - `Tensor` of shape (*batch size*, *num_labels*) representing the logits for each class to predict
/// * `hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
/// * `attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
/// * `DistilBertSequenceClassificationOutput` containing:
/// - `logits` - `Tensor` of shape (*batch size*, *num_labels*)
/// - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
/// - `all_attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
///
/// # Example
///
@ -335,7 +341,7 @@ impl DistilBertModelClassifier {
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
///
/// let (output, all_hidden_states, all_attentions) = no_grad(|| {
/// let model_output = no_grad(|| {
/// distilbert_model
/// .forward_t(Some(input_tensor),
/// Some(mask),
@ -349,24 +355,24 @@ impl DistilBertModelClassifier {
mask: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool,
) -> Result<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
let (output, all_hidden_states, all_attentions) =
match self
.distil_bert_model
.forward_t(input, mask, input_embeds, train)
{
Ok(value) => value,
Err(err) => return Err(err),
};
) -> Result<DistilBertSequenceClassificationOutput, RustBertError> {
let base_model_output =
self.distil_bert_model
.forward_t(input, mask, input_embeds, train)?;
let output = output
let logits = base_model_output
.hidden_state
.select(1, 0)
.apply(&self.pre_classifier)
.relu()
.apply_t(&self.dropout, train)
.apply(&self.classifier);
Ok((output, all_hidden_states, all_attentions))
Ok(DistilBertSequenceClassificationOutput {
logits,
all_hidden_states: base_model_output.all_hidden_states,
all_attentions: base_model_output.all_attentions,
})
}
}
@ -451,9 +457,10 @@ impl DistilBertModelMaskedLM {
///
/// # Returns
///
/// * `output` - `Tensor` of shape (*batch size*, *sequence_length*, *vocab_size*) representing the logits for position and vocabulary index
/// * `hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
/// * `attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
/// * `DistilBertMaskedLMOutput` containing:
/// - `prediction_scores` - `Tensor` of shape (*batch size*, *sequence_length*, *vocab_size*)
/// - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
/// - `all_attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
///
/// # Example
///
@ -473,7 +480,7 @@ impl DistilBertModelMaskedLM {
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
///
/// let (output, all_hidden_states, all_attentions) = no_grad(|| {
/// let model_output = no_grad(|| {
/// distilbert_model
/// .forward_t(Some(input_tensor), Some(mask), None, false)
/// .unwrap()
@ -485,23 +492,23 @@ impl DistilBertModelMaskedLM {
mask: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool,
) -> Result<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
let (output, all_hidden_states, all_attentions) =
match self
.distil_bert_model
.forward_t(input, mask, input_embeds, train)
{
Ok(value) => value,
Err(err) => return Err(err),
};
) -> Result<DistilBertMaskedLMOutput, RustBertError> {
let base_model_output =
self.distil_bert_model
.forward_t(input, mask, input_embeds, train)?;
let output = output
let prediction_scores = base_model_output
.hidden_state
.apply(&self.vocab_transform)
.gelu()
.apply(&self.vocab_layer_norm)
.apply(&self.vocab_projector);
Ok((output, all_hidden_states, all_attentions))
Ok(DistilBertMaskedLMOutput {
prediction_scores,
all_hidden_states: base_model_output.all_hidden_states,
all_attentions: base_model_output.all_attentions,
})
}
}
@ -568,10 +575,11 @@ impl DistilBertForQuestionAnswering {
///
/// # Returns
///
/// * `start_scores` - `Tensor` of shape (*batch size*, *sequence_length*) containing the logits for start of the answer
/// * `end_scores` - `Tensor` of shape (*batch size*, *sequence_length*) containing the logits for end of the answer
/// * `hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
/// * `attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
/// * `DistilBertQuestionAnsweringOutput` containing:
/// - `start_logits` - `Tensor` of shape (*batch size*, *sequence_length*) containing the logits for start of the answer
/// - `end_logits` - `Tensor` of shape (*batch size*, *sequence_length*) containing the logits for end of the answer
/// - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
/// - `all_attentions` - `Option<Vec<Vec<Tensor>>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
///
/// # Example
///
@ -591,7 +599,7 @@ impl DistilBertForQuestionAnswering {
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
///
/// let (start_scores, end_score, all_hidden_states, all_attentions) = no_grad(|| {
/// let model_output = no_grad(|| {
/// distilbert_model
/// .forward_t(Some(input_tensor), Some(mask), None, false)
/// .unwrap()
@ -603,24 +611,27 @@ impl DistilBertForQuestionAnswering {
mask: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool,
) -> Result<(Tensor, Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
let (output, all_hidden_states, all_attentions) =
match self
.distil_bert_model
.forward_t(input, mask, input_embeds, train)
{
Ok(value) => value,
Err(err) => return Err(err),
};
) -> Result<DistilBertQuestionAnsweringOutput, RustBertError> {
let base_model_output =
self.distil_bert_model
.forward_t(input, mask, input_embeds, train)?;
let output = output.apply_t(&self.dropout, train).apply(&self.qa_outputs);
let output = base_model_output
.hidden_state
.apply_t(&self.dropout, train)
.apply(&self.qa_outputs);
let logits = output.split(1, -1);
let (start_logits, end_logits) = (&logits[0], &logits[1]);
let start_logits = start_logits.squeeze1(-1);
let end_logits = end_logits.squeeze1(-1);
Ok((start_logits, end_logits, all_hidden_states, all_attentions))
Ok(DistilBertQuestionAnsweringOutput {
start_logits,
end_logits,
all_hidden_states: base_model_output.all_hidden_states,
all_attentions: base_model_output.all_attentions,
})
}
}
@ -693,9 +704,10 @@ impl DistilBertForTokenClassification {
///
/// # Returns
///
/// * `output` - `Tensor` of shape (*batch size*, *sequence_length*, *num_labels*) representing the logits for position and class
/// * `hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
/// * `attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
/// * `DistilBertTokenClassificationOutput` containing:
/// - `logits` - `Tensor` of shape (*batch size*, *sequence_length*, *num_labels*) containing the logits for each of the input tokens and classes
/// - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
/// - `all_attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
///
/// # Example
///
@ -715,7 +727,7 @@ impl DistilBertForTokenClassification {
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
///
/// let (output, all_hidden_states, all_attentions) = no_grad(|| {
/// let model_output = no_grad(|| {
/// distilbert_model
/// .forward_t(Some(input_tensor), Some(mask), None, false)
/// .unwrap()
@ -727,18 +739,60 @@ impl DistilBertForTokenClassification {
mask: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool,
) -> Result<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
let (output, all_hidden_states, all_attentions) =
match self
.distil_bert_model
.forward_t(input, mask, input_embeds, train)
{
Ok(value) => value,
Err(err) => return Err(err),
};
) -> Result<DistilBertTokenClassificationOutput, RustBertError> {
let base_model_output =
self.distil_bert_model
.forward_t(input, mask, input_embeds, train)?;
let output = output.apply_t(&self.dropout, train).apply(&self.classifier);
let logits = base_model_output
.hidden_state
.apply_t(&self.dropout, train)
.apply(&self.classifier);
Ok((output, all_hidden_states, all_attentions))
Ok(DistilBertTokenClassificationOutput {
logits,
all_hidden_states: base_model_output.all_hidden_states,
all_attentions: base_model_output.all_attentions,
})
}
}
/// Container for the DistilBERT masked LM model output.
pub struct DistilBertMaskedLMOutput {
/// Logits for the vocabulary items at each sequence position
pub prediction_scores: Tensor,
/// Hidden states for all intermediate layers
pub all_hidden_states: Option<Vec<Tensor>>,
/// Attention weights for all intermediate layers
pub all_attentions: Option<Vec<Tensor>>,
}
/// Container for the DistilBERT sequence classification model output
pub struct DistilBertSequenceClassificationOutput {
/// Logits for each input (sequence) for each target class
pub logits: Tensor,
/// Hidden states for all intermediate layers
pub all_hidden_states: Option<Vec<Tensor>>,
/// Attention weights for all intermediate layers
pub all_attentions: Option<Vec<Tensor>>,
}
/// Container for the DistilBERT token classification model output
pub struct DistilBertTokenClassificationOutput {
/// Logits for each sequence item (token) for each target class
pub logits: Tensor,
/// Hidden states for all intermediate layers
pub all_hidden_states: Option<Vec<Tensor>>,
/// Attention weights for all intermediate layers
pub all_attentions: Option<Vec<Tensor>>,
}
/// Container for the DistilBERT question answering model output
pub struct DistilBertQuestionAnsweringOutput {
/// Logits for the start position for token of each input sequence
pub start_logits: Tensor,
/// Logits for the end position for token of each input sequence
pub end_logits: Tensor,
/// Hidden states for all intermediate layers
pub all_hidden_states: Option<Vec<Tensor>>,
/// Attention weights for all intermediate layers
pub all_attentions: Option<Vec<Tensor>>,
}

View File

@ -11,7 +11,7 @@
// limitations under the License.
use crate::common::dropout::Dropout;
use crate::distilbert::distilbert::DistilBertConfig;
use crate::distilbert::distilbert_model::DistilBertConfig;
use std::borrow::Borrow;
use tch::kind::Kind::Float;
use tch::nn::{embedding, EmbeddingConfig, ModuleT};
@ -125,10 +125,8 @@ impl ModuleT for DistilBertEmbedding {
let position_embed = position_ids.apply(&self.position_embeddings);
let embeddings = word_embed + position_embed;
let embeddings = embeddings
.apply(&self.layer_norm)
.apply_t(&self.dropout, train);
embeddings
.apply(&self.layer_norm)
.apply_t(&self.dropout, train)
}
}

View File

@ -55,12 +55,14 @@
//! ```
mod attention;
mod distilbert;
mod distilbert_model;
mod embeddings;
mod transformer;
pub use distilbert::{
pub use distilbert_model::{
Activation, DistilBertConfig, DistilBertConfigResources, DistilBertForQuestionAnswering,
DistilBertForTokenClassification, DistilBertModel, DistilBertModelClassifier,
DistilBertModelMaskedLM, DistilBertModelResources, DistilBertVocabResources,
DistilBertForTokenClassification, DistilBertMaskedLMOutput, DistilBertModel,
DistilBertModelClassifier, DistilBertModelMaskedLM, DistilBertModelResources,
DistilBertQuestionAnsweringOutput, DistilBertSequenceClassificationOutput,
DistilBertTokenClassificationOutput, DistilBertVocabResources,
};

View File

@ -13,7 +13,7 @@
use crate::common::activations::{_gelu, _relu};
use crate::common::dropout::Dropout;
use crate::distilbert::attention::MultiHeadSelfAttention;
use crate::distilbert::distilbert::{Activation, DistilBertConfig};
use crate::distilbert::distilbert_model::{Activation, DistilBertConfig};
use std::borrow::{Borrow, BorrowMut};
use tch::nn::LayerNorm;
use tch::{nn, Tensor};
@ -149,7 +149,7 @@ impl Transformer {
input: &Tensor,
mask: Option<Tensor>,
train: bool,
) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
) -> DistilBertTransformerOutput {
let mut all_hidden_states: Option<Vec<Tensor>> = if self.output_hidden_states {
Some(vec![])
} else {
@ -163,25 +163,34 @@ impl Transformer {
let mut hidden_state = input.copy();
let mut attention_weights: Option<Tensor>;
let mut layers = self.layers.iter();
loop {
match layers.next() {
Some(layer) => {
if let Some(hidden_states) = all_hidden_states.borrow_mut() {
hidden_states.push(hidden_state.as_ref().copy());
};
let temp = layer.forward_t(&hidden_state, &mask, train);
hidden_state = temp.0;
attention_weights = temp.1;
if let Some(attentions) = all_attentions.borrow_mut() {
attentions.push(attention_weights.as_ref().unwrap().copy());
};
}
None => break,
for layer in &self.layers {
if let Some(hidden_states) = all_hidden_states.borrow_mut() {
hidden_states.push(hidden_state.as_ref().copy());
};
let temp = layer.forward_t(&hidden_state, &mask, train);
hidden_state = temp.0;
attention_weights = temp.1;
if let Some(attentions) = all_attentions.borrow_mut() {
attentions.push(attention_weights.as_ref().unwrap().copy());
};
}
(hidden_state, all_hidden_states, all_attentions)
DistilBertTransformerOutput {
hidden_state,
all_hidden_states,
all_attentions,
}
}
}
/// Container for the DistilBert transformer output.
pub struct DistilBertTransformerOutput {
/// Last hidden states from the model
pub hidden_state: Tensor,
/// Hidden states for all intermediate layers
pub all_hidden_states: Option<Vec<Tensor>>,
/// Attention weights for all intermediate layers
pub all_attentions: Option<Vec<Tensor>>,
}

View File

@ -17,7 +17,7 @@ use crate::bert::{Activation, BertConfig};
use crate::common::activations::{_gelu, _mish, _relu};
use crate::common::dropout::Dropout;
use crate::electra::embeddings::ElectraEmbeddings;
use crate::Config;
use crate::{Config, RustBertError};
use serde::{Deserialize, Serialize};
use std::{borrow::Borrow, collections::HashMap};
use tch::{nn, Kind, Tensor};
@ -188,9 +188,10 @@ impl ElectraModel {
///
/// # Returns
///
/// * `output` - `Tensor` of shape (*batch size*, *sequence_length*, *hidden_size*)
/// * `hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
/// * `attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
/// * `ElectraModelOutput` containing:
/// - `hidden_state` - `Tensor` of shape (*batch size*, *sequence_length*, *hidden_size*)
/// - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
/// - `all_attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
///
/// # Example
///
@ -212,7 +213,7 @@ impl ElectraModel {
/// let position_ids = Tensor::arange(sequence_length, (Int64, device))
/// .expand(&[batch_size, sequence_length], true);
///
/// let (output, all_hidden_states, all_attentions) = no_grad(|| {
/// let model_output = no_grad(|| {
/// electra_model
/// .forward_t(
/// Some(input_tensor),
@ -233,18 +234,22 @@ impl ElectraModel {
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool,
) -> Result<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
) -> Result<ElectraModelOutput, RustBertError> {
let (input_shape, device) = match &input_ids {
Some(input_value) => match &input_embeds {
Some(_) => {
return Err("Only one of input ids or input embeddings may be set");
return Err(RustBertError::ValueError(
"Only one of input ids or input embeddings may be set".into(),
));
}
None => (input_value.size(), input_value.device()),
},
None => match &input_embeds {
Some(embeds) => (vec![embeds.size()[0], embeds.size()[1]], embeds.device()),
None => {
return Err("At least one of input ids or input embeddings must be set");
return Err(RustBertError::ValueError(
"At least one of input ids or input embeddings must be set".into(),
));
}
},
};
@ -258,7 +263,9 @@ impl ElectraModel {
3 => mask.unsqueeze(1),
2 => mask.unsqueeze(1).unsqueeze(1),
_ => {
return Err("Invalid attention mask dimension, must be 2 or 3");
return Err(RustBertError::ValueError(
"Invalid attention mask dimension, must be 2 or 3".into(),
));
}
};
@ -288,7 +295,11 @@ impl ElectraModel {
train,
);
Ok((hidden_state, all_hidden_states, all_attentions))
Ok(ElectraModelOutput {
hidden_state,
all_hidden_states,
all_attentions,
})
}
}
@ -366,8 +377,6 @@ impl ElectraDiscriminatorHead {
/// # Returns
///
/// * `output` - `Tensor` of shape (*batch size*, *sequence_length*)
/// * `hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
/// * `attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
///
/// # Example
///
@ -566,9 +575,10 @@ impl ElectraForMaskedLM {
///
/// # Returns
///
/// * `output` - `Tensor` of shape (*batch size*, *sequence_length*, *vocab_size*)
/// * `hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
/// * `attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
/// * `ElectraMaskedLMOutput` containing:
/// - `prediction_scores` - `Tensor` of shape (*batch size*, *sequence_length*, *vocab_size*)
/// - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
/// - `all_attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
///
/// # Example
///
@ -590,7 +600,7 @@ impl ElectraForMaskedLM {
/// let position_ids = Tensor::arange(sequence_length, (Int64, device))
/// .expand(&[batch_size, sequence_length], true);
///
/// let (output, all_hidden_states, all_attentions) = no_grad(|| {
/// let model_output = no_grad(|| {
/// electra_model.forward_t(
/// Some(input_tensor),
/// Some(mask),
@ -609,8 +619,8 @@ impl ElectraForMaskedLM {
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool,
) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
let (hidden_states, all_hidden_states, all_attentions) = self
) -> ElectraMaskedLMOutput {
let base_model_output = self
.electra
.forward_t(
input_ids,
@ -621,9 +631,13 @@ impl ElectraForMaskedLM {
train,
)
.unwrap();
let hidden_states = self.generator_head.forward(&hidden_states);
let hidden_states = hidden_states.apply(&self.lm_head);
(hidden_states, all_hidden_states, all_attentions)
let hidden_states = self.generator_head.forward(&base_model_output.hidden_state);
let prediction_scores = hidden_states.apply(&self.lm_head);
ElectraMaskedLMOutput {
prediction_scores,
all_hidden_states: base_model_output.all_hidden_states,
all_attentions: base_model_output.all_attentions,
}
}
}
@ -689,9 +703,10 @@ impl ElectraDiscriminator {
///
/// # Returns
///
/// * `output` - `Tensor` of shape (*batch size*, *sequence_length*)
/// * `hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
/// * `attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
/// * `ElectraDiscriminatorOutput` containing:
/// - `logits` - `Tensor` of shape (*batch size*, *sequence_length*) containing the probability of each token to be generated by a language model
/// - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
/// - `all_attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
///
/// # Example
///
@ -712,7 +727,7 @@ impl ElectraDiscriminator {
/// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
/// let position_ids = Tensor::arange(sequence_length, (Int64, device)).expand(&[batch_size, sequence_length], true);
///
/// let (output, all_hidden_states, all_attentions) = no_grad(|| {
/// let model_output = no_grad(|| {
/// electra_model
/// .forward_t(Some(input_tensor),
/// Some(mask),
@ -730,8 +745,8 @@ impl ElectraDiscriminator {
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool,
) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
let (hidden_states, all_hidden_states, all_attentions) = self
) -> ElectraDiscriminatorOutput {
let base_model_output = self
.electra
.forward_t(
input_ids,
@ -742,8 +757,15 @@ impl ElectraDiscriminator {
train,
)
.unwrap();
let probabilities = self.discriminator_head.forward(&hidden_states).sigmoid();
(probabilities, all_hidden_states, all_attentions)
let probabilities = self
.discriminator_head
.forward(&base_model_output.hidden_state)
.sigmoid();
ElectraDiscriminatorOutput {
probabilities,
all_hidden_states: base_model_output.all_hidden_states,
all_attentions: base_model_output.all_attentions,
}
}
}
@ -822,9 +844,10 @@ impl ElectraForTokenClassification {
///
/// # Returns
///
/// * `output` - `Tensor` of shape (*batch size*, *sequence_length*, *num_classes*)
/// * `hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
/// * `attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
/// * `ElectraTokenClassificationOutput` containing:
/// - `logits` - `Tensor` of shape (*batch size*, *sequence_length*, *num_labels*) containing the logits for each of the input tokens and classes
/// - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
/// - `all_attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
///
/// # Example
///
@ -845,7 +868,7 @@ impl ElectraForTokenClassification {
/// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
/// let position_ids = Tensor::arange(sequence_length, (Int64, device)).expand(&[batch_size, sequence_length], true);
///
/// let (output, all_hidden_states, all_attentions) = no_grad(|| {
/// let model_output = no_grad(|| {
/// electra_model
/// .forward_t(Some(input_tensor),
/// Some(mask),
@ -863,8 +886,8 @@ impl ElectraForTokenClassification {
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool,
) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
let (hidden_states, all_hidden_states, all_attentions) = self
) -> ElectraTokenClassificationOutput {
let base_model_output = self
.electra
.forward_t(
input_ids,
@ -875,9 +898,54 @@ impl ElectraForTokenClassification {
train,
)
.unwrap();
let output = hidden_states
let logits = base_model_output
.hidden_state
.apply_t(&self.dropout, train)
.apply(&self.classifier);
(output, all_hidden_states, all_attentions)
ElectraTokenClassificationOutput {
logits,
all_hidden_states: base_model_output.all_hidden_states,
all_attentions: base_model_output.all_attentions,
}
}
}
/// Container for the Electra model output.
pub struct ElectraModelOutput {
/// Last hidden states from the model
pub hidden_state: Tensor,
/// Hidden states for all intermediate layers
pub all_hidden_states: Option<Vec<Tensor>>,
/// Attention weights for all intermediate layers
pub all_attentions: Option<Vec<Tensor>>,
}
/// Container for the Electra discriminator model output.
pub struct ElectraDiscriminatorOutput {
/// Probabilities for each sequence item (token) to be generated by a language model
pub probabilities: Tensor,
/// Hidden states for all intermediate layers
pub all_hidden_states: Option<Vec<Tensor>>,
/// Attention weights for all intermediate layers
pub all_attentions: Option<Vec<Tensor>>,
}
/// Container for the Electra masked LM model output.
pub struct ElectraMaskedLMOutput {
/// Logits for the vocabulary items at each sequence position
pub prediction_scores: Tensor,
/// Hidden states for all intermediate layers
pub all_hidden_states: Option<Vec<Tensor>>,
/// Attention weights for all intermediate layers
pub all_attentions: Option<Vec<Tensor>>,
}
/// Container for the Electra token classification model output.
pub struct ElectraTokenClassificationOutput {
/// Logits for each sequence item (token) for each target class
pub logits: Tensor,
/// Hidden states for all intermediate layers
pub all_hidden_states: Option<Vec<Tensor>>,
/// Attention weights for all intermediate layers
pub all_attentions: Option<Vec<Tensor>>,
}

View File

@ -13,7 +13,8 @@
// limitations under the License.
use crate::common::dropout::Dropout;
use crate::electra::electra::ElectraConfig;
use crate::electra::electra_model::ElectraConfig;
use crate::RustBertError;
use std::borrow::Borrow;
use tch::nn::{embedding, EmbeddingConfig};
use tch::{nn, Kind, Tensor};
@ -91,11 +92,13 @@ impl ElectraEmbeddings {
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool,
) -> Result<Tensor, &'static str> {
) -> Result<Tensor, RustBertError> {
let (input_embeddings, input_shape) = match input_ids {
Some(input_value) => match input_embeds {
Some(_) => {
return Err("Only one of input ids or input embeddings may be set");
return Err(RustBertError::ValueError(
"Only one of input ids or input embeddings may be set".into(),
));
}
None => (
input_value.apply_t(&self.word_embeddings, train),
@ -108,7 +111,9 @@ impl ElectraEmbeddings {
(embeds, size)
}
None => {
return Err("Only one of input ids or input embeddings may be set");
return Err(RustBertError::ValueError(
"At least one of input ids or input embeddings must be set".into(),
));
}
},
};

View File

@ -56,11 +56,12 @@
//! # }
//! ```
mod electra;
mod electra_model;
mod embeddings;
pub use electra::{
pub use electra_model::{
ElectraConfig, ElectraConfigResources, ElectraDiscriminator, ElectraDiscriminatorHead,
ElectraForMaskedLM, ElectraForTokenClassification, ElectraGeneratorHead, ElectraModel,
ElectraModelResources, ElectraVocabResources,
ElectraDiscriminatorOutput, ElectraForMaskedLM, ElectraForTokenClassification,
ElectraGeneratorHead, ElectraMaskedLMOutput, ElectraModel, ElectraModelOutput,
ElectraModelResources, ElectraTokenClassificationOutput, ElectraVocabResources,
};

View File

@ -13,7 +13,7 @@
// limitations under the License.
use crate::common::dropout::Dropout;
use crate::gpt2::gpt2::Gpt2Config;
use crate::gpt2::gpt2_model::Gpt2Config;
use std::borrow::Borrow;
use tch::kind::Kind::Float;
use tch::nn::{Init, Module};
@ -128,20 +128,20 @@ impl Attention {
fn flatten(&self, x: Tensor) -> Tensor {
x.transpose(1, 2)
.contiguous()
.view((x.size()[0], -1, &self.n_head * self.dim_per_head))
.view((x.size()[0], -1, self.n_head * self.dim_per_head))
}
fn attention(
&self,
q: &Tensor,
k: &Tensor,
v: &Tensor,
query: &Tensor,
key: &Tensor,
value: &Tensor,
attention_mask: &Option<Tensor>,
train: bool,
) -> (Tensor, Option<Tensor>) {
let mut w = q.matmul(&k);
let mut w = query.matmul(&key);
if self.scale {
w = w / (*v.size().last().unwrap() as f64).sqrt();
w = w / (*value.size().last().unwrap() as f64).sqrt();
}
let (nd, ns) = (w.size()[2], w.size()[3]);
@ -152,7 +152,7 @@ impl Attention {
w = w + mask;
}
w = w.softmax(-1, Float).apply_t(&self.attn_dropout, train);
let output = w.matmul(&v);
let output = w.matmul(&value);
if self.output_attentions {
(output, Some(w))

View File

@ -15,8 +15,8 @@
use crate::common::dropout::Dropout;
use crate::common::linear::{linear_no_bias, LinearNoBias};
use crate::gpt2::transformer::Block;
use crate::pipelines::generation::{Cache, LMHeadModel};
use crate::Config;
use crate::pipelines::generation::{Cache, LMHeadModel, LMModelOutput};
use crate::{Config, RustBertError};
use serde::{Deserialize, Serialize};
use std::borrow::{Borrow, BorrowMut};
use tch::kind::Kind::Int64;
@ -313,10 +313,11 @@ impl Gpt2Model {
///
/// # Returns
///
/// * `output` - `Tensor` of shape (*batch size*, *sequence_length*, *hidden_size*) representing the activations of the last hidden state
/// * `past` - `Option<Vec<Tensor>>` of length *n_layer* containing the past keys and values of each layer of shape (*2*, *batch size*, *number of heads*, *past_sequence_length*, *hidden size per head*)
/// * `hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
/// * `attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
/// * `Gpt2ModelOutput` containing:
/// - `output` - `Tensor` of shape (*batch size*, *sequence_length*, *vocab_size*) representing the activations of the last hidden state
/// - `cache` - `Option<Vec<Tensor>>` of length *n_layer* containing the past keys and values of each layer of shape (*2*, *batch size*, *number of heads*, *past_sequence_length*, *hidden size per head*)
/// - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
/// - `all_attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
///
/// # Example
///
@ -352,7 +353,7 @@ impl Gpt2Model {
/// let position_ids = Tensor::arange(sequence_length, (Int64, device))
/// .expand(&[batch_size, sequence_length], true);
///
/// let (output, past, hidden_states, attentions) = no_grad(|| {
/// let model_output = no_grad(|| {
/// gpt2_model
/// .forward_t(
/// &Some(input_tensor),
@ -375,19 +376,13 @@ impl Gpt2Model {
position_ids: &Option<Tensor>,
input_embeds: &Option<Tensor>,
train: bool,
) -> Result<
(
Tensor,
Option<Vec<Tensor>>,
Option<Vec<Tensor>>,
Option<Vec<Tensor>>,
),
&'static str,
> {
) -> Result<Gpt2ModelOutput, RustBertError> {
let (input_embeddings, seq_length) = match input_ids {
Some(input_value) => match input_embeds {
Some(_) => {
return Err("Only one of input ids or input embeddings may be set");
return Err(RustBertError::ValueError(
"Only one of input ids or input embeddings may be set".into(),
));
}
None => (
input_value.apply(&self.wte),
@ -397,7 +392,9 @@ impl Gpt2Model {
None => match input_embeds {
Some(embeds) => (embeds.copy(), embeds.size()[1]),
None => {
return Err("At least one of input ids or input embeddings must be set");
return Err(RustBertError::ValueError(
"At least one of input ids or input embeddings must be set".into(),
));
}
},
};
@ -466,34 +463,29 @@ impl Gpt2Model {
None
};
let mut layer_iter = self.h.iter().zip(layer_past);
loop {
match layer_iter.next() {
Some(layer_values) => {
let (layer, past) = layer_values;
if let Some(hidden_states) = all_hidden_states.borrow_mut() {
hidden_states.push(hidden_state.as_ref().copy());
};
let layer_iter = self.h.iter().zip(layer_past);
for layer_values in layer_iter {
let (layer, past) = layer_values;
if let Some(hidden_states) = all_hidden_states.borrow_mut() {
hidden_states.push(hidden_state.as_ref().copy());
};
let temp = layer.forward_t(&hidden_state, &past, &attention_mask, train);
hidden_state = temp.0;
if let Some(presents) = all_presents.borrow_mut() {
presents.push(temp.1.as_ref().copy());
};
if let Some(attentions) = all_attentions.borrow_mut() {
attentions.push(temp.2.as_ref().unwrap().copy());
};
}
None => break,
let temp = layer.forward_t(&hidden_state, &past, &attention_mask, train);
hidden_state = temp.0;
if let Some(presents) = all_presents.borrow_mut() {
presents.push(temp.1.as_ref().copy());
};
if let Some(attentions) = all_attentions.borrow_mut() {
attentions.push(temp.2.as_ref().unwrap().copy());
};
}
Ok((
hidden_state.apply(&self.ln_f),
all_presents,
Ok(Gpt2ModelOutput {
output: hidden_state.apply(&self.ln_f),
cache: all_presents,
all_hidden_states,
all_attentions,
))
})
}
}
@ -567,11 +559,12 @@ impl LMHeadModel for GPT2LMHeadModel {
///
/// # Returns
///
/// * `output` - `Tensor` of shape (*batch size*, *sequence_length*, *vocab_size*) representing the logits for each vocab item and position
/// * `encoder_hidden_states` - None
/// * `past` - `Option<Vec<Tensor>>` of length *n_layer* containing the past keys and values of each layer of shape (*2*, *batch size*, *number of heads*, *past_sequence_length*, *hidden size per head*)
/// * `hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
/// * `attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
/// * `LMModelOutput` containing:
/// - `lm_logits` - `Tensor` of shape (*batch size*, *sequence_length*, *vocab_size*) representing the logits for each vocab item and position
/// - `cache` - `Gpt2Cache` made of `Option<Vec<Tensor>>` of length *n_layer* containing the past keys and values of each layer of shape (*2*, *batch size*, *number of heads*, *past_sequence_length*, *hidden size per head*)
/// - `encoder_hidden_states` - None
/// - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
/// - `all_attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
///
/// # Example
///
@ -608,7 +601,7 @@ impl LMHeadModel for GPT2LMHeadModel {
/// let position_ids = Tensor::arange(sequence_length, (Int64, device))
/// .expand(&[batch_size, sequence_length], true);
///
/// let (output, _, past, hidden_states, attentions) = no_grad(|| {
/// let model_output = no_grad(|| {
/// gpt2_model
/// .forward_t(
/// &Some(input_tensor),
@ -635,18 +628,9 @@ impl LMHeadModel for GPT2LMHeadModel {
_encoder_outputs: Option<&Tensor>,
_decoder_input_ids: &Option<Tensor>,
train: bool,
) -> Result<
(
Tensor,
Option<Tensor>,
Cache,
Option<Vec<Tensor>>,
Option<Vec<Tensor>>,
),
&'static str,
> {
let (output, past, all_hidden_states, all_attentions) = match layer_past {
Cache::GPT2Cache(layer_past) => Ok(self.transformer.forward_t(
) -> Result<LMModelOutput, RustBertError> {
let base_model_output = match layer_past {
Cache::GPT2Cache(layer_past) => self.transformer.forward_t(
input_ids,
&layer_past,
attention_mask,
@ -654,8 +638,8 @@ impl LMHeadModel for GPT2LMHeadModel {
position_ids,
input_embeds,
train,
)?),
Cache::None => Ok(self.transformer.forward_t(
),
Cache::None => self.transformer.forward_t(
input_ids,
&None,
attention_mask,
@ -663,17 +647,34 @@ impl LMHeadModel for GPT2LMHeadModel {
position_ids,
input_embeds,
train,
)?),
_ => Err("Cache not compatible with GPT2 model"),
),
_ => {
return Err(RustBertError::ValueError(
"Cache not compatible with GPT2 Model".into(),
));
}
}?;
let lm_logits = output.apply(&self.lm_head);
Ok((
let lm_logits = base_model_output.output.apply(&self.lm_head);
Ok(LMModelOutput {
lm_logits,
None,
Cache::GPT2Cache(past),
all_hidden_states,
all_attentions,
))
encoder_hidden_state: None,
cache: Cache::GPT2Cache(base_model_output.cache),
all_hidden_states: base_model_output.all_hidden_states,
all_attentions: base_model_output.all_attentions,
})
}
}
/// Container for the GPT2 model output.
pub struct Gpt2ModelOutput {
/// Hidden state of the last layer of the decoder, or logits for a custom head
/// module after the decoder (e.g. vocabulary logits for language modeling tasks)
pub output: Tensor,
/// Cached attention layers keys and values if the model is used for generation
pub cache: Option<Vec<Tensor>>,
/// Hidden states for all intermediate layers
pub all_hidden_states: Option<Vec<Tensor>>,
/// Attention weights for all intermediate layers
pub all_attentions: Option<Vec<Tensor>>,
}

View File

@ -56,10 +56,10 @@
//! ```
pub(crate) mod attention;
mod gpt2;
mod gpt2_model;
pub(crate) mod transformer;
pub use gpt2::{
pub use gpt2_model::{
GPT2LMHeadModel, Gpt2Config, Gpt2ConfigResources, Gpt2MergesResources, Gpt2Model,
Gpt2ModelResources, Gpt2VocabResources, GptActivation,
Gpt2ModelOutput, Gpt2ModelResources, Gpt2VocabResources, GptActivation,
};

View File

@ -15,7 +15,7 @@
use crate::common::activations::{_gelu_new, _relu, _swish};
use crate::common::dropout::Dropout;
use crate::gpt2::attention::{Attention, GPTConv1D};
use crate::gpt2::gpt2::{Gpt2Config, GptActivation};
use crate::gpt2::gpt2_model::{Gpt2Config, GptActivation};
use std::borrow::Borrow;
use tch::{nn, Tensor};

View File

@ -11,8 +11,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use crate::bart::{BartConfig, BartModel, LayerState};
use crate::pipelines::generation::{Cache, LMHeadModel};
use crate::bart::{BartConfig, BartEncoderOutput, BartModel, BartModelOutput, LayerState};
use crate::pipelines::generation::{Cache, LMHeadModel, LMModelOutput};
use crate::RustBertError;
use std::borrow::Borrow;
use tch::nn::Init;
use tch::{nn, Tensor};
@ -295,12 +296,14 @@ impl MarianForConditionalGeneration {
///
/// # Returns
///
/// * `lm_logits` - `Tensor` of shape (*batch size*, *target_sequence_length*, *vocab_size*) representing the logits for each vocab item and position
/// * `encoder_hidden_states` - `Tensor` of shape (*batch size*, *source_sequence_length*, *hidden_size*) representing the activations of the last encoder hidden state
/// * `all_encoder_hidden_states` - `Option<Vec<Tensor>>` of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*)
/// * `all_encoder_attentions` - `Option<Vec<Tensor>>` of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*)
/// * `all_decoder_hidden_states` - `Option<Vec<Tensor>>` of length *num_decoder_layers* with shape (*batch size*, *target_sequence_length*, *hidden_size*)
/// * `all_decoder_attentions` - `Option<Vec<Tensor>>` of length *num_decoder_layers* with shape (*batch size*, *target_sequence_length*, *hidden_size*)
/// * `BartModelOutput` containing:
/// - `decoder_output` - `Tensor` of shape (*batch size*, *target_sequence_length*, *vocab_size*) representing the logits for each vocabulary item and position
/// - `encoder_hidden_states` - `Tensor` of shape (*batch size*, *source_sequence_length*, *hidden_size*) representing the activations of the last encoder hidden state
/// - `cache` - `(Option<Tensor>, Option<Vec<&LayerState, &LayerState>>)` of length *n_layer* containing the encoder padding mask and past keys and values for both the self attention and the encoder cross attention of each layer of the decoder.
/// - `all_encoder_hidden_states` - `Option<Vec<Tensor>>` of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*)
/// - `all_encoder_attentions` - `Option<Vec<Tensor>>` of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*)
/// - `all_decoder_hidden_states` - `Option<Vec<Tensor>>` of length *num_decoder_layers* with shape (*batch size*, *target_sequence_length*, *hidden_size*)
/// - `all_decoder_attentions` - `Option<Vec<Tensor>>` of length *num_decoder_layers* with shape (*batch size*, *target_sequence_length*, *hidden_size*)
///
/// # Example
///
@ -325,15 +328,7 @@ impl MarianForConditionalGeneration {
/// let decoder_attention_mask =
/// Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
///
/// let (
/// decoder_output,
/// encoder_hidden_states,
/// cache,
/// all_encoder_hidden_states,
/// all_encoder_attentions,
/// all_decoder_hidden_states,
/// all_decoder_attentions,
/// ) = no_grad(|| {
/// let model_output = no_grad(|| {
/// marian_model.forward_t(
/// Some(&input_tensor),
/// Some(&encoder_attention_mask),
@ -349,29 +344,13 @@ impl MarianForConditionalGeneration {
&self,
input_ids: Option<&Tensor>,
attention_mask: Option<&Tensor>,
encoder_outputs: Option<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>)>,
encoder_outputs: Option<BartEncoderOutput>,
decoder_input_ids: Option<&Tensor>,
decoder_attention_mask: Option<&Tensor>,
old_layer_states: Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
train: bool,
) -> (
Tensor,
Tensor,
Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
Option<Vec<Tensor>>,
Option<Vec<Tensor>>,
Option<Vec<Tensor>>,
Option<Vec<Tensor>>,
) {
let (
decoder_outputs,
encoder_hidden_states,
decoder_cache,
all_decoder_hidden_states,
all_decoder_attentions,
all_encoder_hidden_states,
all_encoder_attentions,
) = self.base_model.forward_t(
) -> BartModelOutput {
let base_model_output = self.base_model.forward_t(
input_ids,
attention_mask,
decoder_input_ids,
@ -381,26 +360,25 @@ impl MarianForConditionalGeneration {
train,
);
let lm_logits = decoder_outputs.linear::<Tensor>(&self.base_model.embeddings.ws, None);
(
lm_logits,
encoder_hidden_states,
decoder_cache,
all_decoder_hidden_states,
all_decoder_attentions,
all_encoder_hidden_states,
all_encoder_attentions,
)
let lm_logits = base_model_output
.decoder_output
.linear::<Tensor>(&self.base_model.embeddings.ws, None);
BartModelOutput {
decoder_output: lm_logits,
..base_model_output
}
}
pub fn encode(&self, input_ids: &Tensor, attention_mask: Option<&Tensor>) -> Tensor {
let (encoder_hidden_states, _, _) = self.base_model.encoder.forward_t(
input_ids,
attention_mask,
&self.base_model.embeddings,
false,
);
encoder_hidden_states
self.base_model
.encoder
.forward_t(
input_ids,
attention_mask,
&self.base_model.embeddings,
false,
)
.hidden_state
}
}
@ -423,11 +401,13 @@ impl LMHeadModel for MarianForConditionalGeneration {
///
/// # Returns
///
/// * `lm_logits` - `Tensor` of shape (*batch size*, *sequence_length*, *vocab_size*) representing the logits for each vocab item and position
/// * `past` - None
/// * `encoder_hidden_states` - `Option<Tensor>` Hidden states for the encoder
/// * `hidden_states` - None
/// * `attentions` - None
/// * `LMModelOutput` containing:
/// - `lm_logits` - `Tensor` of shape (*batch size*, *sequence_length*, *vocab_size*) representing the logits for each vocab item and position
/// - `cache` - `BartCache` made of `Option<Vec<(Option<Vec<&LayerState, &LayerState>>)>>` of length *n_layer* containing the encoder past keys and values for
/// both the self attention and the encoder cross attention of each layer of the decoder.
/// - `encoder_hidden_states` - `Option<Tensor>` Hidden states for the encoder
/// - `all_hidden_states` - None
/// - `all_attentions` - None
///
/// # Example
///
@ -452,15 +432,7 @@ impl LMHeadModel for MarianForConditionalGeneration {
/// let decoder_attention_mask =
/// Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
///
/// let (
/// decoder_output,
/// encoder_hidden_states,
/// cache,
/// all_encoder_hidden_states,
/// all_encoder_attentions,
/// all_decoder_hidden_states,
/// all_decoder_attentions,
/// ) = no_grad(|| {
/// let model_output = no_grad(|| {
/// marian_model.forward_t(
/// Some(&input_tensor),
/// Some(&encoder_attention_mask),
@ -483,22 +455,17 @@ impl LMHeadModel for MarianForConditionalGeneration {
encoder_outputs: Option<&Tensor>,
decoder_input_ids: &Option<Tensor>,
train: bool,
) -> Result<
(
Tensor,
Option<Tensor>,
Cache,
Option<Vec<Tensor>>,
Option<Vec<Tensor>>,
),
&'static str,
> {
let (decoder_output, encoder_hidden_states, new_cache, _, _, _, _) = match cache {
) -> Result<LMModelOutput, RustBertError> {
let base_model_output = match cache {
Cache::BARTCache(cached_layer_states) => self.base_model.forward_t(
input_ids.as_ref(),
attention_mask.as_ref(),
decoder_input_ids.as_ref(),
Some((encoder_outputs.as_ref().unwrap().copy(), None, None)),
Some(BartEncoderOutput {
hidden_state: encoder_outputs.as_ref().unwrap().copy(),
all_hidden_states: None,
all_attentions: None,
}),
None,
cached_layer_states,
train,
@ -507,22 +474,32 @@ impl LMHeadModel for MarianForConditionalGeneration {
input_ids.as_ref(),
attention_mask.as_ref(),
decoder_input_ids.as_ref(),
Some((encoder_outputs.as_ref().unwrap().copy(), None, None)),
Some(BartEncoderOutput {
hidden_state: encoder_outputs.as_ref().unwrap().copy(),
all_hidden_states: None,
all_attentions: None,
}),
None,
None,
train,
),
_ => Err("Cache not compatible with Marian Model")?,
_ => {
return Err(RustBertError::ValueError(
"Cache not compatible with Marian Model".into(),
));
}
};
let lm_logits = decoder_output.linear::<Tensor>(&self.base_model.embeddings.ws, None)
let lm_logits = base_model_output
.decoder_output
.linear::<Tensor>(&self.base_model.embeddings.ws, None)
+ &self.final_logits_bias;
Ok((
Ok(LMModelOutput {
lm_logits,
Some(encoder_hidden_states),
Cache::BARTCache(new_cache),
None,
None,
))
encoder_hidden_state: Some(base_model_output.encoder_hidden_state),
cache: Cache::BARTCache(base_model_output.cache),
all_hidden_states: None,
all_attentions: None,
})
}
}

View File

@ -57,9 +57,9 @@
//! # }
//! ```
mod marian;
mod marian_model;
pub use marian::{
pub use marian_model::{
MarianConfigResources, MarianForConditionalGeneration, MarianModelResources, MarianPrefix,
MarianSpmResources, MarianVocabResources,
};

View File

@ -6,7 +6,7 @@
//!
//! # Model set-up and pre-trained weights loading
//!
//! A full working example is provided in `examples/openai_gpt.rs`, run with `cargo run --example openai_gpt`.
//! A full working example is provided in `examples/openai_gpt`, run with `cargo run --example openai_gpt`.
//! All models expect the following resources:
//! - Configuration file expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers)
//! - Model weights are expected to have a structure and parameter names following the [Transformers library](https://github.com/huggingface/transformers). A conversion using the Python utility scripts is required to convert the `.bin` weights to the `.ot` format.
@ -55,10 +55,10 @@
//! # }
//! ```
mod openai_gpt;
mod openai_gpt_model;
mod transformer;
pub use openai_gpt::{
pub use openai_gpt_model::{
OpenAIGPTLMHeadModel, OpenAiGptConfigResources, OpenAiGptMergesResources, OpenAiGptModel,
OpenAiGptModelResources, OpenAiGptVocabResources,
OpenAiGptModelOutput, OpenAiGptModelResources, OpenAiGptVocabResources,
};

View File

@ -16,7 +16,8 @@ use crate::common::dropout::Dropout;
use crate::common::linear::{linear_no_bias, LinearNoBias};
use crate::gpt2::Gpt2Config;
use crate::openai_gpt::transformer::Block;
use crate::pipelines::generation::{Cache, LMHeadModel};
use crate::pipelines::generation::{Cache, LMHeadModel, LMModelOutput};
use crate::RustBertError;
use std::borrow::{Borrow, BorrowMut};
use tch::kind::Kind::Int64;
use tch::nn::embedding;
@ -166,9 +167,10 @@ impl OpenAiGptModel {
///
/// # Returns
///
/// * `output` - `Tensor` of shape (*batch size*, *sequence_length*, *hidden_size*) representing the activations of the last hidden state
/// * `hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
/// * `attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
/// * `OpenAiGptModelOutput` containing:
/// - `output` - `Tensor` of shape (*batch size*, *sequence_length*, *hidden_size*) representing the activations of the last hidden state
/// - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
/// - `all_attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
///
/// # Example
///
@ -192,7 +194,7 @@ impl OpenAiGptModel {
/// let position_ids = Tensor::arange(sequence_length, (Int64, device))
/// .expand(&[batch_size, sequence_length], true);
///
/// let (output, hidden_states, attentions) = no_grad(|| {
/// let model_output = no_grad(|| {
/// gpt_model
/// .forward_t(
/// &Some(input_tensor),
@ -213,11 +215,13 @@ impl OpenAiGptModel {
position_ids: &Option<Tensor>,
input_embeds: &Option<Tensor>,
train: bool,
) -> Result<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
) -> Result<OpenAiGptModelOutput, RustBertError> {
let (input_embeddings, seq_length) = match input_ids {
Some(input_value) => match input_embeds {
Some(_) => {
return Err("Only one of input ids or input embeddings may be set");
return Err(RustBertError::ValueError(
"Only one of input ids or input embeddings may be set".into(),
));
}
None => (
input_value.apply(&self.tokens_embed),
@ -227,7 +231,9 @@ impl OpenAiGptModel {
None => match input_embeds {
Some(embeds) => (embeds.copy(), embeds.size()[1]),
None => {
return Err("At least one of input ids or input embeddings must be set");
return Err(RustBertError::ValueError(
"At least one of input ids or input embeddings must be set".into(),
));
}
},
};
@ -267,25 +273,23 @@ impl OpenAiGptModel {
None
};
let mut layers = self.h.iter();
loop {
match layers.next() {
Some(layer) => {
if let Some(hidden_states) = all_hidden_states.borrow_mut() {
hidden_states.push(hidden_state.as_ref().copy());
};
for layer in &self.h {
if let Some(hidden_states) = all_hidden_states.borrow_mut() {
hidden_states.push(hidden_state.as_ref().copy());
};
let temp = layer.forward_t(&hidden_state, &attention_mask, train);
hidden_state = temp.0;
if let Some(attentions) = all_attentions.borrow_mut() {
attentions.push(temp.1.as_ref().unwrap().copy());
};
}
None => break,
let temp = layer.forward_t(&hidden_state, &attention_mask, train);
hidden_state = temp.0;
if let Some(attentions) = all_attentions.borrow_mut() {
attentions.push(temp.1.as_ref().unwrap().copy());
};
}
Ok((hidden_state, all_hidden_states, all_attentions))
Ok(OpenAiGptModelOutput {
hidden_state,
all_hidden_states,
all_attentions,
})
}
}
@ -360,11 +364,12 @@ impl LMHeadModel for OpenAIGPTLMHeadModel {
///
/// # Returns
///
/// * `output` - `Tensor` of shape (*batch size*, *sequence_length*, *vocab_size*) representing the logits for each vocab item and position
/// * `encoder_hidden_states` - None
/// * `past` - None
/// * `hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
/// * `attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
/// * `LMModelOutput` containing:
/// - `lm_logits` - `Tensor` of shape (*batch size*, *sequence_length*, *vocab_size*) representing the logits for each vocab item and position
/// - `cache` - None
/// - `encoder_hidden_states` - None
/// - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
/// - `all_attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
///
/// # Example
///
@ -388,7 +393,7 @@ impl LMHeadModel for OpenAIGPTLMHeadModel {
/// let token_type_ids = Tensor::ones(&[batch_size, sequence_length], (Int64, device));
/// let position_ids = Tensor::arange(sequence_length, (Int64, device)).expand(&[batch_size, sequence_length], true);
///
/// let (output, _, _, hidden_states, attentions) = no_grad(|| {
/// let model_output = no_grad(|| {
/// gpt_model
/// .forward_t(&Some(input_tensor),
/// Cache::None,
@ -412,17 +417,8 @@ impl LMHeadModel for OpenAIGPTLMHeadModel {
_encoder_outputs: Option<&Tensor>,
_decoder_input_ids: &Option<Tensor>,
train: bool,
) -> Result<
(
Tensor,
Option<Tensor>,
Cache,
Option<Vec<Tensor>>,
Option<Vec<Tensor>>,
),
&'static str,
> {
let (output, all_hidden_states, all_attentions) = self.transformer.forward_t(
) -> Result<LMModelOutput, RustBertError> {
let base_model_output = self.transformer.forward_t(
input_ids,
attention_mask,
token_type_ids,
@ -431,13 +427,24 @@ impl LMHeadModel for OpenAIGPTLMHeadModel {
train,
)?;
let lm_logits = output.apply(&self.lm_head);
Ok((
let lm_logits = base_model_output.hidden_state.apply(&self.lm_head);
Ok(LMModelOutput {
lm_logits,
None,
Cache::None,
all_hidden_states,
all_attentions,
))
encoder_hidden_state: None,
cache: Cache::None,
all_hidden_states: base_model_output.all_hidden_states,
all_attentions: base_model_output.all_attentions,
})
}
}
/// Container for the OpenAI GPT model output.
pub struct OpenAiGptModelOutput {
/// Hidden state of the last layer of the decoder, or logits for a custom head
/// module after the decoder (e.g. vocabulary logits for language modeling tasks)
pub hidden_state: Tensor,
/// Hidden states for all intermediate layers
pub all_hidden_states: Option<Vec<Tensor>>,
/// Attention weights for all intermediate layers
pub all_attentions: Option<Vec<Tensor>>,
}

View File

@ -16,7 +16,6 @@
//! generic pipelines. The model component is defined in the generic pipeline itself as the
//! pre-processing, forward pass and postprocessing differs between pipelines while basic config and
//! tokenization objects don't.
//!
use crate::albert::AlbertConfig;
use crate::bart::BartConfig;
use crate::bert::BertConfig;
@ -335,87 +334,91 @@ impl TokenizerOption {
original_offsets_2: Option<Vec<Vec<OffsetSize>>>,
mask_1: Vec<Mask>,
mask_2: Option<Vec<Mask>>,
) -> (
Vec<i64>,
Vec<i8>,
Vec<i8>,
Vec<Option<Offset>>,
Vec<Vec<OffsetSize>>,
Vec<Mask>,
) {
match *self {
Self::Bert(ref tokenizer) => tokenizer.build_input_with_special_tokens(
tokens_1,
tokens_2,
offsets_1,
offsets_2,
original_offsets_1,
original_offsets_2,
mask_1,
mask_2,
),
Self::Roberta(ref tokenizer) => tokenizer.build_input_with_special_tokens(
tokens_1,
tokens_2,
offsets_1,
offsets_2,
original_offsets_1,
original_offsets_2,
mask_1,
mask_2,
),
Self::XLMRoberta(ref tokenizer) => tokenizer.build_input_with_special_tokens(
tokens_1,
tokens_2,
offsets_1,
offsets_2,
original_offsets_1,
original_offsets_2,
mask_1,
mask_2,
),
Self::Marian(ref tokenizer) => tokenizer.build_input_with_special_tokens(
tokens_1,
tokens_2,
offsets_1,
offsets_2,
original_offsets_1,
original_offsets_2,
mask_1,
mask_2,
),
Self::T5(ref tokenizer) => tokenizer.build_input_with_special_tokens(
tokens_1,
tokens_2,
offsets_1,
offsets_2,
original_offsets_1,
original_offsets_2,
mask_1,
mask_2,
),
Self::Albert(ref tokenizer) => tokenizer.build_input_with_special_tokens(
tokens_1,
tokens_2,
offsets_1,
offsets_2,
original_offsets_1,
original_offsets_2,
mask_1,
mask_2,
),
) -> TokenizedInput {
let (token_ids, segment_ids, special_tokens_mask, token_offsets, reference_offsets, mask) =
match *self {
Self::Bert(ref tokenizer) => tokenizer.build_input_with_special_tokens(
tokens_1,
tokens_2,
offsets_1,
offsets_2,
original_offsets_1,
original_offsets_2,
mask_1,
mask_2,
),
Self::Roberta(ref tokenizer) => tokenizer.build_input_with_special_tokens(
tokens_1,
tokens_2,
offsets_1,
offsets_2,
original_offsets_1,
original_offsets_2,
mask_1,
mask_2,
),
Self::XLMRoberta(ref tokenizer) => tokenizer.build_input_with_special_tokens(
tokens_1,
tokens_2,
offsets_1,
offsets_2,
original_offsets_1,
original_offsets_2,
mask_1,
mask_2,
),
Self::Marian(ref tokenizer) => tokenizer.build_input_with_special_tokens(
tokens_1,
tokens_2,
offsets_1,
offsets_2,
original_offsets_1,
original_offsets_2,
mask_1,
mask_2,
),
Self::T5(ref tokenizer) => tokenizer.build_input_with_special_tokens(
tokens_1,
tokens_2,
offsets_1,
offsets_2,
original_offsets_1,
original_offsets_2,
mask_1,
mask_2,
),
Self::Albert(ref tokenizer) => tokenizer.build_input_with_special_tokens(
tokens_1,
tokens_2,
offsets_1,
offsets_2,
original_offsets_1,
original_offsets_2,
mask_1,
mask_2,
),
};
TokenizedInput {
token_ids,
segment_ids,
special_tokens_mask,
overflowing_tokens: vec![],
num_truncated_tokens: 0,
token_offsets,
reference_offsets,
mask,
}
}
/// Interface method to convert tokens to ids
pub fn convert_tokens_to_ids(&self, tokens: &Vec<String>) -> Vec<i64> {
pub fn convert_tokens_to_ids(&self, tokens: &[String]) -> Vec<i64> {
match *self {
Self::Bert(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens),
Self::Roberta(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens),
Self::Marian(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens),
Self::T5(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens),
Self::XLMRoberta(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens),
Self::Albert(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens),
Self::Bert(ref tokenizer) => tokenizer.convert_tokens_to_ids(&tokens.into()),
Self::Roberta(ref tokenizer) => tokenizer.convert_tokens_to_ids(&tokens.into()),
Self::Marian(ref tokenizer) => tokenizer.convert_tokens_to_ids(&tokens.into()),
Self::T5(ref tokenizer) => tokenizer.convert_tokens_to_ids(&tokens.into()),
Self::XLMRoberta(ref tokenizer) => tokenizer.convert_tokens_to_ids(&tokens.into()),
Self::Albert(ref tokenizer) => tokenizer.convert_tokens_to_ids(&tokens.into()),
}
}

View File

@ -70,21 +70,21 @@ pub struct ConversationConfig {
/// Merges resource (default: DialoGPT-medium)
pub merges_resource: Resource,
/// Minimum sequence length (default: 0)
pub min_length: u64,
pub min_length: i64,
/// Maximum sequence length (default: 20)
pub max_length: u64,
pub max_length: i64,
/// Minimum free length available for generated responses (default: 32)
pub min_length_for_response: u64,
pub min_length_for_response: i64,
/// Sampling flag. If true, will perform top-k and/or nucleus sampling on generated tokens, otherwise greedy (deterministic) decoding (default: true)
pub do_sample: bool,
/// Early stopping flag indicating if the beam search should stop as soon as `num_beam` hypotheses have been generated (default: false)
pub early_stopping: bool,
/// Number of beams for beam search (default: 5)
pub num_beams: u64,
pub num_beams: i64,
/// Temperature setting. Values higher than 1 will improve originality at the risk of reducing relevance (default: 1.0)
pub temperature: f64,
/// Top_k values for sampling tokens. Value higher than 0 will enable the feature (default: 0)
pub top_k: u64,
pub top_k: i64,
/// Top_p value for [Nucleus sampling, Holtzman et al.](http://arxiv.org/abs/1904.09751). Keep top tokens until cumulative probability reaches top_p (default: 0.9)
pub top_p: f64,
/// Repetition penalty (mostly useful for CTRL decoders). Values higher than 1 will penalize tokens that have been already generated. (default: 1.0)
@ -92,9 +92,9 @@ pub struct ConversationConfig {
/// Exponential penalty based on the length of the hypotheses generated (default: 1.0)
pub length_penalty: f64,
/// Number of allowed repetitions of n-grams. Values higher than 0 turn on this feature (default: 3)
pub no_repeat_ngram_size: u64,
pub no_repeat_ngram_size: i64,
/// Number of sequences to return for each prompt text (default: 1)
pub num_return_sequences: u64,
pub num_return_sequences: i64,
/// Device to place the model on (default: CUDA/GPU when available)
pub device: Device,
}
@ -203,9 +203,11 @@ impl Conversation {
/// let mut conversation = Conversation::new_empty();
/// conversation.add_user_input("Hi there!");
/// ```
pub fn add_user_input(&mut self, text: &str) -> Result<(), &'static str> {
pub fn add_user_input(&mut self, text: &str) -> Result<(), RustBertError> {
if self.new_user_input.is_some() {
Err("User input already provided for this conversation")
Err(RustBertError::ValueError(
"User input already provided for this conversation".into(),
))
} else {
self.new_user_input = Some(text.to_string());
Ok(())
@ -306,12 +308,10 @@ impl Conversation {
pub fn get_last_input(&self) -> Option<&str> {
if self.new_user_input.is_some() {
Some(self.new_user_input.as_ref().unwrap().as_str())
} else if !self.past_user_inputs.is_empty() {
Some(self.past_user_inputs.last().unwrap().as_str())
} else {
if self.past_user_inputs.len() > 0 {
Some(self.past_user_inputs.last().unwrap().as_str())
} else {
None
}
None
}
}
@ -566,12 +566,18 @@ impl ConversationManager {
}
}
impl Default for ConversationManager {
fn default() -> Self {
Self::new()
}
}
/// # Conversation model
/// Processes a ConversationManager and generate system responses for active conversations.
pub struct ConversationModel {
model: GPT2Generator,
eos_token_id: i64,
max_allowed_context_length: u64,
max_allowed_context_length: i64,
}
impl ConversationModel {
@ -617,7 +623,7 @@ impl ConversationModel {
let model = GPT2Generator::new(generate_config)?;
let eos_token_id = *model.get_eos_ids().as_ref().unwrap().first().unwrap();
let max_allowed_length =
conversation_config.max_length as u64 - conversation_config.min_length_for_response;
conversation_config.max_length - conversation_config.min_length_for_response;
Ok(ConversationModel {
model,
eos_token_id,

View File

@ -73,7 +73,9 @@ use crate::openai_gpt::{
OpenAIGPTLMHeadModel, OpenAiGptConfigResources, OpenAiGptMergesResources,
OpenAiGptModelResources, OpenAiGptVocabResources,
};
use crate::pipelines::generation::private_generation_utils::PrivateLanguageGenerator;
use crate::pipelines::generation::private_generation_utils::{
GenerateOptions, PrivateLanguageGenerator,
};
use crate::t5::{
LayerState as T5LayerState, T5Config, T5ConfigResources, T5ForConditionalGeneration,
T5ModelResources, T5VocabResources,
@ -104,19 +106,19 @@ pub struct GenerateConfig {
/// Merges resource (default: pretrained GPT2 model)
pub merges_resource: Resource,
/// Minimum sequence length (default: 0)
pub min_length: u64,
pub min_length: i64,
/// Maximum sequence length (default: 20)
pub max_length: u64,
pub max_length: i64,
/// Sampling flag. If true, will perform top-k and/or nucleus sampling on generated tokens, otherwise greedy (deterministic) decoding (default: true)
pub do_sample: bool,
/// Early stopping flag indicating if the beam search should stop as soon as `num_beam` hypotheses have been generated (default: false)
pub early_stopping: bool,
/// Number of beams for beam search (default: 5)
pub num_beams: u64,
pub num_beams: i64,
/// Temperature setting. Values higher than 1 will improve originality at the risk of reducing relevance (default: 1.0)
pub temperature: f64,
/// Top_k values for sampling tokens. Value higher than 0 will enable the feature (default: 0)
pub top_k: u64,
pub top_k: i64,
/// Top_p value for [Nucleus sampling, Holtzman et al.](http://arxiv.org/abs/1904.09751). Keep top tokens until cumulative probability reaches top_p (default: 0.9)
pub top_p: f64,
/// Repetition penalty (mostly useful for CTRL decoders). Values higher than 1 will penalize tokens that have been already generated. (default: 1.0)
@ -124,9 +126,9 @@ pub struct GenerateConfig {
/// Exponential penalty based on the length of the hypotheses generated (default: 1.0)
pub length_penalty: f64,
/// Number of allowed repetitions of n-grams. Values higher than 0 turn on this feature (default: 3)
pub no_repeat_ngram_size: u64,
pub no_repeat_ngram_size: i64,
/// Number of sequences to return for each prompt text (default: 1)
pub num_return_sequences: u64,
pub num_return_sequences: i64,
/// Device to place the model on (default: CUDA/GPU when available)
pub device: Device,
}
@ -179,11 +181,11 @@ impl GenerateConfig {
"length_penalty must be strictly greater than 0"
);
assert!(
self.num_return_sequences > 0u64,
self.num_return_sequences > 0i64,
"num_return_sequences must be strictly greater than 0"
);
assert!(
self.num_beams > 0u64,
self.num_beams > 0i64,
"num_beams must be strictly greater than 0"
);
@ -245,8 +247,8 @@ impl OpenAIGenerator {
generate_config.validate();
// The following allow keeping the same GenerationConfig Default for GPT, GPT2 and BART models
let model_resource = if &generate_config.model_resource
== &Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2))
let model_resource = if generate_config.model_resource
== Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2))
{
Resource::Remote(RemoteResource::from_pretrained(
OpenAiGptModelResources::GPT,
@ -255,8 +257,8 @@ impl OpenAIGenerator {
generate_config.model_resource.clone()
};
let config_resource = if &generate_config.config_resource
== &Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2))
let config_resource = if generate_config.config_resource
== Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2))
{
Resource::Remote(RemoteResource::from_pretrained(
OpenAiGptConfigResources::GPT,
@ -265,8 +267,8 @@ impl OpenAIGenerator {
generate_config.config_resource.clone()
};
let vocab_resource = if &generate_config.vocab_resource
== &Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2))
let vocab_resource = if generate_config.vocab_resource
== Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2))
{
Resource::Remote(RemoteResource::from_pretrained(
OpenAiGptVocabResources::GPT,
@ -275,8 +277,8 @@ impl OpenAIGenerator {
generate_config.vocab_resource.clone()
};
let merges_resource = if &generate_config.merges_resource
== &Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2))
let merges_resource = if generate_config.merges_resource
== Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2))
{
Resource::Remote(RemoteResource::from_pretrained(
OpenAiGptMergesResources::GPT,
@ -575,32 +577,32 @@ impl BartGenerator {
/// ```
pub fn new(generate_config: GenerateConfig) -> Result<BartGenerator, RustBertError> {
// The following allow keeping the same GenerationConfig Default for GPT, GPT2 and BART models
let model_resource = if &generate_config.model_resource
== &Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2))
let model_resource = if generate_config.model_resource
== Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2))
{
Resource::Remote(RemoteResource::from_pretrained(BartModelResources::BART))
} else {
generate_config.model_resource.clone()
};
let config_resource = if &generate_config.config_resource
== &Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2))
let config_resource = if generate_config.config_resource
== Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2))
{
Resource::Remote(RemoteResource::from_pretrained(BartConfigResources::BART))
} else {
generate_config.config_resource.clone()
};
let vocab_resource = if &generate_config.vocab_resource
== &Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2))
let vocab_resource = if generate_config.vocab_resource
== Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2))
{
Resource::Remote(RemoteResource::from_pretrained(BartVocabResources::BART))
} else {
generate_config.vocab_resource.clone()
};
let merges_resource = if &generate_config.merges_resource
== &Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2))
let merges_resource = if generate_config.merges_resource
== Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2))
{
Resource::Remote(RemoteResource::from_pretrained(BartMergesResources::BART))
} else {
@ -702,7 +704,7 @@ impl PrivateLanguageGenerator<BartForConditionalGeneration, RobertaVocab, Robert
max_length: i64,
) {
if current_length == 1 {
self.force_token_id_generation(scores, &vec![self.get_bos_id().unwrap()]);
self.force_token_id_generation(scores, &[self.get_bos_id().unwrap()]);
} else if current_length == max_length - 1 {
self.force_token_id_generation(scores, self.get_eos_ids().as_ref().unwrap());
}
@ -747,7 +749,7 @@ impl PrivateLanguageGenerator<BartForConditionalGeneration, RobertaVocab, Robert
fn encode_prompt_text(
&self,
prompt_text: Vec<&str>,
max_len: u64,
max_len: i64,
pad_token_id: Option<i64>,
) -> Tensor {
let tokens = self.get_tokenizer().encode_list(
@ -797,7 +799,7 @@ impl PrivateLanguageGenerator<BartForConditionalGeneration, RobertaVocab, Robert
match past {
Cache::BARTCache(old_cache_option) => match old_cache_option {
Some(old_cache) => {
for (self_layer_state, encoder_layer_state) in old_cache.into_iter() {
for (self_layer_state, encoder_layer_state) in old_cache.iter_mut() {
if self_layer_state.is_some() {
self_layer_state
.as_mut()
@ -1022,7 +1024,7 @@ impl PrivateLanguageGenerator<MarianForConditionalGeneration, MarianVocab, Maria
fn encode_prompt_text(
&self,
prompt_text: Vec<&str>,
max_len: u64,
max_len: i64,
pad_token_id: Option<i64>,
) -> Tensor {
let tokens = self.get_tokenizer().encode_list(
@ -1072,7 +1074,7 @@ impl PrivateLanguageGenerator<MarianForConditionalGeneration, MarianVocab, Maria
match past {
Cache::BARTCache(old_cache_option) => match old_cache_option {
Some(old_cache) => {
for (self_layer_state, encoder_layer_state) in old_cache.into_iter() {
for (self_layer_state, encoder_layer_state) in old_cache.iter_mut() {
if self_layer_state.is_some() {
self_layer_state
.as_mut()
@ -1119,24 +1121,24 @@ pub struct T5Generator {
impl T5Generator {
pub fn new(generate_config: GenerateConfig) -> Result<T5Generator, RustBertError> {
// The following allow keeping the same GenerationConfig Default for GPT, GPT2 and BART models
let model_resource = if &generate_config.model_resource
== &Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2))
let model_resource = if generate_config.model_resource
== Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2))
{
Resource::Remote(RemoteResource::from_pretrained(T5ModelResources::T5_SMALL))
} else {
generate_config.model_resource.clone()
};
let config_resource = if &generate_config.config_resource
== &Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2))
let config_resource = if generate_config.config_resource
== Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2))
{
Resource::Remote(RemoteResource::from_pretrained(T5ConfigResources::T5_SMALL))
} else {
generate_config.config_resource.clone()
};
let vocab_resource = if &generate_config.vocab_resource
== &Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2))
let vocab_resource = if generate_config.vocab_resource
== Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2))
{
Resource::Remote(RemoteResource::from_pretrained(T5VocabResources::T5_SMALL))
} else {
@ -1254,7 +1256,7 @@ impl PrivateLanguageGenerator<T5ForConditionalGeneration, T5Vocab, T5Tokenizer>
fn encode_prompt_text(
&self,
prompt_text: Vec<&str>,
max_len: u64,
max_len: i64,
pad_token_id: Option<i64>,
) -> Tensor {
let tokens = self.get_tokenizer().encode_list(
@ -1304,7 +1306,7 @@ impl PrivateLanguageGenerator<T5ForConditionalGeneration, T5Vocab, T5Tokenizer>
match past {
Cache::T5Cache(old_cache_option) => match old_cache_option {
Some(old_cache) => {
for (self_layer_state, encoder_layer_state) in old_cache.into_iter() {
for (self_layer_state, encoder_layer_state) in old_cache.iter_mut() {
if self_layer_state.is_some() {
self_layer_state
.as_mut()
@ -1351,6 +1353,23 @@ pub(crate) mod private_generation_utils {
use tch::kind::Kind::{Bool, Float, Int64};
use tch::{nn, Device, Tensor};
pub struct GenerateOptions {
pub min_length: i64,
pub max_length: i64,
pub do_sample: bool,
pub temperature: f64,
pub top_k: i64,
pub top_p: f64,
pub repetition_penalty: f64,
pub no_repeat_ngram_size: i64,
pub pad_token_id: Option<i64>,
pub eos_token_ids: Option<Vec<i64>>,
pub num_return_sequences: i64,
pub early_stopping: bool,
pub num_beams: i64,
pub length_penalty: f64,
}
pub trait PrivateLanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>> {
fn get_model(&self) -> &T;
fn get_tokenizer(&self) -> &U;
@ -1394,7 +1413,7 @@ pub(crate) mod private_generation_utils {
fn encode_prompt_text(
&self,
prompt_text: Vec<&str>,
max_len: u64,
max_len: i64,
pad_token_id: Option<i64>,
) -> Tensor {
let tokens = self.get_tokenizer().tokenize_list(prompt_text);
@ -1460,7 +1479,7 @@ pub(crate) mod private_generation_utils {
&self,
next_token_logits: &mut Tensor,
batch_size: i64,
num_beams: u64,
num_beams: i64,
prev_output_tokens: &Tensor,
repetition_penalty: f64,
) {
@ -1469,7 +1488,7 @@ pub(crate) mod private_generation_utils {
let token = prev_output_tokens.get(i).int64_value(&[token_position]);
let updated_value = &next_token_logits.double_value(&[i, token]);
if updated_value < &0f64 {
&next_token_logits.get(i).index_fill_(
let _ = next_token_logits.get(i).index_fill_(
0,
&Tensor::of_slice(&[token])
.to_kind(Int64)
@ -1477,7 +1496,7 @@ pub(crate) mod private_generation_utils {
updated_value * repetition_penalty,
);
} else {
&next_token_logits.get(i).index_fill_(
let _ = next_token_logits.get(i).index_fill_(
0,
&Tensor::of_slice(&[token])
.to_kind(Int64)
@ -1521,11 +1540,10 @@ pub(crate) mod private_generation_utils {
let ngram = &hypothesis_input_ids[ngram.0 as usize..ngram.1 as usize + 1];
let key = ngram[..no_repeat_ngram_size as usize - 1].to_vec();
let value = *ngram.last().unwrap();
if generated_ngram.contains_key(&key) {
generated_ngram.get_mut(&key).unwrap().push(value)
} else {
generated_ngram.insert(key, vec![value]);
}
generated_ngram
.entry(key)
.or_insert_with(|| vec![value])
.push(value);
}
let hypothesis_banned_tokens = match generated_ngram.get(query) {
Some(banned_tokens) => banned_tokens.clone(),
@ -1551,21 +1569,20 @@ pub(crate) mod private_generation_utils {
let top_k = vocab_size - min(max(top_k, min_tokens_to_keep), vocab_size);
let (_, indices_to_remove) = logits.topk(top_k, -1, false, false);
for index in 0..*logits.size().first().unwrap() {
&logits.get(index).index_fill_(
let _ = logits.get(index).index_fill_(
0,
&indices_to_remove.get(index),
std::f64::NEG_INFINITY,
);
}
}
if top_p < 1f64 {
let (sorted_logits, sorted_indices) = logits.sort(-1, true);
let cumulative_probabilities = sorted_logits.softmax(-1, Float).cumsum(-1, Float);
let mut sorted_indices_to_remove =
cumulative_probabilities.ge(top_p).to_kind(Int64);
if min_tokens_to_keep > 1 {
&sorted_indices_to_remove.index_fill_(
let _ = sorted_indices_to_remove.index_fill_(
1,
&Tensor::arange1(0, min_tokens_to_keep + 1, (Int64, logits.device())),
0,
@ -1597,31 +1614,22 @@ pub(crate) mod private_generation_utils {
input_ids: Tensor,
encoder_outputs: Option<Tensor>,
cur_len: i64,
min_length: i64,
max_length: i64,
do_sample: bool,
temperature: f64,
top_k: i64,
top_p: f64,
repetition_penalty: f64,
no_repeat_ngram_size: i64,
pad_token_id: Option<i64>,
eos_token_ids: Option<Vec<i64>>,
batch_size: i64,
attention_mask: Tensor,
gen_opt: GenerateOptions,
) -> Tensor {
let mut unfinished_sentences =
Tensor::ones(&[batch_size], (Int64, self.get_var_store().device()));
let mut sentence_lengths: Tensor =
Tensor::ones(&[batch_size], (Int64, self.get_var_store().device()))
* max_length as i64;
* gen_opt.max_length as i64;
let mut attention_mask = attention_mask.copy();
let mut input_ids = input_ids.copy();
let mut past: Cache = Cache::None;
let mut outputs: Tensor;
let mut current_length = cur_len;
while current_length < max_length {
while current_length < gen_opt.max_length {
let (
prepared_input,
prepared_attention_mask,
@ -1648,31 +1656,31 @@ pub(crate) mod private_generation_utils {
false,
)
.unwrap();
outputs = temp.0;
past = temp.2;
outputs = temp.lm_logits;
past = temp.cache;
let mut next_token_logits = outputs.select(1, -1);
// Reduce probability for repeated inputs
if repetition_penalty > 1f64 {
if gen_opt.repetition_penalty > 1f64 {
self.enforce_repetition_penalty(
&mut next_token_logits,
batch_size,
1,
&input_ids,
repetition_penalty,
gen_opt.repetition_penalty,
)
}
// Get banned tokens and set their probability to 0
if no_repeat_ngram_size > 0 {
if gen_opt.no_repeat_ngram_size > 0 {
let banned_tokens = self.get_banned_tokens(
&input_ids,
no_repeat_ngram_size as i64,
gen_opt.no_repeat_ngram_size as i64,
current_length as i64,
);
for (batch_index, index_banned_token) in
(0..banned_tokens.len() as i64).zip(banned_tokens)
{
&next_token_logits.get(batch_index).index_fill_(
let _ = next_token_logits.get(batch_index).index_fill_(
0,
&Tensor::of_slice(&index_banned_token)
.to_device(next_token_logits.device()),
@ -1682,21 +1690,26 @@ pub(crate) mod private_generation_utils {
}
// Do not allow eos token if min length is not reached
if (&eos_token_ids.is_some()) & (current_length < min_length) {
&next_token_logits.index_fill_(
if (gen_opt.eos_token_ids.is_some()) & (current_length < gen_opt.min_length) {
let _ = next_token_logits.index_fill_(
1,
&Tensor::of_slice(eos_token_ids.as_ref().unwrap())
&Tensor::of_slice(gen_opt.eos_token_ids.as_ref().unwrap())
.to(next_token_logits.device()),
std::f64::NEG_INFINITY,
);
}
// Top-k and top-p sampling
let next_token = if do_sample {
if temperature > 1f64 {
next_token_logits = next_token_logits / temperature;
let next_token = if gen_opt.do_sample {
if gen_opt.temperature > 1f64 {
next_token_logits /= gen_opt.temperature;
}
self.top_k_top_p_filtering(&mut next_token_logits, top_k as i64, top_p, 1);
self.top_k_top_p_filtering(
&mut next_token_logits,
gen_opt.top_k as i64,
gen_opt.top_p,
1,
);
let probabilities = next_token_logits.softmax(-1, Float);
probabilities.multinomial(1, false).squeeze1(1)
} else {
@ -1704,17 +1717,17 @@ pub(crate) mod private_generation_utils {
};
// Add tokens to unfinished sentences
let tokens_to_add = match &eos_token_ids {
let tokens_to_add = match &gen_opt.eos_token_ids {
Some(_) => {
next_token * &unfinished_sentences
- pad_token_id.unwrap() * (&unfinished_sentences - 1)
- gen_opt.pad_token_id.unwrap() * (&unfinished_sentences - 1)
}
None => next_token,
};
input_ids = Tensor::cat(&[input_ids, tokens_to_add.unsqueeze(-1)], -1);
if eos_token_ids.is_some() {
for eos_token_id in eos_token_ids.as_ref().unwrap() {
if gen_opt.eos_token_ids.is_some() {
for eos_token_id in gen_opt.eos_token_ids.as_ref().unwrap() {
let sentence_with_eos = tokens_to_add.eq(*eos_token_id).to_kind(Int64);
let sentence_with_eos: Tensor = sentence_with_eos * &unfinished_sentences;
let _ = sentence_lengths.masked_fill_(
@ -1746,7 +1759,7 @@ pub(crate) mod private_generation_utils {
}
let decoded = if i64::from(&sentence_lengths.min().ne1(&sentence_lengths.max())) > 0 {
match pad_token_id {
match gen_opt.pad_token_id {
Some(pad_value) => {
let decoded: Tensor = Tensor::ones(
&[batch_size, i64::from(sentence_lengths.max())],
@ -1783,33 +1796,27 @@ pub(crate) mod private_generation_utils {
input_ids: Tensor,
encoder_outputs: Option<Tensor>,
cur_len: i64,
min_length: i64,
max_length: i64,
do_sample: bool,
early_stopping: bool,
temperature: f64,
top_k: i64,
top_p: f64,
repetition_penalty: f64,
no_repeat_ngram_size: i64,
pad_token_id: Option<i64>,
eos_token_ids: Option<Vec<i64>>,
batch_size: i64,
num_return_sequences: i64,
length_penalty: f64,
num_beams: i64,
attention_mask: Tensor,
gen_opt: GenerateOptions,
) -> Tensor {
let mut hypotheses = (0..batch_size)
.map(|_| BeamHypotheses::new(num_beams, max_length, length_penalty, early_stopping))
.map(|_| {
BeamHypotheses::new(
gen_opt.num_beams,
gen_opt.max_length,
gen_opt.length_penalty,
gen_opt.early_stopping,
)
})
.collect::<Vec<BeamHypotheses>>();
let vocab_size = self.get_vocab_size();
let beam_scores = Tensor::zeros(
&[batch_size, num_beams],
&[batch_size, gen_opt.num_beams],
(Float, self.get_var_store().device()),
);
if !do_sample {
if !gen_opt.do_sample {
let _ = beam_scores
.slice(1, 1, *beam_scores.size().last().unwrap(), 1)
.fill_(-1e9);
@ -1827,7 +1834,7 @@ pub(crate) mod private_generation_utils {
let mut encoder_outputs = encoder_outputs;
let mut current_length = cur_len;
while current_length < max_length {
while current_length < gen_opt.max_length {
let (
prepared_input,
prepared_attention_mask,
@ -1854,48 +1861,53 @@ pub(crate) mod private_generation_utils {
false,
)
.unwrap();
outputs = temp.0;
past = temp.2;
outputs = temp.lm_logits;
past = temp.cache;
let mut next_token_logits = outputs.select(1, -1);
// Reduce probability for repeated inputs
if repetition_penalty > 1f64 {
if gen_opt.repetition_penalty > 1f64 {
self.enforce_repetition_penalty(
&mut next_token_logits,
batch_size,
1,
&input_ids,
repetition_penalty,
gen_opt.repetition_penalty,
)
}
if temperature > 1f64 {
next_token_logits = next_token_logits / temperature;
if gen_opt.temperature > 1f64 {
next_token_logits /= gen_opt.temperature;
}
let mut scores = next_token_logits.log_softmax(-1, Float);
if self.is_encoder_decoder() & !do_sample {
self.prepare_scores_for_generation(&mut scores, current_length, max_length);
if self.is_encoder_decoder() & !gen_opt.do_sample {
self.prepare_scores_for_generation(
&mut scores,
current_length,
gen_opt.max_length,
);
}
// Do not allow eos token if min length is not reached
if (&eos_token_ids.is_some()) & (current_length < min_length) {
&scores.index_fill_(
if (gen_opt.eos_token_ids.is_some()) & (current_length < gen_opt.min_length) {
let _ = scores.index_fill_(
1,
&Tensor::of_slice(eos_token_ids.as_ref().unwrap()).to(scores.device()),
&Tensor::of_slice(gen_opt.eos_token_ids.as_ref().unwrap())
.to(scores.device()),
std::f64::NEG_INFINITY,
);
}
// Get banned tokens and set their probability to 0
if no_repeat_ngram_size > 0 {
if gen_opt.no_repeat_ngram_size > 0 {
let banned_tokens = self.get_banned_tokens(
&input_ids,
no_repeat_ngram_size as i64,
current_length as i64,
gen_opt.no_repeat_ngram_size,
current_length,
);
for (batch_index, index_banned_token) in
(0..banned_tokens.len() as i64).zip(banned_tokens)
{
&scores.get(batch_index).index_fill_(
let _ = scores.get(batch_index).index_fill_(
0,
&Tensor::of_slice(&index_banned_token)
.to_device(next_token_logits.device()),
@ -1904,16 +1916,16 @@ pub(crate) mod private_generation_utils {
}
}
let (next_scores, next_tokens) = if do_sample {
let (next_scores, next_tokens) = if gen_opt.do_sample {
let mut _scores: Tensor =
&scores + &beam_scores.unsqueeze(-1).expand_as(&scores);
self.top_k_top_p_filtering(&mut _scores, top_k as i64, top_p, 2);
self.top_k_top_p_filtering(&mut _scores, gen_opt.top_k, gen_opt.top_p, 2);
let _scores = _scores
.contiguous()
.view((batch_size, num_beams * vocab_size));
.view((batch_size, gen_opt.num_beams * vocab_size));
let probabilities = _scores.softmax(-1, Float);
let next_tokens = probabilities.multinomial(2 * num_beams, false);
let next_tokens = probabilities.multinomial(2 * gen_opt.num_beams, false);
let next_scores = _scores.gather(-1, &next_tokens, false);
let (next_scores, next_scores_indices) = next_scores.sort(1, true);
let next_tokens = next_tokens.gather(-1, &next_scores_indices, false);
@ -1923,25 +1935,25 @@ pub(crate) mod private_generation_utils {
&scores + &beam_scores.unsqueeze(-1).expand_as(&scores);
let next_scores = next_scores
.contiguous()
.view((batch_size, num_beams * vocab_size));
next_scores.topk(2 * num_beams, 1, true, true)
.view((batch_size, gen_opt.num_beams * vocab_size));
next_scores.topk(2 * gen_opt.num_beams, 1, true, true)
};
let mut next_batch_beam: Vec<(f64, i64, i64)> = vec![];
for batch_index in 0..batch_size {
if done[batch_index as usize] {
assert!(
hypotheses[batch_index as usize].len() >= num_beams,
hypotheses[batch_index as usize].len() >= gen_opt.num_beams,
"Batch cannot be completed if all beams have not been generated"
);
assert!(
eos_token_ids.is_some() & pad_token_id.is_some(),
gen_opt.eos_token_ids.is_some() & gen_opt.pad_token_id.is_some(),
"EOS and Padding tokens need to be defined if the number of generated \
beams is greater than the target number fo beams"
);
next_batch_beam.append(
&mut (0..num_beams)
.map(|_| (0f64, pad_token_id.unwrap(), 0i64))
&mut (0..gen_opt.num_beams)
.map(|_| (0f64, gen_opt.pad_token_id.unwrap(), 0i64))
.collect::<Vec<(f64, i64, i64)>>(),
);
continue;
@ -1960,11 +1972,11 @@ pub(crate) mod private_generation_utils {
let beam_id = beam_token_id / vocab_size;
let token_id = beam_token_id % vocab_size;
let effective_beam_id = batch_index * num_beams + beam_id;
let effective_beam_id = batch_index * gen_opt.num_beams + beam_id;
if eos_token_ids.as_ref().is_some() {
if eos_token_ids.as_ref().unwrap().contains(&token_id) {
if beam_token_rank >= num_beams {
if gen_opt.eos_token_ids.as_ref().is_some() {
if gen_opt.eos_token_ids.as_ref().unwrap().contains(&token_id) {
if beam_token_rank >= gen_opt.num_beams {
beam_token_rank += 1;
continue;
}
@ -1985,7 +1997,7 @@ pub(crate) mod private_generation_utils {
));
}
if (next_sentence_beam.len() as i64 == num_beams)
if (next_sentence_beam.len() as i64 == gen_opt.num_beams)
| (beam_token_rank == beam_token_rank_max_value)
{
break;
@ -1993,15 +2005,14 @@ pub(crate) mod private_generation_utils {
beam_token_rank += 1;
}
done[batch_index as usize] = done[batch_index as usize]
| hypotheses[batch_index as usize].is_done(
f64::from(next_scores.get(batch_index).max()),
current_length,
);
done[batch_index as usize] |= hypotheses[batch_index as usize].is_done(
f64::from(next_scores.get(batch_index).max()),
current_length,
);
assert_eq!(
next_sentence_beam.len() as i64,
num_beams,
gen_opt.num_beams,
"Beam incomplete"
);
next_batch_beam.append(&mut next_sentence_beam);
@ -2062,8 +2073,8 @@ pub(crate) mod private_generation_utils {
batch_index += 1;
continue;
}
for beam_index in 0..num_beams {
let effective_beam_id = batch_index * num_beams + beam_index;
for beam_index in 0..gen_opt.num_beams {
let effective_beam_id = batch_index * gen_opt.num_beams + beam_index;
let final_score = f64::from(beam_scores.get(effective_beam_id));
let final_tokens = input_ids.get(effective_beam_id);
hypotheses[batch_index as usize].add(final_tokens, final_score);
@ -2071,10 +2082,13 @@ pub(crate) mod private_generation_utils {
batch_index += 1;
}
let (output_batch_size, output_num_return_sequences_per_batch) = if do_sample {
let (output_batch_size, output_num_return_sequences_per_batch) = if gen_opt.do_sample {
(batch_size, 1)
} else {
(batch_size * num_return_sequences, num_return_sequences)
(
batch_size * gen_opt.num_return_sequences,
gen_opt.num_return_sequences,
)
};
let mut sentence_lengths =
@ -2083,7 +2097,7 @@ pub(crate) mod private_generation_utils {
for (hypothesis_index, hypothesis) in hypotheses.iter().enumerate() {
let mut sorted_hypotheses = hypothesis.clone();
&sorted_hypotheses
sorted_hypotheses
.beams
.sort_by_key(|(score, _)| OrderedFloat(*score));
for j in 0..output_num_return_sequences_per_batch {
@ -2101,12 +2115,13 @@ pub(crate) mod private_generation_utils {
let decoded = if i64::from(sentence_lengths.max()) != i64::from(sentence_lengths.min())
{
let sentence_max_length = min(i64::from(sentence_lengths.max()) + 1, max_length);
let sentence_max_length =
min(i64::from(sentence_lengths.max()) + 1, gen_opt.max_length);
let decoded: Tensor = Tensor::ones(
&[output_batch_size, sentence_max_length],
(Int64, input_ids.device()),
) * pad_token_id.unwrap();
for hypothesis_index in 0..best_ids.len() {
) * gen_opt.pad_token_id.unwrap();
for (hypothesis_index, best_id) in best_ids.iter().enumerate() {
let _ = decoded.get(hypothesis_index as i64).index_copy_(
0,
&Tensor::arange1(
@ -2114,14 +2129,14 @@ pub(crate) mod private_generation_utils {
i64::from(sentence_lengths.get(hypothesis_index as i64)),
(Int64, input_ids.device()),
),
&best_ids[hypothesis_index],
&best_id,
);
let sentence_length = i64::from(sentence_lengths.get(hypothesis_index as i64));
if sentence_length < max_length {
if sentence_length < gen_opt.max_length {
let _ = decoded.get(hypothesis_index as i64).index_fill_(
0,
&Tensor::of_slice(&[sentence_length]).to_device(input_ids.device()),
eos_token_ids.as_ref().unwrap()[0],
gen_opt.eos_token_ids.as_ref().unwrap()[0],
);
}
}
@ -2231,7 +2246,7 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
let config = PrivateLanguageGenerator::get_config(self);
let max_length = config.max_length;
let encoding_max_len = if self.is_encoder_decoder() {
1024u64
1024i64
} else {
max_length
};
@ -2377,49 +2392,45 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
(input_ids, attention_mask)
};
let gen_opt = GenerateOptions {
min_length,
max_length,
do_sample,
temperature,
top_k,
top_p,
repetition_penalty,
no_repeat_ngram_size,
pad_token_id,
eos_token_ids,
num_return_sequences,
early_stopping,
num_beams,
length_penalty,
};
let decoded = no_grad(|| {
if num_beams > 1 {
self.generate_beam_search(
input_ids,
encoder_outputs,
cur_len,
min_length as i64,
max_length as i64,
do_sample,
early_stopping,
temperature,
top_k as i64,
top_p,
repetition_penalty,
no_repeat_ngram_size as i64,
pad_token_id,
eos_token_ids,
effective_batch_size,
num_return_sequences as i64,
length_penalty,
num_beams as i64,
attention_mask,
gen_opt,
)
} else {
self.generate_no_beam_search(
input_ids,
encoder_outputs,
cur_len,
min_length as i64,
max_length as i64,
do_sample,
temperature,
top_k as i64,
top_p,
repetition_penalty,
no_repeat_ngram_size as i64,
pad_token_id,
eos_token_ids,
effective_batch_size,
attention_mask,
gen_opt,
)
}
});
let num_sequences = *decoded.size().first().unwrap();
let mut output_ids = Vec::with_capacity(num_sequences as usize);
for sequence_index in 0..num_sequences {
@ -2575,7 +2586,7 @@ pub trait LMHeadModel {
/// let position_ids = Tensor::arange(sequence_length, (Int64, device))
/// .expand(&[batch_size, sequence_length], true);
///
/// let (output, encoder_output, past, hidden_states, attentions) = no_grad(|| {
/// let model_output = no_grad(|| {
/// gpt2_model
/// .forward_t(
/// &Some(input_tensor),
@ -2602,14 +2613,19 @@ pub trait LMHeadModel {
encoder_outputs: Option<&Tensor>,
decoder_input_ids: &Option<Tensor>,
train: bool,
) -> Result<
(
Tensor,
Option<Tensor>,
Cache,
Option<Vec<Tensor>>,
Option<Vec<Tensor>>,
),
&'static str,
>;
) -> Result<LMModelOutput, RustBertError>;
}
/// Container holding a language model output for generation tasks
pub struct LMModelOutput {
/// Logits for each vocab item and position
pub lm_logits: Tensor,
/// Encoder hidden state (re-used for encoder/decoder architectures)
pub encoder_hidden_state: Option<Tensor>,
/// cached state for improved efficiency during decoding
pub cache: Cache,
/// Hidden states for all intermediate model layers
pub all_hidden_states: Option<Vec<Tensor>>,
/// Attention weights for all intermediate model layers
pub all_attentions: Option<Vec<Tensor>>,
}

View File

@ -160,7 +160,7 @@ impl QaExample {
}
if !current_word.is_empty() {
doc_tokens.push(current_word.clone());
doc_tokens.push(current_word);
}
(doc_tokens, char_to_word_offset)
}
@ -356,21 +356,21 @@ impl QuestionAnsweringOption {
match *self {
Self::Bert(ref model) => {
let outputs = model.forward_t(input_ids, mask, None, None, input_embeds, train);
(outputs.0, outputs.1)
(outputs.start_logits, outputs.end_logits)
}
Self::DistilBert(ref model) => {
let outputs = model
.forward_t(input_ids, mask, input_embeds, train)
.expect("Error in distilbert forward_t");
(outputs.0, outputs.1)
(outputs.start_logits, outputs.end_logits)
}
Self::Roberta(ref model) | Self::XLMRoberta(ref model) => {
let outputs = model.forward_t(input_ids, mask, None, None, input_embeds, train);
(outputs.0, outputs.1)
(outputs.start_logits, outputs.end_logits)
}
Self::Albert(ref model) => {
let outputs = model.forward_t(input_ids, mask, None, None, input_embeds, train);
(outputs.0, outputs.1)
(outputs.start_logits, outputs.end_logits)
}
}
}
@ -437,12 +437,9 @@ impl QuestionAnsweringModel {
let mut var_store = VarStore::new(device);
let mut model_config =
ConfigOption::from_file(question_answering_config.model_type, config_path);
match model_config {
// The config for the current pre-trained question answering model indicates position embeddings which does not seem accurate
ConfigOption::DistilBert(ref mut config) => {
config.sinusoidal_pos_embds = false;
}
_ => (),
if let ConfigOption::DistilBert(ref mut config) = model_config {
config.sinusoidal_pos_embds = false;
};
let qa_model = QuestionAnsweringOption::new(
@ -605,7 +602,7 @@ impl QuestionAnsweringModel {
feature_id_start = max_feature_id;
let example_answers = example_top_k_answers_map
.entry(example_id)
.or_insert(vec![]);
.or_insert_with(Vec::new);
example_answers.extend(answers);
}
});
@ -684,7 +681,7 @@ impl QuestionAnsweringModel {
vec![],
None,
)
.0
.token_ids
.len()
+ 1
}
@ -700,7 +697,7 @@ impl QuestionAnsweringModel {
vec![],
None,
)
.0
.token_ids
.len(),
};
@ -716,7 +713,7 @@ impl QuestionAnsweringModel {
vec![],
Some(vec![]),
)
.0
.token_ids
.len();
let mut spans: Vec<QaFeature> = vec![];
@ -792,8 +789,8 @@ impl QuestionAnsweringModel {
fn encode_qa_pair(
&self,
truncated_query: &Vec<i64>,
spans_token_ids: &Vec<i64>,
truncated_query: &[i64],
spans_token_ids: &[i64],
max_seq_length: usize,
doc_stride: usize,
sequence_pair_added_tokens: usize,
@ -809,8 +806,8 @@ impl QuestionAnsweringModel {
let (truncated_query, truncated_context, _, _, _, _, _, _, overflowing_tokens, _) =
truncate_sequences(
truncated_query.clone(),
Some(spans_token_ids.clone()),
truncated_query.into(),
Some(spans_token_ids.into()),
vec![],
None,
vec![],
@ -823,14 +820,7 @@ impl QuestionAnsweringModel {
)
.unwrap();
let (
mut token_ids,
mut segment_ids,
special_tokens_mask,
mut token_offsets,
mut reference_offsets,
mut mask,
) = self.tokenizer.build_input_with_special_tokens(
let mut tokenized_input = self.tokenizer.build_input_with_special_tokens(
truncated_query,
truncated_context,
vec![],
@ -840,25 +830,43 @@ impl QuestionAnsweringModel {
vec![],
None,
);
let mut attention_mask = vec![1; token_ids.len()];
if token_ids.len() < max_seq_length {
token_ids.append(&mut vec![self.pad_idx; max_seq_length - token_ids.len()]);
segment_ids.append(&mut vec![0; max_seq_length - segment_ids.len()]);
let mut attention_mask = vec![1; tokenized_input.token_ids.len()];
if tokenized_input.token_ids.len() < max_seq_length {
tokenized_input.token_ids.append(&mut vec![
self.pad_idx;
max_seq_length
- tokenized_input.token_ids.len()
]);
tokenized_input.segment_ids.append(&mut vec![
0;
max_seq_length
- tokenized_input.segment_ids.len()
]);
attention_mask.append(&mut vec![0; max_seq_length - attention_mask.len()]);
token_offsets.append(&mut vec![None; max_seq_length - token_offsets.len()]);
reference_offsets.append(&mut vec![vec!(); max_seq_length - token_offsets.len()]);
mask.append(&mut vec![Mask::Special; max_seq_length - mask.len()]);
tokenized_input.token_offsets.append(&mut vec![
None;
max_seq_length
- tokenized_input
.token_offsets
.len()
]);
tokenized_input.reference_offsets.append(&mut vec![
vec!();
max_seq_length
- tokenized_input
.token_offsets
.len()
]);
tokenized_input.mask.append(&mut vec![
Mask::Special;
max_seq_length - tokenized_input.mask.len()
]);
}
(
TokenizedInput {
token_ids,
segment_ids,
special_tokens_mask,
overflowing_tokens,
num_truncated_tokens,
token_offsets,
reference_offsets,
mask,
..tokenized_input
},
attention_mask,
)

View File

@ -301,7 +301,7 @@ impl SequenceClassificationOption {
None,
train,
)
.0
.decoder_output
}
Self::Bert(ref model) => {
model
@ -313,13 +313,13 @@ impl SequenceClassificationOption {
input_embeds,
train,
)
.0
.logits
}
Self::DistilBert(ref model) => {
model
.forward_t(input_ids, mask, input_embeds, train)
.expect("Error in distilbert forward_t")
.0
.logits
}
Self::Roberta(ref model) | Self::XLMRoberta(ref model) => {
model
@ -331,7 +331,7 @@ impl SequenceClassificationOption {
input_embeds,
train,
)
.0
.logits
}
Self::Albert(ref model) => {
model
@ -343,7 +343,7 @@ impl SequenceClassificationOption {
input_embeds,
train,
)
.0
.logits
}
}
}
@ -569,7 +569,7 @@ impl SequenceClassificationModel {
};
sequence_labels.push(label);
}
if sequence_labels.len() > 0 {
if !sequence_labels.is_empty() {
labels.push(sequence_labels);
}
Ok(labels)

View File

@ -83,19 +83,19 @@ pub struct SummarizationConfig {
/// Merges resource (default: pretrained BART model on CNN-DM)
pub merges_resource: Resource,
/// Minimum sequence length (default: 0)
pub min_length: u64,
pub min_length: i64,
/// Maximum sequence length (default: 20)
pub max_length: u64,
pub max_length: i64,
/// Sampling flag. If true, will perform top-k and/or nucleus sampling on generated tokens, otherwise greedy (deterministic) decoding (default: true)
pub do_sample: bool,
/// Early stopping flag indicating if the beam search should stop as soon as `num_beam` hypotheses have been generated (default: false)
pub early_stopping: bool,
/// Number of beams for beam search (default: 5)
pub num_beams: u64,
pub num_beams: i64,
/// Temperature setting. Values higher than 1 will improve originality at the risk of reducing relevance (default: 1.0)
pub temperature: f64,
/// Top_k values for sampling tokens. Value higher than 0 will enable the feature (default: 0)
pub top_k: u64,
pub top_k: i64,
/// Top_p value for [Nucleus sampling, Holtzman et al.](http://arxiv.org/abs/1904.09751). Keep top tokens until cumulative probability reaches top_p (default: 0.9)
pub top_p: f64,
/// Repetition penalty (mostly useful for CTRL decoders). Values higher than 1 will penalize tokens that have been already generated. (default: 1.0)
@ -103,9 +103,9 @@ pub struct SummarizationConfig {
/// Exponential penalty based on the length of the hypotheses generated (default: 1.0)
pub length_penalty: f64,
/// Number of allowed repetitions of n-grams. Values higher than 0 turn on this feature (default: 3)
pub no_repeat_ngram_size: u64,
pub no_repeat_ngram_size: i64,
/// Number of sequences to return for each prompt text (default: 1)
pub num_return_sequences: u64,
pub num_return_sequences: i64,
/// Device to place the model on (default: CUDA/GPU when available)
pub device: Device,
}

View File

@ -408,13 +408,13 @@ impl TokenClassificationOption {
input_embeds,
train,
)
.0
.logits
}
Self::DistilBert(ref model) => {
model
.forward_t(input_ids, mask, input_embeds, train)
.expect("Error in distilbert forward_t")
.0
.logits
}
Self::Roberta(ref model) | Self::XLMRoberta(ref model) => {
model
@ -426,7 +426,7 @@ impl TokenClassificationOption {
input_embeds,
train,
)
.0
.logits
}
Self::Electra(ref model) => {
model
@ -438,7 +438,7 @@ impl TokenClassificationOption {
input_embeds,
train,
)
.0
.logits
}
Self::Albert(ref model) => {
model
@ -450,7 +450,7 @@ impl TokenClassificationOption {
input_embeds,
train,
)
.0
.logits
}
}
}
@ -632,7 +632,7 @@ impl TokenClassificationModel {
fn decode_token(
&self,
original_sentence_chars: &Vec<char>,
original_sentence_chars: &[char],
sentence_tokens: &TokenizedInput,
input_tensor: &Tensor,
labels: &Tensor,
@ -700,16 +700,16 @@ impl TokenClassificationModel {
label_aggregation_function: &LabelAggregationOption,
) {
let mut tokens_to_replace = vec![];
let mut token_iter = tokens.iter_consolidate_tokens();
let token_iter = tokens.iter_consolidate_tokens();
let mut cursor = 0;
while let Some(sub_tokens) = token_iter.next() {
for sub_tokens in token_iter {
if sub_tokens.len() > 1 {
let (label_index, label) =
self.consolidate_labels(sub_tokens, label_aggregation_function);
let sentence = (&sub_tokens[0]).sentence;
let index = (&sub_tokens[0]).index;
let word_index = (&sub_tokens[0]).word_index;
let sentence = (sub_tokens[0]).sentence;
let index = (sub_tokens[0]).index;
let word_index = (sub_tokens[0]).word_index;
let offset_start = match &sub_tokens.first().unwrap().offset {
Some(offset) => Some(offset.begin),
None => None,
@ -718,14 +718,15 @@ impl TokenClassificationModel {
Some(offset) => Some(offset.end),
None => None,
};
let offset = if offset_start.is_some() & offset_end.is_some() {
Some(Offset::new(offset_start.unwrap(), offset_end.unwrap()))
} else {
None
};
let offset =
if let (Some(offset_start), Some(offset_end)) = (offset_start, offset_end) {
Some(Offset::new(offset_start, offset_end))
} else {
None
};
let mut text = String::new();
let mut score = 1f64;
for current_sub_token in sub_tokens.into_iter() {
for current_sub_token in sub_tokens.iter() {
text.push_str(current_sub_token.text.as_str());
score *= if current_sub_token.label_index == label_index {
current_sub_token.score

View File

@ -91,311 +91,176 @@ pub enum Language {
GermanToFrench,
}
struct RemoteTranslationResources;
struct RemoteTranslationResources {
model_resource: (&'static str, &'static str),
config_resource: (&'static str, &'static str),
vocab_resource: (&'static str, &'static str),
merges_resource: (&'static str, &'static str),
prefix: Option<&'static str>,
model_type: ModelType,
}
impl RemoteTranslationResources {
pub const ENGLISH2FRENCH: (
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
ModelType,
) = (
MarianModelResources::ENGLISH2ROMANCE,
MarianConfigResources::ENGLISH2ROMANCE,
MarianVocabResources::ENGLISH2ROMANCE,
MarianSpmResources::ENGLISH2ROMANCE,
MarianPrefix::ENGLISH2FRENCH,
ModelType::Marian,
);
pub const ENGLISH2FRENCH_V2: (
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
ModelType,
) = (
T5ModelResources::T5_BASE,
T5ConfigResources::T5_BASE,
T5VocabResources::T5_BASE,
T5VocabResources::T5_BASE,
T5Prefix::ENGLISH2FRENCH,
ModelType::T5,
);
pub const ENGLISH2GERMAN_V2: (
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
ModelType,
) = (
T5ModelResources::T5_BASE,
T5ConfigResources::T5_BASE,
T5VocabResources::T5_BASE,
T5VocabResources::T5_BASE,
T5Prefix::ENGLISH2GERMAN,
ModelType::T5,
);
pub const ENGLISH2CATALAN: (
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
ModelType,
) = (
MarianModelResources::ENGLISH2ROMANCE,
MarianConfigResources::ENGLISH2ROMANCE,
MarianVocabResources::ENGLISH2ROMANCE,
MarianSpmResources::ENGLISH2ROMANCE,
MarianPrefix::ENGLISH2CATALAN,
ModelType::Marian,
);
pub const ENGLISH2SPANISH: (
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
ModelType,
) = (
MarianModelResources::ENGLISH2ROMANCE,
MarianConfigResources::ENGLISH2ROMANCE,
MarianVocabResources::ENGLISH2ROMANCE,
MarianSpmResources::ENGLISH2ROMANCE,
MarianPrefix::ENGLISH2SPANISH,
ModelType::Marian,
);
pub const ENGLISH2PORTUGUESE: (
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
ModelType,
) = (
MarianModelResources::ENGLISH2ROMANCE,
MarianConfigResources::ENGLISH2ROMANCE,
MarianVocabResources::ENGLISH2ROMANCE,
MarianSpmResources::ENGLISH2ROMANCE,
MarianPrefix::ENGLISH2PORTUGUESE,
ModelType::Marian,
);
pub const ENGLISH2ITALIAN: (
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
ModelType,
) = (
MarianModelResources::ENGLISH2ROMANCE,
MarianConfigResources::ENGLISH2ROMANCE,
MarianVocabResources::ENGLISH2ROMANCE,
MarianSpmResources::ENGLISH2ROMANCE,
MarianPrefix::ENGLISH2ITALIAN,
ModelType::Marian,
);
pub const ENGLISH2ROMANIAN: (
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
ModelType,
) = (
MarianModelResources::ENGLISH2ROMANCE,
MarianConfigResources::ENGLISH2ROMANCE,
MarianVocabResources::ENGLISH2ROMANCE,
MarianSpmResources::ENGLISH2ROMANCE,
MarianPrefix::ENGLISH2ROMANIAN,
ModelType::Marian,
);
pub const ENGLISH2GERMAN: (
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
ModelType,
) = (
MarianModelResources::ENGLISH2GERMAN,
MarianConfigResources::ENGLISH2GERMAN,
MarianVocabResources::ENGLISH2GERMAN,
MarianSpmResources::ENGLISH2GERMAN,
MarianPrefix::ENGLISH2GERMAN,
ModelType::Marian,
);
pub const ENGLISH2RUSSIAN: (
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
ModelType,
) = (
MarianModelResources::ENGLISH2RUSSIAN,
MarianConfigResources::ENGLISH2RUSSIAN,
MarianVocabResources::ENGLISH2RUSSIAN,
MarianSpmResources::ENGLISH2RUSSIAN,
MarianPrefix::ENGLISH2RUSSIAN,
ModelType::Marian,
);
pub const FRENCH2ENGLISH: (
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
ModelType,
) = (
MarianModelResources::ROMANCE2ENGLISH,
MarianConfigResources::ROMANCE2ENGLISH,
MarianVocabResources::ROMANCE2ENGLISH,
MarianSpmResources::ROMANCE2ENGLISH,
MarianPrefix::FRENCH2ENGLISH,
ModelType::Marian,
);
pub const CATALAN2ENGLISH: (
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
ModelType,
) = (
MarianModelResources::ROMANCE2ENGLISH,
MarianConfigResources::ROMANCE2ENGLISH,
MarianVocabResources::ROMANCE2ENGLISH,
MarianSpmResources::ROMANCE2ENGLISH,
MarianPrefix::CATALAN2ENGLISH,
ModelType::Marian,
);
pub const SPANISH2ENGLISH: (
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
ModelType,
) = (
MarianModelResources::ROMANCE2ENGLISH,
MarianConfigResources::ROMANCE2ENGLISH,
MarianVocabResources::ROMANCE2ENGLISH,
MarianSpmResources::ROMANCE2ENGLISH,
MarianPrefix::SPANISH2ENGLISH,
ModelType::Marian,
);
pub const PORTUGUESE2ENGLISH: (
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
ModelType,
) = (
MarianModelResources::ROMANCE2ENGLISH,
MarianConfigResources::ROMANCE2ENGLISH,
MarianVocabResources::ROMANCE2ENGLISH,
MarianSpmResources::ROMANCE2ENGLISH,
MarianPrefix::PORTUGUESE2ENGLISH,
ModelType::Marian,
);
pub const ITALIAN2ENGLISH: (
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
ModelType,
) = (
MarianModelResources::ROMANCE2ENGLISH,
MarianConfigResources::ROMANCE2ENGLISH,
MarianVocabResources::ROMANCE2ENGLISH,
MarianSpmResources::ROMANCE2ENGLISH,
MarianPrefix::ITALIAN2ENGLISH,
ModelType::Marian,
);
pub const ROMANIAN2ENGLISH: (
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
ModelType,
) = (
MarianModelResources::ROMANCE2ENGLISH,
MarianConfigResources::ROMANCE2ENGLISH,
MarianVocabResources::ROMANCE2ENGLISH,
MarianSpmResources::ROMANCE2ENGLISH,
MarianPrefix::ROMANIAN2ENGLISH,
ModelType::Marian,
);
pub const GERMAN2ENGLISH: (
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
ModelType,
) = (
MarianModelResources::GERMAN2ENGLISH,
MarianConfigResources::GERMAN2ENGLISH,
MarianVocabResources::GERMAN2ENGLISH,
MarianSpmResources::GERMAN2ENGLISH,
MarianPrefix::GERMAN2ENGLISH,
ModelType::Marian,
);
pub const RUSSIAN2ENGLISH: (
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
ModelType,
) = (
MarianModelResources::RUSSIAN2ENGLISH,
MarianConfigResources::RUSSIAN2ENGLISH,
MarianVocabResources::RUSSIAN2ENGLISH,
MarianSpmResources::RUSSIAN2ENGLISH,
MarianPrefix::RUSSIAN2ENGLISH,
ModelType::Marian,
);
pub const FRENCH2GERMAN: (
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
ModelType,
) = (
MarianModelResources::FRENCH2GERMAN,
MarianConfigResources::FRENCH2GERMAN,
MarianVocabResources::FRENCH2GERMAN,
MarianSpmResources::FRENCH2GERMAN,
MarianPrefix::FRENCH2GERMAN,
ModelType::Marian,
);
pub const GERMAN2FRENCH: (
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
ModelType,
) = (
MarianModelResources::GERMAN2FRENCH,
MarianConfigResources::GERMAN2FRENCH,
MarianVocabResources::GERMAN2FRENCH,
MarianSpmResources::GERMAN2FRENCH,
MarianPrefix::GERMAN2FRENCH,
ModelType::Marian,
);
pub const ENGLISH2FRENCH: RemoteTranslationResources = Self {
model_resource: MarianModelResources::ENGLISH2ROMANCE,
config_resource: MarianConfigResources::ENGLISH2ROMANCE,
vocab_resource: MarianVocabResources::ENGLISH2ROMANCE,
merges_resource: MarianSpmResources::ENGLISH2ROMANCE,
prefix: MarianPrefix::ENGLISH2FRENCH,
model_type: ModelType::Marian,
};
pub const ENGLISH2FRENCH_V2: RemoteTranslationResources = Self {
model_resource: T5ModelResources::T5_BASE,
config_resource: T5ConfigResources::T5_BASE,
vocab_resource: T5VocabResources::T5_BASE,
merges_resource: T5VocabResources::T5_BASE,
prefix: T5Prefix::ENGLISH2FRENCH,
model_type: ModelType::T5,
};
pub const ENGLISH2GERMAN_V2: RemoteTranslationResources = Self {
model_resource: T5ModelResources::T5_BASE,
config_resource: T5ConfigResources::T5_BASE,
vocab_resource: T5VocabResources::T5_BASE,
merges_resource: T5VocabResources::T5_BASE,
prefix: T5Prefix::ENGLISH2GERMAN,
model_type: ModelType::T5,
};
pub const ENGLISH2CATALAN: RemoteTranslationResources = Self {
model_resource: MarianModelResources::ENGLISH2ROMANCE,
config_resource: MarianConfigResources::ENGLISH2ROMANCE,
vocab_resource: MarianVocabResources::ENGLISH2ROMANCE,
merges_resource: MarianSpmResources::ENGLISH2ROMANCE,
prefix: MarianPrefix::ENGLISH2CATALAN,
model_type: ModelType::Marian,
};
pub const ENGLISH2SPANISH: RemoteTranslationResources = Self {
model_resource: MarianModelResources::ENGLISH2ROMANCE,
config_resource: MarianConfigResources::ENGLISH2ROMANCE,
vocab_resource: MarianVocabResources::ENGLISH2ROMANCE,
merges_resource: MarianSpmResources::ENGLISH2ROMANCE,
prefix: MarianPrefix::ENGLISH2SPANISH,
model_type: ModelType::Marian,
};
pub const ENGLISH2PORTUGUESE: RemoteTranslationResources = Self {
model_resource: MarianModelResources::ENGLISH2ROMANCE,
config_resource: MarianConfigResources::ENGLISH2ROMANCE,
vocab_resource: MarianVocabResources::ENGLISH2ROMANCE,
merges_resource: MarianSpmResources::ENGLISH2ROMANCE,
prefix: MarianPrefix::ENGLISH2PORTUGUESE,
model_type: ModelType::Marian,
};
pub const ENGLISH2ITALIAN: RemoteTranslationResources = Self {
model_resource: MarianModelResources::ENGLISH2ROMANCE,
config_resource: MarianConfigResources::ENGLISH2ROMANCE,
vocab_resource: MarianVocabResources::ENGLISH2ROMANCE,
merges_resource: MarianSpmResources::ENGLISH2ROMANCE,
prefix: MarianPrefix::ENGLISH2ITALIAN,
model_type: ModelType::Marian,
};
pub const ENGLISH2ROMANIAN: RemoteTranslationResources = Self {
model_resource: MarianModelResources::ENGLISH2ROMANCE,
config_resource: MarianConfigResources::ENGLISH2ROMANCE,
vocab_resource: MarianVocabResources::ENGLISH2ROMANCE,
merges_resource: MarianSpmResources::ENGLISH2ROMANCE,
prefix: MarianPrefix::ENGLISH2ROMANIAN,
model_type: ModelType::Marian,
};
pub const ENGLISH2GERMAN: RemoteTranslationResources = Self {
model_resource: MarianModelResources::ENGLISH2GERMAN,
config_resource: MarianConfigResources::ENGLISH2GERMAN,
vocab_resource: MarianVocabResources::ENGLISH2GERMAN,
merges_resource: MarianSpmResources::ENGLISH2GERMAN,
prefix: MarianPrefix::ENGLISH2GERMAN,
model_type: ModelType::Marian,
};
pub const ENGLISH2RUSSIAN: RemoteTranslationResources = Self {
model_resource: MarianModelResources::ENGLISH2RUSSIAN,
config_resource: MarianConfigResources::ENGLISH2RUSSIAN,
vocab_resource: MarianVocabResources::ENGLISH2RUSSIAN,
merges_resource: MarianSpmResources::ENGLISH2RUSSIAN,
prefix: MarianPrefix::ENGLISH2RUSSIAN,
model_type: ModelType::Marian,
};
pub const FRENCH2ENGLISH: RemoteTranslationResources = Self {
model_resource: MarianModelResources::ROMANCE2ENGLISH,
config_resource: MarianConfigResources::ROMANCE2ENGLISH,
vocab_resource: MarianVocabResources::ROMANCE2ENGLISH,
merges_resource: MarianSpmResources::ROMANCE2ENGLISH,
prefix: MarianPrefix::FRENCH2ENGLISH,
model_type: ModelType::Marian,
};
pub const CATALAN2ENGLISH: RemoteTranslationResources = Self {
model_resource: MarianModelResources::ROMANCE2ENGLISH,
config_resource: MarianConfigResources::ROMANCE2ENGLISH,
vocab_resource: MarianVocabResources::ROMANCE2ENGLISH,
merges_resource: MarianSpmResources::ROMANCE2ENGLISH,
prefix: MarianPrefix::CATALAN2ENGLISH,
model_type: ModelType::Marian,
};
pub const SPANISH2ENGLISH: RemoteTranslationResources = Self {
model_resource: MarianModelResources::ROMANCE2ENGLISH,
config_resource: MarianConfigResources::ROMANCE2ENGLISH,
vocab_resource: MarianVocabResources::ROMANCE2ENGLISH,
merges_resource: MarianSpmResources::ROMANCE2ENGLISH,
prefix: MarianPrefix::SPANISH2ENGLISH,
model_type: ModelType::Marian,
};
pub const PORTUGUESE2ENGLISH: RemoteTranslationResources = Self {
model_resource: MarianModelResources::ROMANCE2ENGLISH,
config_resource: MarianConfigResources::ROMANCE2ENGLISH,
vocab_resource: MarianVocabResources::ROMANCE2ENGLISH,
merges_resource: MarianSpmResources::ROMANCE2ENGLISH,
prefix: MarianPrefix::PORTUGUESE2ENGLISH,
model_type: ModelType::Marian,
};
pub const ITALIAN2ENGLISH: RemoteTranslationResources = Self {
model_resource: MarianModelResources::ROMANCE2ENGLISH,
config_resource: MarianConfigResources::ROMANCE2ENGLISH,
vocab_resource: MarianVocabResources::ROMANCE2ENGLISH,
merges_resource: MarianSpmResources::ROMANCE2ENGLISH,
prefix: MarianPrefix::ITALIAN2ENGLISH,
model_type: ModelType::Marian,
};
pub const ROMANIAN2ENGLISH: RemoteTranslationResources = Self {
model_resource: MarianModelResources::ROMANCE2ENGLISH,
config_resource: MarianConfigResources::ROMANCE2ENGLISH,
vocab_resource: MarianVocabResources::ROMANCE2ENGLISH,
merges_resource: MarianSpmResources::ROMANCE2ENGLISH,
prefix: MarianPrefix::ROMANIAN2ENGLISH,
model_type: ModelType::Marian,
};
pub const GERMAN2ENGLISH: RemoteTranslationResources = Self {
model_resource: MarianModelResources::GERMAN2ENGLISH,
config_resource: MarianConfigResources::GERMAN2ENGLISH,
vocab_resource: MarianVocabResources::GERMAN2ENGLISH,
merges_resource: MarianSpmResources::GERMAN2ENGLISH,
prefix: MarianPrefix::GERMAN2ENGLISH,
model_type: ModelType::Marian,
};
pub const RUSSIAN2ENGLISH: RemoteTranslationResources = Self {
model_resource: MarianModelResources::RUSSIAN2ENGLISH,
config_resource: MarianConfigResources::RUSSIAN2ENGLISH,
vocab_resource: MarianVocabResources::RUSSIAN2ENGLISH,
merges_resource: MarianSpmResources::RUSSIAN2ENGLISH,
prefix: MarianPrefix::RUSSIAN2ENGLISH,
model_type: ModelType::Marian,
};
pub const FRENCH2GERMAN: RemoteTranslationResources = Self {
model_resource: MarianModelResources::FRENCH2GERMAN,
config_resource: MarianConfigResources::FRENCH2GERMAN,
vocab_resource: MarianVocabResources::FRENCH2GERMAN,
merges_resource: MarianSpmResources::FRENCH2GERMAN,
prefix: MarianPrefix::FRENCH2GERMAN,
model_type: ModelType::Marian,
};
pub const GERMAN2FRENCH: RemoteTranslationResources = Self {
model_resource: MarianModelResources::GERMAN2FRENCH,
config_resource: MarianConfigResources::GERMAN2FRENCH,
vocab_resource: MarianVocabResources::GERMAN2FRENCH,
merges_resource: MarianSpmResources::GERMAN2FRENCH,
prefix: MarianPrefix::GERMAN2FRENCH,
model_type: ModelType::Marian,
};
}
/// # Configuration for text translation
@ -411,19 +276,19 @@ pub struct TranslationConfig {
/// Merges resource (default: pretrained BART model on CNN-DM)
pub merges_resource: Resource,
/// Minimum sequence length (default: 0)
pub min_length: u64,
pub min_length: i64,
/// Maximum sequence length (default: 20)
pub max_length: u64,
pub max_length: i64,
/// Sampling flag. If true, will perform top-k and/or nucleus sampling on generated tokens, otherwise greedy (deterministic) decoding (default: true)
pub do_sample: bool,
/// Early stopping flag indicating if the beam search should stop as soon as `num_beam` hypotheses have been generated (default: false)
pub early_stopping: bool,
/// Number of beams for beam search (default: 5)
pub num_beams: u64,
pub num_beams: i64,
/// Temperature setting. Values higher than 1 will improve originality at the risk of reducing relevance (default: 1.0)
pub temperature: f64,
/// Top_k values for sampling tokens. Value higher than 0 will enable the feature (default: 0)
pub top_k: u64,
pub top_k: i64,
/// Top_p value for [Nucleus sampling, Holtzman et al.](http://arxiv.org/abs/1904.09751). Keep top tokens until cumulative probability reaches top_p (default: 0.9)
pub top_p: f64,
/// Repetition penalty (mostly useful for CTRL decoders). Values higher than 1 will penalize tokens that have been already generated. (default: 1.0)
@ -431,9 +296,9 @@ pub struct TranslationConfig {
/// Exponential penalty based on the length of the hypotheses generated (default: 1.0)
pub length_penalty: f64,
/// Number of allowed repetitions of n-grams. Values higher than 0 turn on this feature (default: 3)
pub no_repeat_ngram_size: u64,
pub no_repeat_ngram_size: i64,
/// Number of sequences to return for each prompt text (default: 1)
pub num_return_sequences: u64,
pub num_return_sequences: i64,
/// Device to place the model on (default: CUDA/GPU when available)
pub device: Device,
/// Prefix to append translation inputs with
@ -463,37 +328,44 @@ impl TranslationConfig {
/// # }
/// ```
pub fn new(language: Language, device: Device) -> TranslationConfig {
let (model_resource, config_resource, vocab_resource, merges_resource, prefix, model_type) =
match language {
Language::EnglishToFrench => RemoteTranslationResources::ENGLISH2FRENCH,
Language::EnglishToCatalan => RemoteTranslationResources::ENGLISH2CATALAN,
Language::EnglishToSpanish => RemoteTranslationResources::ENGLISH2SPANISH,
Language::EnglishToPortuguese => RemoteTranslationResources::ENGLISH2PORTUGUESE,
Language::EnglishToItalian => RemoteTranslationResources::ENGLISH2ITALIAN,
Language::EnglishToRomanian => RemoteTranslationResources::ENGLISH2ROMANIAN,
Language::EnglishToGerman => RemoteTranslationResources::ENGLISH2GERMAN,
Language::EnglishToRussian => RemoteTranslationResources::ENGLISH2RUSSIAN,
let translation_resource = match language {
Language::EnglishToFrench => RemoteTranslationResources::ENGLISH2FRENCH,
Language::EnglishToCatalan => RemoteTranslationResources::ENGLISH2CATALAN,
Language::EnglishToSpanish => RemoteTranslationResources::ENGLISH2SPANISH,
Language::EnglishToPortuguese => RemoteTranslationResources::ENGLISH2PORTUGUESE,
Language::EnglishToItalian => RemoteTranslationResources::ENGLISH2ITALIAN,
Language::EnglishToRomanian => RemoteTranslationResources::ENGLISH2ROMANIAN,
Language::EnglishToGerman => RemoteTranslationResources::ENGLISH2GERMAN,
Language::EnglishToRussian => RemoteTranslationResources::ENGLISH2RUSSIAN,
Language::FrenchToEnglish => RemoteTranslationResources::FRENCH2ENGLISH,
Language::CatalanToEnglish => RemoteTranslationResources::CATALAN2ENGLISH,
Language::SpanishToEnglish => RemoteTranslationResources::SPANISH2ENGLISH,
Language::PortugueseToEnglish => RemoteTranslationResources::PORTUGUESE2ENGLISH,
Language::ItalianToEnglish => RemoteTranslationResources::ITALIAN2ENGLISH,
Language::RomanianToEnglish => RemoteTranslationResources::ROMANIAN2ENGLISH,
Language::GermanToEnglish => RemoteTranslationResources::GERMAN2ENGLISH,
Language::RussianToEnglish => RemoteTranslationResources::RUSSIAN2ENGLISH,
Language::FrenchToEnglish => RemoteTranslationResources::FRENCH2ENGLISH,
Language::CatalanToEnglish => RemoteTranslationResources::CATALAN2ENGLISH,
Language::SpanishToEnglish => RemoteTranslationResources::SPANISH2ENGLISH,
Language::PortugueseToEnglish => RemoteTranslationResources::PORTUGUESE2ENGLISH,
Language::ItalianToEnglish => RemoteTranslationResources::ITALIAN2ENGLISH,
Language::RomanianToEnglish => RemoteTranslationResources::ROMANIAN2ENGLISH,
Language::GermanToEnglish => RemoteTranslationResources::GERMAN2ENGLISH,
Language::RussianToEnglish => RemoteTranslationResources::RUSSIAN2ENGLISH,
Language::EnglishToFrenchV2 => RemoteTranslationResources::ENGLISH2FRENCH_V2,
Language::EnglishToGermanV2 => RemoteTranslationResources::ENGLISH2GERMAN_V2,
Language::EnglishToFrenchV2 => RemoteTranslationResources::ENGLISH2FRENCH_V2,
Language::EnglishToGermanV2 => RemoteTranslationResources::ENGLISH2GERMAN_V2,
Language::FrenchToGerman => RemoteTranslationResources::FRENCH2GERMAN,
Language::GermanToFrench => RemoteTranslationResources::GERMAN2FRENCH,
};
let model_resource = Resource::Remote(RemoteResource::from_pretrained(model_resource));
let config_resource = Resource::Remote(RemoteResource::from_pretrained(config_resource));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(vocab_resource));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(merges_resource));
let prefix = match prefix {
Language::FrenchToGerman => RemoteTranslationResources::FRENCH2GERMAN,
Language::GermanToFrench => RemoteTranslationResources::GERMAN2FRENCH,
};
let model_resource = Resource::Remote(RemoteResource::from_pretrained(
translation_resource.model_resource,
));
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
translation_resource.config_resource,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
translation_resource.vocab_resource,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
translation_resource.merges_resource,
));
let prefix = match translation_resource.prefix {
Some(value) => Some(value.to_string()),
None => None,
};
@ -516,7 +388,7 @@ impl TranslationConfig {
num_return_sequences: 1,
device,
prefix,
model_type,
model_type: translation_resource.model_type,
}
}
@ -740,10 +612,10 @@ impl TranslationModel {
pub fn translate(&self, texts: &[&str]) -> Vec<String> {
match &self.prefix {
Some(value) => {
let texts: Vec<String> = texts
.into_iter()
let texts = texts
.iter()
.map(|&v| format!("{} {}", value, v))
.collect();
.collect::<Vec<String>>();
self.model
.generate(Some(texts.iter().map(AsRef::as_ref).collect()), None)
}

View File

@ -332,7 +332,7 @@ impl ZeroShotClassificationOption {
None,
train,
)
.0
.decoder_output
}
Self::Bert(ref model) => {
model
@ -344,13 +344,13 @@ impl ZeroShotClassificationOption {
input_embeds,
train,
)
.0
.logits
}
Self::DistilBert(ref model) => {
model
.forward_t(input_ids, mask, input_embeds, train)
.expect("Error in distilbert forward_t")
.0
.logits
}
Self::Roberta(ref model) | Self::XLMRoberta(ref model) => {
model
@ -362,7 +362,7 @@ impl ZeroShotClassificationOption {
input_embeds,
train,
)
.0
.logits
}
Self::Albert(ref model) => {
model
@ -374,7 +374,7 @@ impl ZeroShotClassificationOption {
input_embeds,
train,
)
.0
.logits
}
}
}
@ -447,13 +447,13 @@ impl ZeroShotClassificationModel {
let label_sentences: Vec<String> = match template {
Some(function) => labels.iter().map(|label| function(label)).collect(),
None => labels
.into_iter()
.iter()
.map(|label| format!("This example is about {}.", label))
.collect(),
};
let text_pair_list = inputs
.into_iter()
.iter()
.cartesian_product(label_sentences.iter())
.map(|(&s, label)| (s, label.as_str()))
.collect();

View File

@ -13,6 +13,7 @@
use crate::bert::{BertConfig, BertEmbedding};
use crate::common::dropout::Dropout;
use crate::RustBertError;
use std::borrow::Borrow;
use tch::nn::{embedding, EmbeddingConfig};
use tch::{nn, Kind, Tensor};
@ -174,11 +175,13 @@ impl BertEmbedding for RobertaEmbeddings {
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool,
) -> Result<Tensor, &'static str> {
) -> Result<Tensor, RustBertError> {
let (input_embeddings, input_shape) = match &input_ids {
Some(input_value) => match &input_embeds {
Some(_) => {
return Err("Only one of input ids or input embeddings may be set");
return Err(RustBertError::ValueError(
"Only one of input ids or input embeddings may be set".into(),
));
}
None => (
input_value.apply_t(&self.word_embeddings, train),
@ -188,7 +191,9 @@ impl BertEmbedding for RobertaEmbeddings {
None => match &input_embeds {
Some(embeds) => (embeds.copy(), vec![embeds.size()[0], embeds.size()[1]]),
None => {
return Err("Only one of input ids or input embeddings may be set");
return Err(RustBertError::ValueError(
"At least one of input ids or input embeddings must be set".into(),
));
}
},
};

View File

@ -10,7 +10,7 @@
//!
//! # Model set-up and pre-trained weights loading
//!
//! A full working example is provided in `examples/robert.rs`, run with `cargo run --example roberta`.
//! A full working example is provided in `examples/roberta.rs`, run with `cargo run --example roberta`.
//! The example below illustrate a Masked language model example, the structure is similar for other models.
//! All models expect the following resources:
//! - Configuration file expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers)
@ -63,10 +63,10 @@
//! ```
mod embeddings;
mod roberta;
mod roberta_model;
pub use embeddings::RobertaEmbeddings;
pub use roberta::{
pub use roberta_model::{
RobertaConfigResources, RobertaForMaskedLM, RobertaForMultipleChoice,
RobertaForQuestionAnswering, RobertaForSequenceClassification, RobertaForTokenClassification,
RobertaMergesResources, RobertaModelResources, RobertaVocabResources,

View File

@ -282,7 +282,7 @@ impl RobertaForMaskedLM {
/// let position_ids = Tensor::arange(sequence_length, (Int64, device))
/// .expand(&[batch_size, sequence_length], true);
///
/// let (output, all_hidden_states, all_attentions) = no_grad(|| {
/// let model_output = no_grad(|| {
/// roberta_model.forward_t(
/// Some(input_tensor),
/// Some(mask),
@ -305,8 +305,8 @@ impl RobertaForMaskedLM {
encoder_hidden_states: &Option<Tensor>,
encoder_mask: &Option<Tensor>,
train: bool,
) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
let (hidden_state, _, all_hidden_states, all_attentions) = self
) -> RobertaMaskedLMOutput {
let base_model_output = self
.roberta
.forward_t(
input_ids,
@ -320,8 +320,12 @@ impl RobertaForMaskedLM {
)
.unwrap();
let prediction_scores = self.lm_head.forward(&hidden_state);
(prediction_scores, all_hidden_states, all_attentions)
let prediction_scores = self.lm_head.forward(&base_model_output.hidden_state);
RobertaMaskedLMOutput {
prediction_scores,
all_hidden_states: base_model_output.all_hidden_states,
all_attentions: base_model_output.all_attentions,
}
}
}
@ -434,9 +438,10 @@ impl RobertaForSequenceClassification {
///
/// # Returns
///
/// * `labels` - `Tensor` of shape (*batch size*, *num_labels*)
/// * `hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
/// * `attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
/// * `RobertaSequenceClassificationOutput` containing:
/// - `logits` - `Tensor` of shape (*batch size*, *num_labels*)
/// - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
/// - `all_attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
///
/// # Example
///
@ -460,7 +465,7 @@ impl RobertaForSequenceClassification {
/// let position_ids = Tensor::arange(sequence_length, (Int64, device))
/// .expand(&[batch_size, sequence_length], true);
///
/// let (labels, all_hidden_states, all_attentions) = no_grad(|| {
/// let model_output = no_grad(|| {
/// roberta_model.forward_t(
/// Some(input_tensor),
/// Some(mask),
@ -479,8 +484,8 @@ impl RobertaForSequenceClassification {
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool,
) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
let (hidden_state, _, all_hidden_states, all_attentions) = self
) -> RobertaSequenceClassificationOutput {
let base_model_output = self
.roberta
.forward_t(
input_ids,
@ -494,8 +499,14 @@ impl RobertaForSequenceClassification {
)
.unwrap();
let output = self.classifier.forward_t(&hidden_state, train);
(output, all_hidden_states, all_attentions)
let logits = self
.classifier
.forward_t(&base_model_output.hidden_state, train);
RobertaSequenceClassificationOutput {
logits,
all_hidden_states: base_model_output.all_hidden_states,
all_attentions: base_model_output.all_attentions,
}
}
}
@ -563,9 +574,10 @@ impl RobertaForMultipleChoice {
///
/// # Returns
///
/// * `output` - `Tensor` of shape (*1*, *batch size*) containing the logits for each of the alternatives given
/// * `hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
/// * `attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
/// * `RobertaSequenceClassificationOutput` containing:
/// - `logits` - `Tensor` of shape (*1*, *batch size*) containing the logits for each of the alternatives given
/// - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
/// - `all_attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
///
/// # Example
///
@ -589,7 +601,7 @@ impl RobertaForMultipleChoice {
/// let position_ids = Tensor::arange(sequence_length, (Int64, device))
/// .expand(&[num_choices, sequence_length], true);
///
/// let (choices, all_hidden_states, all_attentions) = no_grad(|| {
/// let model_output = no_grad(|| {
/// roberta_model.forward_t(
/// input_tensor,
/// Some(mask),
@ -606,7 +618,7 @@ impl RobertaForMultipleChoice {
token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>,
train: bool,
) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
) -> RobertaSequenceClassificationOutput {
let num_choices = input_ids.size()[1];
let flat_input_ids = Some(input_ids.view((-1i64, *input_ids.size().last().unwrap())));
@ -623,7 +635,7 @@ impl RobertaForMultipleChoice {
None => None,
};
let (_, pooled_output, all_hidden_states, all_attentions) = self
let base_model_output = self
.roberta
.forward_t(
flat_input_ids,
@ -637,11 +649,16 @@ impl RobertaForMultipleChoice {
)
.unwrap();
let output = pooled_output
let logits = base_model_output
.pooled_output
.apply_t(&self.dropout, train)
.apply(&self.classifier)
.view((-1, num_choices));
(output, all_hidden_states, all_attentions)
RobertaSequenceClassificationOutput {
logits,
all_hidden_states: base_model_output.all_hidden_states,
all_attentions: base_model_output.all_attentions,
}
}
}
@ -719,9 +736,10 @@ impl RobertaForTokenClassification {
///
/// # Returns
///
/// * `output` - `Tensor` of shape (*batch size*, *sequence_length*, *num_labels*) containing the logits for each of the input tokens and classes
/// * `hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
/// * `attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
/// * `RobertaTokenClassificationOutput` containing:
/// - `logits` - `Tensor` of shape (*batch size*, *sequence_length*, *num_labels*) containing the logits for each of the input tokens and classes
/// - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
/// - `all_attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
///
/// # Example
///
@ -745,7 +763,7 @@ impl RobertaForTokenClassification {
/// let position_ids = Tensor::arange(sequence_length, (Int64, device))
/// .expand(&[batch_size, sequence_length], true);
///
/// let (token_labels, all_hidden_states, all_attentions) = no_grad(|| {
/// let model_output = no_grad(|| {
/// roberta_model.forward_t(
/// Some(input_tensor),
/// Some(mask),
@ -764,8 +782,8 @@ impl RobertaForTokenClassification {
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool,
) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
let (hidden_state, _, all_hidden_states, all_attentions) = self
) -> RobertaTokenClassificationOutput {
let base_model_output = self
.roberta
.forward_t(
input_ids,
@ -779,10 +797,16 @@ impl RobertaForTokenClassification {
)
.unwrap();
let sequence_output = hidden_state
let logits = base_model_output
.hidden_state
.apply_t(&self.dropout, train)
.apply(&self.classifier);
(sequence_output, all_hidden_states, all_attentions)
RobertaTokenClassificationOutput {
logits,
all_hidden_states: base_model_output.all_hidden_states,
all_attentions: base_model_output.all_attentions,
}
}
}
@ -854,10 +878,11 @@ impl RobertaForQuestionAnswering {
///
/// # Returns
///
/// * `start_scores` - `Tensor` of shape (*batch size*, *sequence_length*) containing the logits for start of the answer
/// * `end_scores` - `Tensor` of shape (*batch size*, *sequence_length*) containing the logits for end of the answer
/// * `hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
/// * `attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
/// * `RobertaQuestionAnsweringOutput` containing:
/// - `start_logits` - `Tensor` of shape (*batch size*, *sequence_length*) containing the logits for start of the answer
/// - `end_logits` - `Tensor` of shape (*batch size*, *sequence_length*) containing the logits for end of the answer
/// - `all_hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
/// - `all_attentions` - `Option<Vec<Vec<Tensor>>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
///
/// # Example
///
@ -881,7 +906,7 @@ impl RobertaForQuestionAnswering {
/// let position_ids = Tensor::arange(sequence_length, (Int64, device))
/// .expand(&[batch_size, sequence_length], true);
///
/// let (start_scores, end_scores, all_hidden_states, all_attentions) = no_grad(|| {
/// let model_output = no_grad(|| {
/// roberta_model.forward_t(
/// Some(input_tensor),
/// Some(mask),
@ -900,8 +925,8 @@ impl RobertaForQuestionAnswering {
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool,
) -> (Tensor, Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
let (hidden_state, _, all_hidden_states, all_attentions) = self
) -> RobertaQuestionAnsweringOutput {
let base_model_output = self
.roberta
.forward_t(
input_ids,
@ -915,12 +940,59 @@ impl RobertaForQuestionAnswering {
)
.unwrap();
let sequence_output = hidden_state.apply(&self.qa_outputs);
let sequence_output = base_model_output.hidden_state.apply(&self.qa_outputs);
let logits = sequence_output.split(1, -1);
let (start_logits, end_logits) = (&logits[0], &logits[1]);
let start_logits = start_logits.squeeze1(-1);
let end_logits = end_logits.squeeze1(-1);
(start_logits, end_logits, all_hidden_states, all_attentions)
RobertaQuestionAnsweringOutput {
start_logits,
end_logits,
all_hidden_states: base_model_output.all_hidden_states,
all_attentions: base_model_output.all_attentions,
}
}
}
/// Container for the RoBERTa masked LM model output.
pub struct RobertaMaskedLMOutput {
/// Logits for the vocabulary items at each sequence position
pub prediction_scores: Tensor,
/// Hidden states for all intermediate layers
pub all_hidden_states: Option<Vec<Tensor>>,
/// Attention weights for all intermediate layers
pub all_attentions: Option<Vec<Tensor>>,
}
/// Container for the RoBERTa sequence classification model output.
pub struct RobertaSequenceClassificationOutput {
/// Logits for each input (sequence) for each target class
pub logits: Tensor,
/// Hidden states for all intermediate layers
pub all_hidden_states: Option<Vec<Tensor>>,
/// Attention weights for all intermediate layers
pub all_attentions: Option<Vec<Tensor>>,
}
/// Container for the RoBERTa token classification model output.
pub struct RobertaTokenClassificationOutput {
/// Logits for each sequence item (token) for each target class
pub logits: Tensor,
/// Hidden states for all intermediate layers
pub all_hidden_states: Option<Vec<Tensor>>,
/// Attention weights for all intermediate layers
pub all_attentions: Option<Vec<Tensor>>,
}
/// Container for the RoBERTa question answering model output.
pub struct RobertaQuestionAnsweringOutput {
/// Logits for the start position for token of each input sequence
pub start_logits: Tensor,
/// Logits for the end position for token of each input sequence
pub end_logits: Tensor,
/// Hidden states for all intermediate layers
pub all_hidden_states: Option<Vec<Tensor>>,
/// Attention weights for all intermediate layers
pub all_attentions: Option<Vec<Tensor>>,
}

View File

@ -55,10 +55,10 @@ pub struct T5Attention {
inner_dim: i64,
output_attentions: bool,
store_cache: bool,
q: nn::Linear,
k: nn::Linear,
v: nn::Linear,
o: nn::Linear,
query: nn::Linear,
key: nn::Linear,
value: nn::Linear,
output: nn::Linear,
relative_attention_bias: Option<nn::Embedding>,
}
@ -82,10 +82,10 @@ impl T5Attention {
};
let inner_dim = config.num_heads * config.d_kv;
let k = nn::linear(p / "k", config.d_model, inner_dim, linear_config);
let v = nn::linear(p / "v", config.d_model, inner_dim, linear_config);
let q = nn::linear(p / "q", config.d_model, inner_dim, linear_config);
let o = nn::linear(p / "o", inner_dim, config.d_model, linear_config);
let key = nn::linear(p / "k", config.d_model, inner_dim, linear_config);
let value = nn::linear(p / "v", config.d_model, inner_dim, linear_config);
let query = nn::linear(p / "q", config.d_model, inner_dim, linear_config);
let output = nn::linear(p / "o", inner_dim, config.d_model, linear_config);
let dropout = Dropout::new(config.dropout_rate);
let relative_attention_bias = if has_relative_attention_bias {
@ -110,10 +110,10 @@ impl T5Attention {
inner_dim,
output_attentions,
store_cache,
q,
k,
v,
o,
query,
key,
value,
output,
relative_attention_bias,
}
}
@ -155,17 +155,17 @@ impl T5Attention {
None => real_query_length,
};
let q: Tensor = self.shape(hidden_states.as_ref().apply(&self.q), bs);
let q: Tensor = self.shape(hidden_states.as_ref().apply(&self.query), bs);
let (mut k, mut v) = if kv.is_none() {
(
self.shape(hidden_states.apply(&self.k), bs),
self.shape(hidden_states.apply(&self.v), bs),
self.shape(hidden_states.apply(&self.key), bs),
self.shape(hidden_states.apply(&self.value), bs),
)
} else {
(
self.shape(kv.as_ref().unwrap().apply(&self.k), bs),
self.shape(kv.as_ref().unwrap().apply(&self.v), bs),
self.shape(kv.as_ref().unwrap().apply(&self.key), bs),
self.shape(kv.as_ref().unwrap().apply(&self.value), bs),
)
};
@ -198,18 +198,18 @@ impl T5Attention {
let length = temp_value.size()[2];
temp_value = temp_value.slice(2, length - 1, length, 1);
};
if attention_mask.is_some() {
temp_value = temp_value + attention_mask.unwrap();
if let Some(attention_mask) = attention_mask {
temp_value = temp_value + attention_mask
};
Some(temp_value)
} else {
None
};
let position_bias = if position_bias.is_none() {
calculated_position_bias.as_ref().unwrap()
let position_bias = if let Some(position_bias) = position_bias {
position_bias
} else {
position_bias.unwrap()
calculated_position_bias.as_ref().unwrap()
};
scores += position_bias;
@ -219,7 +219,7 @@ impl T5Attention {
.apply_t(&self.dropout, train);
let context = self
.unshape(attention_weights.matmul(&v), bs)
.apply(&self.o);
.apply(&self.output);
let attention_weights = if self.output_attentions {
Some(attention_weights)
@ -247,7 +247,7 @@ impl T5Attention {
let mut num_buckets = num_buckets;
let mut ret = n.zeros_like();
let n = if bidirectional {
num_buckets = num_buckets / 2;
num_buckets /= 2;
ret += n.lt(0).to_kind(Kind::Int64) * num_buckets;
n.abs()
} else {

View File

@ -14,6 +14,7 @@ use crate::common::dropout::Dropout;
use crate::t5::attention::{LayerState, T5LayerCrossAttention, T5LayerSelfAttention};
use crate::t5::layer_norm::T5LayerNorm;
use crate::t5::T5Config;
use crate::RustBertError;
use std::borrow::{Borrow, BorrowMut};
use tch::nn::LinearConfig;
use tch::{nn, Kind, Tensor};
@ -148,12 +149,7 @@ impl T5Block {
encoder_decoder_position_bias: Option<&Tensor>,
mut layer_states: (Option<LayerState>, Option<LayerState>),
train: bool,
) -> (
Tensor,
(Option<Tensor>, Option<Tensor>),
(Option<Tensor>, Option<Tensor>),
(Option<LayerState>, Option<LayerState>),
) {
) -> T5BlockOutput {
let (
hidden_states,
self_attention_weights,
@ -190,17 +186,17 @@ impl T5Block {
(hidden_states, None, None, None)
};
let attention_weights = (self_attention_weights, cross_attention_weights);
let position_bias = (self_attention_position_bias, cross_attention_position_bias);
layer_states = (self_attention_layer_past, cross_attention_layer_past);
let hidden_states = self.ff_layer.forward_t(&hidden_states, train);
(
T5BlockOutput {
hidden_states,
attention_weights,
position_bias,
layer_states,
)
self_attention_weights,
cross_attention_weights,
self_attention_position_bias,
cross_attention_position_bias,
cache: layer_states,
}
}
}
@ -269,19 +265,13 @@ impl T5Stack {
embeddings: &nn::Embedding,
old_layer_states: Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
train: bool,
) -> Result<
(
Tensor,
Option<Vec<Tensor>>,
Option<Vec<Tensor>>,
Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
),
&'static str,
> {
) -> Result<T5StackOutput, RustBertError> {
let (input_embeddings, input_shape) = match input_ids {
Some(input_ids_value) => match input_embeds {
Some(_) => {
return Err("Only one of input ids or input embeddings may be set");
return Err(RustBertError::ValueError(
"Only one of input ids or input embeddings may be set".into(),
));
}
None => (input_ids_value.apply(embeddings), input_ids_value.size()),
},
@ -291,7 +281,9 @@ impl T5Stack {
(embeds, size)
}
None => {
return Err("Only one of input ids or input embeddings may be set");
return Err(RustBertError::ValueError(
"At least one of input ids or input embeddings must be set".into(),
));
}
},
};
@ -332,7 +324,7 @@ impl T5Stack {
if self.is_decoder {
let seq_ids =
Tensor::arange(input_shape[1], (Kind::Float, input_embeddings.device()));
let causal_mask = seq_ids.unsqueeze(0).unsqueeze(0).repeat(&vec![
let causal_mask = seq_ids.unsqueeze(0).unsqueeze(0).repeat(&[
input_shape[0],
input_shape[1],
1,
@ -344,7 +336,9 @@ impl T5Stack {
}
}
_ => {
return Err("Invalid attention mask dimension, must be 2 or 3");
return Err(RustBertError::ValueError(
"Invalid attention mask dimension, must be 2 or 3".into(),
));
}
};
@ -368,7 +362,9 @@ impl T5Stack {
2 => encoder_mask.unsqueeze(1).unsqueeze(1),
3 => encoder_mask.unsqueeze(1),
_ => {
return Err("Invalid encoder attention mask dimension, must be 2 or 3");
return Err(RustBertError::ValueError(
"Invalid attention mask dimension, must be 2 or 3".into(),
));
}
};
Some((encoder_mask.ones_like() - encoder_mask) * -1e9)
@ -386,7 +382,7 @@ impl T5Stack {
} else {
None
};
let mut next_decoder_cache: Option<Vec<(Option<LayerState>, Option<LayerState>)>> =
let mut next_cache: Option<Vec<(Option<LayerState>, Option<LayerState>)>> =
if self.store_cache {
if old_layer_states.is_some() {
old_layer_states
@ -400,42 +396,36 @@ impl T5Stack {
let mut encoder_decoder_position_bias = None;
let mut attention_weights: Option<Tensor>;
let mut hidden_state = input_embeddings.apply_t(&self.dropout, train);
let mut blocks = self.blocks.iter().enumerate();
loop {
match blocks.next() {
Some((layer_idx, layer)) => {
let layer_state = match &next_decoder_cache {
Some(values) => values[layer_idx].to_owned(),
None => (None, None),
};
let temp = layer.forward_t(
&hidden_state,
position_bias.as_ref(),
extended_attention_mask.as_ref(),
encoder_hidden_states,
extended_encoder_attention_mask.as_ref(),
encoder_decoder_position_bias.as_ref(),
layer_state,
train,
);
if layer_idx == 0 {
position_bias = (temp.2).0;
encoder_decoder_position_bias = (temp.2).1;
}
hidden_state = temp.0;
attention_weights = (temp.1).1;
if let Some(hidden_states) = all_hidden_states.borrow_mut() {
hidden_states.push(hidden_state.as_ref().copy().transpose(0, 1));
};
if let Some(attentions) = all_attentions.borrow_mut() {
attentions.push(attention_weights.as_ref().unwrap().copy());
};
if let Some(value) = &mut next_decoder_cache {
value[layer_idx] = temp.3
};
}
None => break,
for (layer_idx, layer) in self.blocks.iter().enumerate() {
let layer_state = match &next_cache {
Some(values) => values[layer_idx].to_owned(),
None => (None, None),
};
let block_output = layer.forward_t(
&hidden_state,
position_bias.as_ref(),
extended_attention_mask.as_ref(),
encoder_hidden_states,
extended_encoder_attention_mask.as_ref(),
encoder_decoder_position_bias.as_ref(),
layer_state,
train,
);
if layer_idx == 0 {
position_bias = block_output.self_attention_position_bias;
encoder_decoder_position_bias = block_output.cross_attention_position_bias;
}
hidden_state = block_output.hidden_states;
attention_weights = block_output.cross_attention_weights;
if let Some(hidden_states) = all_hidden_states.borrow_mut() {
hidden_states.push(hidden_state.as_ref().copy().transpose(0, 1));
};
if let Some(attentions) = all_attentions.borrow_mut() {
attentions.push(attention_weights.as_ref().unwrap().copy());
};
if let Some(value) = &mut next_cache {
value[layer_idx] = block_output.cache
};
}
@ -443,11 +433,27 @@ impl T5Stack {
.apply(&self.final_layer_norm)
.apply_t(&self.dropout, train);
Ok((
Ok(T5StackOutput {
hidden_state,
all_hidden_states,
all_attentions,
next_decoder_cache,
))
next_cache,
})
}
}
pub struct T5BlockOutput {
pub hidden_states: Tensor,
pub self_attention_weights: Option<Tensor>,
pub cross_attention_weights: Option<Tensor>,
pub self_attention_position_bias: Option<Tensor>,
pub cross_attention_position_bias: Option<Tensor>,
pub cache: (Option<LayerState>, Option<LayerState>),
}
pub struct T5StackOutput {
pub hidden_state: Tensor,
pub all_hidden_states: Option<Vec<Tensor>>,
pub all_attentions: Option<Vec<Tensor>>,
pub next_cache: Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
}

View File

@ -6,7 +6,7 @@
//!
//! # Model set-up and pre-trained weights loading
//!
//! A full working example (translation) is provided in `examples/t5.rs`, run with `cargo run --example t5`.
//! A full working example (translation) is provided in `examples/t5`, run with `cargo run --example t5`.
//! All models expect the following resources:
//! - Configuration file expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers)
//! - Model weights are expected to have a structure and parameter names following the [Transformers library](https://github.com/huggingface/transformers). A conversion using the Python utility scripts is required to convert the `.bin` weights to the `.ot` format.
@ -51,10 +51,10 @@
mod attention;
mod encoder;
mod layer_norm;
mod t5;
mod t5_model;
pub use attention::LayerState;
pub use t5::{
T5Config, T5ConfigResources, T5ForConditionalGeneration, T5Model, T5ModelResources, T5Prefix,
T5VocabResources,
pub use t5_model::{
T5Config, T5ConfigResources, T5ForConditionalGeneration, T5Model, T5ModelOutput,
T5ModelResources, T5Prefix, T5VocabResources,
};

View File

@ -9,10 +9,10 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use crate::pipelines::generation::{Cache, LMHeadModel};
use crate::pipelines::generation::{Cache, LMHeadModel, LMModelOutput};
use crate::t5::attention::LayerState;
use crate::t5::encoder::T5Stack;
use crate::Config;
use crate::t5::encoder::{T5Stack, T5StackOutput};
use crate::{Config, RustBertError};
use serde::{Deserialize, Serialize};
use std::borrow::Borrow;
use tch::nn::embedding;
@ -251,14 +251,14 @@ impl T5Model {
///
/// # Returns
///
/// * `decoder_output` - `Tensor` of shape (*batch size*, *target_sequence_length*, *hidden_size*) representing the activations of the last decoder hidden state
/// * `encoder_hidden_states` - `Tensor` of shape (*batch size*, *source_sequence_length*, *hidden_size*) representing the activations of the last encoder hidden state
/// * `decoder_cache` - `Option<Vec<(Option<Vec<&LayerState, &LayerState>>)>>` of length *n_layer* containing the encoder past keys and values for
/// both the self attention and the encoder cross attention of each layer of the decoder.
/// * `all_encoder_hidden_states` - `Option<Vec<Tensor>>` of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*)
/// * `all_encoder_attentions` - `Option<Vec<Tensor>>` of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*)
/// * `all_decoder_hidden_states` - `Option<Vec<Tensor>>` of length *num_decoder_layers* with shape (*batch size*, *target_sequence_length*, *hidden_size*)
/// * `all_decoder_attentions` - `Option<Vec<Tensor>>` of length *num_decoder_layers* with shape (*batch size*, *target_sequence_length*, *hidden_size*)
/// * `T5ModelOutput` containing:
/// - `decoder_output` - `Tensor` of shape (*batch size*, *target_sequence_length*, *hidden_size*) representing the activations of the last decoder hidden state
/// - `encoder_hidden_states` - `Tensor` of shape (*batch size*, *source_sequence_length*, *hidden_size*) representing the activations of the last encoder hidden state
/// - `cache` - `Option<Vec<(Option<Vec<LayerState, LayerState>>)>>` of length *n_layer* containing the encoder padding mask and past keys and values for both the self attention and the encoder cross attention of each layer of the decoder.
/// - `all_encoder_hidden_states` - `Option<Vec<Tensor>>` of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*)
/// - `all_encoder_attentions` - `Option<Vec<Tensor>>` of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*)
/// - `all_decoder_hidden_states` - `Option<Vec<Tensor>>` of length *num_decoder_layers* with shape (*batch size*, *target_sequence_length*, *hidden_size*)
/// - `all_decoder_attentions` - `Option<Vec<Tensor>>` of length *num_decoder_layers* with shape (*batch size*, *target_sequence_length*, *hidden_size*)
///
/// # Example
///
@ -282,15 +282,7 @@ impl T5Model {
/// let decoder_attention_mask =
/// Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
///
/// let (
/// decoder_output,
/// encoder_hidden_states,
/// decoder_cache,
/// all_encoder_hidden_states,
/// all_encoder_attentions,
/// all_decoder_hidden_states,
/// all_decoder_attentions,
/// ) = no_grad(|| {
/// let model_output = no_grad(|| {
/// t5_model.forward_t(
/// Some(&input_tensor),
/// Some(&encoder_attention_mask),
@ -308,52 +300,30 @@ impl T5Model {
&self,
input_ids: Option<&Tensor>,
attention_mask: Option<&Tensor>,
encoder_outputs: Option<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>)>,
encoder_outputs: Option<T5StackOutput>,
decoder_input_ids: Option<&Tensor>,
decoder_attention_mask: Option<&Tensor>,
input_embeds: Option<Tensor>,
decoder_input_embeds: Option<Tensor>,
old_layer_states: Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
train: bool,
) -> (
Tensor,
Tensor,
Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
Option<Vec<Tensor>>,
Option<Vec<Tensor>>,
Option<Vec<Tensor>>,
Option<Vec<Tensor>>,
) {
let (encoder_hidden_states, all_encoder_hidden_states, all_encoder_attentions) =
match encoder_outputs {
Some(value) => value,
None => {
let (
encoder_hidden_states,
all_encoder_hidden_states,
all_encoder_attentions,
_,
) = self
.encoder
.forward_t(
input_ids,
attention_mask,
None,
None,
input_embeds,
&self.embeddings,
None,
train,
)
.unwrap();
(
encoder_hidden_states,
all_encoder_hidden_states,
all_encoder_attentions,
)
}
};
) -> T5ModelOutput {
let encoder_output = match encoder_outputs {
Some(value) => value,
None => self
.encoder
.forward_t(
input_ids,
attention_mask,
None,
None,
input_embeds,
&self.embeddings,
None,
train,
)
.unwrap(),
};
let (calculated_decoder_input_ids, calculated_decoder_input_embeds) =
if old_layer_states.is_some() {
let decoder_input_ids = match decoder_input_ids {
@ -377,29 +347,28 @@ impl T5Model {
(decoder_input_ids, decoder_input_embeds)
};
let (decoder_outputs, all_decoder_hidden_states, all_decoder_attentions, decoder_cache) =
self.decoder
.forward_t(
decoder_input_ids,
decoder_attention_mask,
Some(&encoder_hidden_states),
attention_mask,
decoder_input_embeds,
&self.embeddings,
old_layer_states,
train,
)
.unwrap();
(
decoder_outputs,
encoder_hidden_states,
decoder_cache,
all_decoder_hidden_states,
all_decoder_attentions,
all_encoder_hidden_states,
all_encoder_attentions,
)
let decoder_output = self
.decoder
.forward_t(
decoder_input_ids,
decoder_attention_mask,
Some(&encoder_output.hidden_state),
attention_mask,
decoder_input_embeds,
&self.embeddings,
old_layer_states,
train,
)
.unwrap();
T5ModelOutput {
decoder_output: decoder_output.hidden_state,
encoder_hidden_state: encoder_output.hidden_state,
next_cache: decoder_output.next_cache,
all_decoder_hidden_states: decoder_output.all_hidden_states,
all_decoder_attentions: decoder_output.all_attentions,
all_encoder_hidden_states: encoder_output.all_hidden_states,
all_encoder_attentions: encoder_output.all_attentions,
}
}
}
@ -480,14 +449,14 @@ impl T5ForConditionalGeneration {
///
/// # Returns
///
/// * `decoder_output` - `Tensor` of shape (*batch size*, *target_sequence_length*, *hidden_size*) representing the activations of the last decoder hidden state
/// * `encoder_hidden_states` - `Tensor` of shape (*batch size*, *source_sequence_length*, *hidden_size*) representing the activations of the last encoder hidden state
/// * `decoder_cache` - `Option<Vec<(Option<Vec<&LayerState, &LayerState>>)>>` of length *n_layer* containing the encoder past keys and values for
/// both the self attention and the encoder cross attention of each layer of the decoder.
/// * `all_encoder_hidden_states` - `Option<Vec<Tensor>>` of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*)
/// * `all_encoder_attentions` - `Option<Vec<Tensor>>` of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*)
/// * `all_decoder_hidden_states` - `Option<Vec<Tensor>>` of length *num_decoder_layers* with shape (*batch size*, *target_sequence_length*, *hidden_size*)
/// * `all_decoder_attentions` - `Option<Vec<Tensor>>` of length *num_decoder_layers* with shape (*batch size*, *target_sequence_length*, *hidden_size*)
/// * `T5ModelOutput` containing:
/// - `decoder_output` - `Tensor` of shape (*batch size*, *target_sequence_length*, *vocab_size*) representing the logits for each sequence position and vocabulary item
/// - `encoder_hidden_states` - `Tensor` of shape (*batch size*, *source_sequence_length*, *hidden_size*) representing the activations of the last encoder hidden state
/// - `cache` - `Option<Vec<(Option<Vec<LayerState, LayerState>>)>>` of length *n_layer* containing the encoder padding mask and past keys and values for both the self attention and the encoder cross attention of each layer of the decoder.
/// - `all_encoder_hidden_states` - `Option<Vec<Tensor>>` of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*)
/// - `all_encoder_attentions` - `Option<Vec<Tensor>>` of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*)
/// - `all_decoder_hidden_states` - `Option<Vec<Tensor>>` of length *num_decoder_layers* with shape (*batch size*, *target_sequence_length*, *hidden_size*)
/// - `all_decoder_attentions` - `Option<Vec<Tensor>>` of length *num_decoder_layers* with shape (*batch size*, *target_sequence_length*, *hidden_size*)
///
/// # Example
///
@ -511,15 +480,7 @@ impl T5ForConditionalGeneration {
/// let decoder_attention_mask =
/// Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
///
/// let (
/// decoder_output,
/// encoder_hidden_states,
/// decoder_cache,
/// all_encoder_hidden_states,
/// all_encoder_attentions,
/// all_decoder_hidden_states,
/// all_decoder_attentions,
/// ) = no_grad(|| {
/// let model_output = no_grad(|| {
/// t5_model.forward_t(
/// Some(&input_tensor),
/// Some(&encoder_attention_mask),
@ -537,31 +498,15 @@ impl T5ForConditionalGeneration {
&self,
input_ids: Option<&Tensor>,
attention_mask: Option<&Tensor>,
encoder_outputs: Option<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>)>,
encoder_outputs: Option<T5StackOutput>,
decoder_input_ids: Option<&Tensor>,
decoder_attention_mask: Option<&Tensor>,
input_embeds: Option<Tensor>,
decoder_input_embeds: Option<Tensor>,
old_layer_states: Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
train: bool,
) -> (
Tensor,
Tensor,
Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
Option<Vec<Tensor>>,
Option<Vec<Tensor>>,
Option<Vec<Tensor>>,
Option<Vec<Tensor>>,
) {
let (
decoder_outputs,
encoder_hidden_states,
decoder_cache,
all_decoder_hidden_states,
all_decoder_attentions,
all_encoder_hidden_states,
all_encoder_attentions,
) = self.base_model.forward_t(
) -> T5ModelOutput {
let base_model_output = self.base_model.forward_t(
input_ids,
attention_mask,
encoder_outputs,
@ -572,23 +517,19 @@ impl T5ForConditionalGeneration {
old_layer_states,
train,
);
let lm_logits = decoder_outputs.linear::<Tensor>(&self.base_model.embeddings.ws, None)
let lm_logits = base_model_output
.decoder_output
.linear::<Tensor>(&self.base_model.embeddings.ws, None)
* (self.model_dim.powf(-0.5));
(
lm_logits,
encoder_hidden_states,
decoder_cache,
all_decoder_hidden_states,
all_decoder_attentions,
all_encoder_hidden_states,
all_encoder_attentions,
)
T5ModelOutput {
decoder_output: lm_logits,
..base_model_output
}
}
pub fn encode(&self, input_ids: &Tensor, attention_mask: Option<&Tensor>) -> Tensor {
let (encoder_hidden_states, _, _, _) = self
.base_model
self.base_model
.encoder
.forward_t(
Some(input_ids),
@ -600,8 +541,8 @@ impl T5ForConditionalGeneration {
None,
false,
)
.unwrap();
encoder_hidden_states
.unwrap()
.hidden_state
}
}
@ -622,13 +563,14 @@ impl LMHeadModel for T5ForConditionalGeneration {
///
/// # Returns
///
/// * `lm_logits` - `Tensor` of shape (*batch size*, *sequence_length*, *vocab_size*) representing the logits for each vocab item and position
/// * `past` - `T5Cache` made of `Option<Vec<(Option<Vec<&LayerState, &LayerState>>)>>` of length *n_layer* containing the encoder past keys and values for
/// both the self attention and the encoder cross attention of each layer of the decoder.
/// * `encoder_hidden_states` - `Option<Tensor>` Hidden states for the encoder
/// * `hidden_states` - None
/// * `attentions` - None
/// * `LMModelOutput` containing:
/// - `lm_logits` - `Tensor` of shape (*batch size*, *sequence_length*, *vocab_size*) representing the logits for each vocab item and position
/// - `cache` - `T5Cache` made of `Option<Vec<(Option<Vec<&LayerState, &LayerState>>)>>` of length *n_layer* containing the encoder past keys and values for
/// both the self attention and the encoder cross attention of each layer of the decoder.
/// - `encoder_hidden_states` - `Option<Tensor>` Hidden states for the encoder
/// - `all_hidden_states` - None
/// - `all_attentions` - None
///
/// # Example
///
/// ```no_run
@ -651,15 +593,7 @@ impl LMHeadModel for T5ForConditionalGeneration {
/// let decoder_attention_mask =
/// Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
///
/// let (
/// decoder_output,
/// encoder_hidden_states,
/// decoder_cache,
/// all_encoder_hidden_states,
/// all_encoder_attentions,
/// all_decoder_hidden_states,
/// all_decoder_attentions,
/// ) = no_grad(|| {
/// let model_output = no_grad(|| {
/// t5_model.forward_t(
/// Some(&input_tensor),
/// Some(&encoder_attention_mask),
@ -684,21 +618,17 @@ impl LMHeadModel for T5ForConditionalGeneration {
encoder_outputs: Option<&Tensor>,
decoder_input_ids: &Option<Tensor>,
train: bool,
) -> Result<
(
Tensor,
Option<Tensor>,
Cache,
Option<Vec<Tensor>>,
Option<Vec<Tensor>>,
),
&'static str,
> {
let (decoder_output, encoder_hidden_states, new_cache, _, _, _, _) = match cache {
) -> Result<LMModelOutput, RustBertError> {
let base_model_output = match cache {
Cache::T5Cache(cached_layer_states) => self.base_model.forward_t(
input_ids.as_ref(),
attention_mask.as_ref(),
Some((encoder_outputs.as_ref().unwrap().copy(), None, None)),
Some(T5StackOutput {
hidden_state: encoder_outputs.as_ref().unwrap().copy(),
all_hidden_states: None,
all_attentions: None,
next_cache: None,
}),
Option::from(decoder_input_ids),
None,
None,
@ -709,7 +639,12 @@ impl LMHeadModel for T5ForConditionalGeneration {
Cache::None => self.base_model.forward_t(
input_ids.as_ref(),
attention_mask.as_ref(),
Some((encoder_outputs.as_ref().unwrap().copy(), None, None)),
Some(T5StackOutput {
hidden_state: encoder_outputs.as_ref().unwrap().copy(),
all_hidden_states: None,
all_attentions: None,
next_cache: None,
}),
Option::from(decoder_input_ids),
None,
None,
@ -717,18 +652,45 @@ impl LMHeadModel for T5ForConditionalGeneration {
None,
train,
),
_ => Err("Cache not compatible with T5 Model")?,
_ => {
return Err(RustBertError::ValueError(
"Cache not compatible with T5 Model".into(),
));
}
};
let lm_logits = decoder_output.linear::<Tensor>(&self.base_model.embeddings.ws, None)
let lm_logits = base_model_output
.decoder_output
.linear::<Tensor>(&self.base_model.embeddings.ws, None)
* (self.model_dim.powf(-0.5));
Ok((
Ok(LMModelOutput {
lm_logits,
Some(encoder_hidden_states),
Cache::T5Cache(new_cache),
None,
None,
))
encoder_hidden_state: Some(base_model_output.encoder_hidden_state),
cache: Cache::T5Cache(base_model_output.next_cache),
all_hidden_states: None,
all_attentions: None,
})
}
}
/// Container holding a T5 model output. The decoder output may hold the hidden state of
/// the last layer of the decoder, or may hold logits for a custom head module after the
/// decoder (e.g. for language modeling tasks)
pub struct T5ModelOutput {
/// Hidden state of the last layer of the decoder, or logits for a custom head
/// module after the decoder (e.g. for language modeling tasks)
pub decoder_output: Tensor,
/// Hidden state for the last layer of the encoder
pub encoder_hidden_state: Tensor,
/// Cached outputs of the model (attention layers keys and values) if the model is used for generation
pub next_cache: Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
/// Hidden states for all layers of the decoder
pub all_decoder_hidden_states: Option<Vec<Tensor>>,
/// Attention weights for all layers of the decoder
pub all_decoder_attentions: Option<Vec<Tensor>>,
/// Hidden states for all layers of the encoder
pub all_encoder_hidden_states: Option<Vec<Tensor>>,
/// Attention weights for all layers of the encoder
pub all_encoder_attentions: Option<Vec<Tensor>>,
}

View File

@ -61,18 +61,26 @@ fn albert_masked_lm() -> anyhow::Result<()> {
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
let (output, _, _) =
let model_output =
no_grad(|| albert_model.forward_t(Some(input_tensor), None, None, None, None, false));
// Print masked tokens
let index_1 = output.get(0).get(4).argmax(0, false);
let index_2 = output.get(1).get(6).argmax(0, false);
let index_1 = model_output
.prediction_scores
.get(0)
.get(4)
.argmax(0, false);
let index_2 = model_output
.prediction_scores
.get(1)
.get(6)
.argmax(0, false);
let word_1 = tokenizer.vocab().id_to_token(&index_1.int64_value(&[]));
let word_2 = tokenizer.vocab().id_to_token(&index_2.int64_value(&[]));
assert_eq!("▁them", word_1); // Outputs "_them" : "Looks like one [them] is missing (? this is identical with the original implementation)"
assert_eq!("▁grapes", word_2); // Outputs "grapes" : "It\'s like comparing [grapes] to apples"
assert!((output.double_value(&[0, 0, 0]) - 4.6143).abs() < 1e-4);
assert!((model_output.prediction_scores.double_value(&[0, 0, 0]) - 4.6143).abs() < 1e-4);
Ok(())
}
@ -127,17 +135,17 @@ fn albert_for_sequence_classification() -> anyhow::Result<()> {
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
let (output, all_hidden_states, all_attentions) =
let model_output =
no_grad(|| albert_model.forward_t(Some(input_tensor), None, None, None, None, false));
assert_eq!(output.size(), &[2, 3]);
assert_eq!(model_output.logits.size(), &[2, 3]);
assert_eq!(
config.num_hidden_layers as usize,
all_hidden_states.unwrap().len()
model_output.all_hidden_states.unwrap().len()
);
assert_eq!(
config.num_hidden_layers as usize,
all_attentions.unwrap().len()
model_output.all_attentions.unwrap().len()
);
Ok(())
@ -191,20 +199,20 @@ fn albert_for_multiple_choice() -> anyhow::Result<()> {
.unsqueeze(0);
// Forward pass
let (output, all_hidden_states, all_attentions) = no_grad(|| {
let model_output = no_grad(|| {
albert_model
.forward_t(Some(input_tensor), None, None, None, None, false)
.unwrap()
});
assert_eq!(output.size(), &[1, 2]);
assert_eq!(model_output.logits.size(), &[1, 2]);
assert_eq!(
config.num_hidden_layers as usize,
all_hidden_states.unwrap().len()
model_output.all_hidden_states.unwrap().len()
);
assert_eq!(
config.num_hidden_layers as usize,
all_attentions.unwrap().len()
model_output.all_attentions.unwrap().len()
);
Ok(())
@ -262,17 +270,17 @@ fn albert_for_token_classification() -> anyhow::Result<()> {
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
let (output, all_hidden_states, all_attentions) =
let model_output =
no_grad(|| bert_model.forward_t(Some(input_tensor), None, None, None, None, false));
assert_eq!(output.size(), &[2, 12, 4]);
assert_eq!(model_output.logits.size(), &[2, 12, 4]);
assert_eq!(
config.num_hidden_layers as usize,
all_hidden_states.unwrap().len()
model_output.all_hidden_states.unwrap().len()
);
assert_eq!(
config.num_hidden_layers as usize,
all_attentions.unwrap().len()
model_output.all_attentions.unwrap().len()
);
Ok(())
@ -324,18 +332,18 @@ fn albert_for_question_answering() -> anyhow::Result<()> {
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
let (start_scores, end_scores, all_hidden_states, all_attentions) =
let model_output =
no_grad(|| albert_model.forward_t(Some(input_tensor), None, None, None, None, false));
assert_eq!(start_scores.size(), &[2, 12]);
assert_eq!(end_scores.size(), &[2, 12]);
assert_eq!(model_output.start_logits.size(), &[2, 12]);
assert_eq!(model_output.end_logits.size(), &[2, 12]);
assert_eq!(
config.num_hidden_layers as usize,
all_hidden_states.unwrap().len()
model_output.all_hidden_states.unwrap().len()
);
assert_eq!(
config.num_hidden_layers as usize,
all_attentions.unwrap().len()
model_output.all_attentions.unwrap().len()
);
Ok(())

View File

@ -62,12 +62,12 @@ fn bart_lm_model() -> anyhow::Result<()> {
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
let (output, encoder_outputs, _, _, _, _, _) =
let model_output =
bart_model.forward_t(Some(&input_tensor), None, None, None, None, None, false);
assert_eq!(output.size(), vec!(1, 6, 1024));
assert_eq!(encoder_outputs.size(), vec!(1, 6, 1024));
assert!((output.double_value(&[0, 0, 0]) - 0.7877).abs() < 1e-4);
assert_eq!(model_output.decoder_output.size(), vec!(1, 6, 1024));
assert_eq!(model_output.encoder_hidden_state.size(), vec!(1, 6, 1024));
assert!((model_output.decoder_output.double_value(&[0, 0, 0]) - 0.7877).abs() < 1e-4);
Ok(())
}

View File

@ -70,7 +70,7 @@ fn bert_masked_lm() -> anyhow::Result<()> {
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
let (output, _, _) = no_grad(|| {
let model_output = no_grad(|| {
bert_model.forward_t(
Some(input_tensor),
None,
@ -84,8 +84,16 @@ fn bert_masked_lm() -> anyhow::Result<()> {
});
// Print masked tokens
let index_1 = output.get(0).get(4).argmax(0, false);
let index_2 = output.get(1).get(6).argmax(0, false);
let index_1 = model_output
.prediction_scores
.get(0)
.get(4)
.argmax(0, false);
let index_2 = model_output
.prediction_scores
.get(1)
.get(6)
.argmax(0, false);
let word_1 = tokenizer.vocab().id_to_token(&index_1.int64_value(&[]));
let word_2 = tokenizer.vocab().id_to_token(&index_2.int64_value(&[]));
@ -144,17 +152,17 @@ fn bert_for_sequence_classification() -> anyhow::Result<()> {
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
let (output, all_hidden_states, all_attentions) =
let model_output =
no_grad(|| bert_model.forward_t(Some(input_tensor), None, None, None, None, false));
assert_eq!(output.size(), &[2, 3]);
assert_eq!(model_output.logits.size(), &[2, 3]);
assert_eq!(
config.num_hidden_layers as usize,
all_hidden_states.unwrap().len()
model_output.all_hidden_states.unwrap().len()
);
assert_eq!(
config.num_hidden_layers as usize,
all_attentions.unwrap().len()
model_output.all_attentions.unwrap().len()
);
Ok(())
@ -206,17 +214,16 @@ fn bert_for_multiple_choice() -> anyhow::Result<()> {
.unsqueeze(0);
// Forward pass
let (output, all_hidden_states, all_attentions) =
no_grad(|| bert_model.forward_t(input_tensor, None, None, None, false));
let model_output = no_grad(|| bert_model.forward_t(input_tensor, None, None, None, false));
assert_eq!(output.size(), &[1, 2]);
assert_eq!(model_output.logits.size(), &[1, 2]);
assert_eq!(
config.num_hidden_layers as usize,
all_hidden_states.unwrap().len()
model_output.all_hidden_states.unwrap().len()
);
assert_eq!(
config.num_hidden_layers as usize,
all_attentions.unwrap().len()
model_output.all_attentions.unwrap().len()
);
Ok(())
@ -272,17 +279,17 @@ fn bert_for_token_classification() -> anyhow::Result<()> {
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
let (output, all_hidden_states, all_attentions) =
let model_output =
no_grad(|| bert_model.forward_t(Some(input_tensor), None, None, None, None, false));
assert_eq!(output.size(), &[2, 11, 4]);
assert_eq!(model_output.logits.size(), &[2, 11, 4]);
assert_eq!(
config.num_hidden_layers as usize,
all_hidden_states.unwrap().len()
model_output.all_hidden_states.unwrap().len()
);
assert_eq!(
config.num_hidden_layers as usize,
all_attentions.unwrap().len()
model_output.all_attentions.unwrap().len()
);
Ok(())
@ -332,18 +339,18 @@ fn bert_for_question_answering() -> anyhow::Result<()> {
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
let (start_scores, end_scores, all_hidden_states, all_attentions) =
let model_output =
no_grad(|| bert_model.forward_t(Some(input_tensor), None, None, None, None, false));
assert_eq!(start_scores.size(), &[2, 11]);
assert_eq!(end_scores.size(), &[2, 11]);
assert_eq!(model_output.start_logits.size(), &[2, 11]);
assert_eq!(model_output.end_logits.size(), &[2, 11]);
assert_eq!(
config.num_hidden_layers as usize,
all_hidden_states.unwrap().len()
model_output.all_hidden_states.unwrap().len()
);
assert_eq!(
config.num_hidden_layers as usize,
all_attentions.unwrap().len()
model_output.all_attentions.unwrap().len()
);
Ok(())
@ -407,7 +414,7 @@ fn bert_question_answering() -> anyhow::Result<()> {
let context = String::from("Amy lives in Amsterdam");
let qa_input = QaInput { question, context };
let answers = qa_model.predict(&vec![qa_input], 1, 32);
let answers = qa_model.predict(&[qa_input], 1, 32);
assert_eq!(answers.len(), 1 as usize);
assert_eq!(answers[0].len(), 1 as usize);

View File

@ -96,15 +96,23 @@ fn distilbert_masked_lm() -> anyhow::Result<()> {
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
let (output, _, _) = no_grad(|| {
let model_output = no_grad(|| {
distil_bert_model
.forward_t(Some(input_tensor), None, None, false)
.unwrap()
});
// Print masked tokens
let index_1 = output.get(0).get(4).argmax(0, false);
let index_2 = output.get(1).get(6).argmax(0, false);
let index_1 = model_output
.prediction_scores
.get(0)
.get(4)
.argmax(0, false);
let index_2 = model_output
.prediction_scores
.get(1)
.get(6)
.argmax(0, false);
let word_1 = tokenizer.vocab().id_to_token(&index_1.int64_value(&[]));
let word_2 = tokenizer.vocab().id_to_token(&index_2.int64_value(&[]));
@ -160,16 +168,22 @@ fn distilbert_for_question_answering() -> anyhow::Result<()> {
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
let (start_scores, end_scores, all_hidden_states, all_attentions) = no_grad(|| {
let model_output = no_grad(|| {
distil_bert_model
.forward_t(Some(input_tensor), None, None, false)
.unwrap()
});
assert_eq!(start_scores.size(), &[2, 11]);
assert_eq!(end_scores.size(), &[2, 11]);
assert_eq!(config.n_layers as usize, all_hidden_states.unwrap().len());
assert_eq!(config.n_layers as usize, all_attentions.unwrap().len());
assert_eq!(model_output.start_logits.size(), &[2, 11]);
assert_eq!(model_output.end_logits.size(), &[2, 11]);
assert_eq!(
config.n_layers as usize,
model_output.all_hidden_states.unwrap().len()
);
assert_eq!(
config.n_layers as usize,
model_output.all_attentions.unwrap().len()
);
Ok(())
}
@ -226,15 +240,21 @@ fn distilbert_for_token_classification() -> anyhow::Result<()> {
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
let (output, all_hidden_states, all_attentions) = no_grad(|| {
let model_output = no_grad(|| {
distil_bert_model
.forward_t(Some(input_tensor), None, None, false)
.unwrap()
});
assert_eq!(output.size(), &[2, 11, 4]);
assert_eq!(config.n_layers as usize, all_hidden_states.unwrap().len());
assert_eq!(config.n_layers as usize, all_attentions.unwrap().len());
assert_eq!(model_output.logits.size(), &[2, 11, 4]);
assert_eq!(
config.n_layers as usize,
model_output.all_hidden_states.unwrap().len()
);
assert_eq!(
config.n_layers as usize,
model_output.all_attentions.unwrap().len()
);
Ok(())
}
@ -249,7 +269,7 @@ fn distilbert_question_answering() -> anyhow::Result<()> {
let context = String::from("Amy lives in Amsterdam");
let qa_input = QaInput { question, context };
let answers = qa_model.predict(&vec![qa_input], 1, 32);
let answers = qa_model.predict(&[qa_input], 1, 32);
assert_eq!(answers.len(), 1 as usize);
assert_eq!(answers[0].len(), 1 as usize);

View File

@ -61,7 +61,7 @@ fn distilgpt2_lm_model() -> anyhow::Result<()> {
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
let (output, _, past, _, _) = gpt2_model
let model_output = gpt2_model
.forward_t(
&Some(input_tensor),
Cache::None,
@ -75,11 +75,16 @@ fn distilgpt2_lm_model() -> anyhow::Result<()> {
)
.unwrap();
let next_word_id = output.get(0).get(-1).argmax(-1, true).int64_value(&[0]);
let next_word_id = model_output
.lm_logits
.get(0)
.get(-1)
.argmax(-1, true)
.int64_value(&[0]);
let next_word = tokenizer.decode(vec![next_word_id], true, true);
assert_eq!(output.size(), vec!(1, 11, 50257));
match past {
assert_eq!(model_output.lm_logits.size(), vec!(1, 11, 50257));
match model_output.cache {
Cache::GPT2Cache(past) => {
assert!(past.is_some());
assert_eq!(past.as_ref().unwrap().len(), config.n_layer as usize);
@ -91,7 +96,13 @@ fn distilgpt2_lm_model() -> anyhow::Result<()> {
_ => panic!("Wrong cache returned for GPT2"),
}
assert!(
(output.double_value(&[0, output.size()[1] - 1, next_word_id]) - (-48.7065)).abs() < 1e-4
(model_output.lm_logits.double_value(&[
0,
model_output.lm_logits.size()[1] - 1,
next_word_id
]) - (-48.7065))
.abs()
< 1e-4
);
assert_eq!(next_word_id, 14104i64);
assert_eq!(next_word, String::from(" twelve"));

View File

@ -58,23 +58,34 @@ fn electra_masked_lm() -> anyhow::Result<()> {
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
let (output, all_hidden_states, all_attentions) =
let model_output =
no_grad(|| electra_model.forward_t(Some(input_tensor), None, None, None, None, false));
// Decode output
let index_1 = output.get(0).get(4).argmax(0, false);
let index_2 = output.get(1).get(7).argmax(0, false);
let index_1 = model_output
.prediction_scores
.get(0)
.get(4)
.argmax(0, false);
let index_2 = model_output
.prediction_scores
.get(1)
.get(7)
.argmax(0, false);
let word_1 = tokenizer.vocab().id_to_token(&index_1.int64_value(&[]));
let word_2 = tokenizer.vocab().id_to_token(&index_2.int64_value(&[]));
assert_eq!(output.size(), &[2, 10, config.vocab_size]);
assert_eq!(
config.num_hidden_layers as usize,
all_hidden_states.unwrap().len()
model_output.prediction_scores.size(),
&[2, 10, config.vocab_size]
);
assert_eq!(
config.num_hidden_layers as usize,
all_attentions.unwrap().len()
model_output.all_hidden_states.unwrap().len()
);
assert_eq!(
config.num_hidden_layers as usize,
model_output.all_attentions.unwrap().len()
);
assert_eq!("thing", word_1); // Outputs "person" : "Looks like one [person] is missing"
assert_eq!("sunny", word_2); // Outputs "pear" : "It was a very nice and [sunny] day"
@ -127,16 +138,20 @@ fn electra_discriminator() -> anyhow::Result<()> {
let input_tensor = Tensor::stack(encoded_input.as_slice(), 0).to(device);
// Forward pass
let (output, _, _) =
let model_output =
no_grad(|| electra_model.forward_t(Some(input_tensor), None, None, None, None, false));
// Validate model predictions
let expected_probabilities = vec![
0.0101, 0.0030, 0.0010, 0.0018, 0.9489, 0.0067, 0.0026, 0.0017, 0.0311, 0.0101,
];
let probabilities = output.iter::<f64>().unwrap().collect::<Vec<f64>>();
let probabilities = model_output
.probabilities
.iter::<f64>()
.unwrap()
.collect::<Vec<f64>>();
assert_eq!(output.size(), &[10]);
assert_eq!(model_output.probabilities.size(), &[10]);
for (expected, pred) in probabilities.iter().zip(expected_probabilities) {
assert!((expected - pred).abs() < 1e-4);
}

View File

@ -62,7 +62,7 @@ fn gpt2_lm_model() -> anyhow::Result<()> {
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
let (output, _, past, _, _) = gpt2_model
let model_output = gpt2_model
.forward_t(
&Some(input_tensor),
Cache::None,
@ -76,11 +76,16 @@ fn gpt2_lm_model() -> anyhow::Result<()> {
)
.unwrap();
let next_word_id = output.get(0).get(-1).argmax(-1, true).int64_value(&[0]);
let next_word_id = model_output
.lm_logits
.get(0)
.get(-1)
.argmax(-1, true)
.int64_value(&[0]);
let next_word = tokenizer.decode(vec![next_word_id], true, true);
assert_eq!(output.size(), vec!(1, 4, 50257));
match past {
assert_eq!(model_output.lm_logits.size(), vec!(1, 4, 50257));
match model_output.cache {
Cache::GPT2Cache(past) => {
assert!(past.is_some());
assert_eq!(past.as_ref().unwrap().len(), config.n_layer as usize);
@ -92,7 +97,13 @@ fn gpt2_lm_model() -> anyhow::Result<()> {
_ => panic!("Wrong cache returned for GPT2"),
}
assert!(
(output.double_value(&[0, output.size()[1] - 1, next_word_id]) - (-69.4948)).abs() < 1e-4
(model_output.lm_logits.double_value(&[
0,
model_output.lm_logits.size()[1] - 1,
next_word_id
]) - (-69.4948))
.abs()
< 1e-4
);
assert_eq!(next_word_id, 1936i64);
assert_eq!(next_word, String::from(" five"));

View File

@ -64,7 +64,7 @@ fn openai_gpt_lm_model() -> anyhow::Result<()> {
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
let (output, _, _, _, _) = openai_gpt
let model_output = openai_gpt
.forward_t(
&Some(input_tensor),
Cache::None,
@ -78,12 +78,23 @@ fn openai_gpt_lm_model() -> anyhow::Result<()> {
)
.unwrap();
let next_word_id = output.get(0).get(-1).argmax(-1, true).int64_value(&[0]);
let next_word_id = model_output
.lm_logits
.get(0)
.get(-1)
.argmax(-1, true)
.int64_value(&[0]);
let next_word = tokenizer.decode(vec![next_word_id], true, true);
assert_eq!(output.size(), vec!(1, 6, 40478));
assert_eq!(model_output.lm_logits.size(), vec!(1, 6, 40478));
assert!(
(output.double_value(&[0, output.size()[1] - 1, next_word_id]) - (9.1056)).abs() < 1e-4
(model_output.lm_logits.double_value(&[
0,
model_output.lm_logits.size()[1] - 1,
next_word_id
]) - (9.1056))
.abs()
< 1e-4
);
assert_eq!(next_word_id, 580i64);
assert_eq!(next_word, String::from("be"));

View File

@ -80,7 +80,7 @@ fn roberta_masked_lm() -> anyhow::Result<()> {
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
let (output, _, _) = no_grad(|| {
let model_output = no_grad(|| {
roberta_model.forward_t(
Some(input_tensor),
None,
@ -94,8 +94,16 @@ fn roberta_masked_lm() -> anyhow::Result<()> {
});
// Print masked tokens
let index_1 = output.get(0).get(4).argmax(0, false);
let index_2 = output.get(1).get(5).argmax(0, false);
let index_1 = model_output
.prediction_scores
.get(0)
.get(4)
.argmax(0, false);
let index_2 = model_output
.prediction_scores
.get(1)
.get(5)
.argmax(0, false);
let word_1 = tokenizer.vocab().id_to_token(&index_1.int64_value(&[]));
let word_2 = tokenizer.vocab().id_to_token(&index_2.int64_value(&[]));
@ -164,17 +172,17 @@ fn roberta_for_sequence_classification() -> anyhow::Result<()> {
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
let (output, all_hidden_states, all_attentions) =
let model_output =
no_grad(|| roberta_model.forward_t(Some(input_tensor), None, None, None, None, false));
assert_eq!(output.size(), &[2, 3]);
assert_eq!(model_output.logits.size(), &[2, 3]);
assert_eq!(
config.num_hidden_layers as usize,
all_hidden_states.unwrap().len()
model_output.all_hidden_states.unwrap().len()
);
assert_eq!(
config.num_hidden_layers as usize,
all_attentions.unwrap().len()
model_output.all_attentions.unwrap().len()
);
Ok(())
@ -236,17 +244,16 @@ fn roberta_for_multiple_choice() -> anyhow::Result<()> {
.unsqueeze(0);
// Forward pass
let (output, all_hidden_states, all_attentions) =
no_grad(|| roberta_model.forward_t(input_tensor, None, None, None, false));
let model_output = no_grad(|| roberta_model.forward_t(input_tensor, None, None, None, false));
assert_eq!(output.size(), &[1, 2]);
assert_eq!(model_output.logits.size(), &[1, 2]);
assert_eq!(
config.num_hidden_layers as usize,
all_hidden_states.unwrap().len()
model_output.all_hidden_states.unwrap().len()
);
assert_eq!(
config.num_hidden_layers as usize,
all_attentions.unwrap().len()
model_output.all_attentions.unwrap().len()
);
Ok(())
@ -312,17 +319,17 @@ fn roberta_for_token_classification() -> anyhow::Result<()> {
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
let (output, all_hidden_states, all_attentions) =
let model_output =
no_grad(|| roberta_model.forward_t(Some(input_tensor), None, None, None, None, false));
assert_eq!(output.size(), &[2, 9, 4]);
assert_eq!(model_output.logits.size(), &[2, 9, 4]);
assert_eq!(
config.num_hidden_layers as usize,
all_hidden_states.unwrap().len()
model_output.all_hidden_states.unwrap().len()
);
assert_eq!(
config.num_hidden_layers as usize,
all_attentions.unwrap().len()
model_output.all_attentions.unwrap().len()
);
Ok(())
@ -357,7 +364,7 @@ fn roberta_question_answering() -> anyhow::Result<()> {
let context = String::from("Amy lives in Amsterdam");
let qa_input = QaInput { question, context };
let answers = qa_model.predict(&vec![qa_input], 1, 32);
let answers = qa_model.predict(&[qa_input], 1, 32);
assert_eq!(answers.len(), 1 as usize);
assert_eq!(answers[0].len(), 1 as usize);

View File

@ -39,14 +39,14 @@ for k, v in weights.items():
k = k.replace("gamma", "weight").replace("beta", "bias")
nps[k] = np.ascontiguousarray(v.cpu().numpy())
# np.savez(target_path / 'model.npz', **nps)
#
# source = str(target_path / 'model.npz')
# target = str(target_path / 'model.ot')
#
# toml_location = (Path(__file__).resolve() / '..' / '..' / 'Cargo.toml').resolve()
#
# subprocess.call(['cargo', 'run', '--bin=convert-tensor', '--manifest-path=%s' % toml_location, '--', source, target])
#
# os.remove(str(target_path / 'model.bin'))
# os.remove(str(target_path / 'model.npz'))
np.savez(target_path / 'model.npz', **nps)
source = str(target_path / 'model.npz')
target = str(target_path / 'model.ot')
toml_location = (Path(__file__).resolve() / '..' / '..' / 'Cargo.toml').resolve()
subprocess.call(['cargo', 'run', '--bin=convert-tensor', '--manifest-path=%s' % toml_location, '--', source, target])
os.remove(str(target_path / 'model.bin'))
os.remove(str(target_path / 'model.npz'))