Updated Bart & Marian (clippy warnings)

This commit is contained in:
Guillaume B 2020-09-12 16:17:58 +02:00
parent daa6dba2d2
commit 7e9d6d7e39
20 changed files with 327 additions and 354 deletions

1
clippy.toml Normal file
View File

@ -0,0 +1 @@
too-many-arguments-threshold = 10

View File

@ -73,12 +73,15 @@ fn main() -> anyhow::Result<()> {
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
let (decoder_output, encoder_output, _, _, _, _, _) =
let model_output =
no_grad(|| bart_model.forward_t(Some(&input_tensor), None, None, None, None, None, false));
// Print masked tokens
println!("{:?}", encoder_output);
println!("{:?}", decoder_output);
println!("{:?}", decoder_output.double_value(&[0, 0, 0]));
println!("{:?}", model_output.encoder_hidden_state);
println!("{:?}", model_output.decoder_hidden_state);
println!(
"{:?}",
model_output.decoder_hidden_state.double_value(&[0, 0, 0])
);
Ok(())
}

View File

@ -70,7 +70,7 @@ fn main() -> anyhow::Result<()> {
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
let (output, _, _, _, _) = gpt2_model
let model_output = gpt2_model
.forward_t(
&Some(input_tensor),
Cache::None,
@ -84,7 +84,12 @@ fn main() -> anyhow::Result<()> {
)
.unwrap();
let next_word_id = output.get(0).get(-1).argmax(-1, true).int64_value(&[0]);
let next_word_id = model_output
.lm_logits
.get(0)
.get(-1)
.argmax(-1, true)
.int64_value(&[0]);
let next_word = tokenizer.decode(vec![next_word_id], true, true);
println!("Provided input: {}", input[0]);
println!("Next word: {}", next_word);

View File

@ -75,7 +75,7 @@ fn main() -> anyhow::Result<()> {
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
let (output, _, _, _, _) = openai_gpt
let model_output = openai_gpt
.forward_t(
&Some(input_tensor),
Cache::None,
@ -89,7 +89,12 @@ fn main() -> anyhow::Result<()> {
)
.unwrap();
let next_word_id = output.get(0).get(-1).argmax(-1, true).int64_value(&[0]);
let next_word_id = model_output
.lm_logits
.get(0)
.get(-1)
.argmax(-1, true)
.int64_value(&[0]);
let next_word = tokenizer.decode(vec![next_word_id], true, true);
println!("Provided input: {}", input[0]);
println!("Next word: {}", next_word);

View File

@ -13,9 +13,9 @@
use crate::bart::attention::LayerState;
use crate::bart::decoder::BartDecoder;
use crate::bart::encoder::BartEncoder;
use crate::bart::encoder::{BartEncoder, BartEncoderOutput};
use crate::common::dropout::Dropout;
use crate::pipelines::generation::{Cache, LMHeadModel};
use crate::pipelines::generation::{Cache, LMHeadModel, LMModelOutput};
use crate::Config;
use serde::{Deserialize, Serialize};
use std::borrow::Borrow;
@ -346,15 +346,7 @@ impl BartModel {
/// let decoder_attention_mask =
/// Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
///
/// let (
/// decoder_output,
/// encoder_hidden_states,
/// decoder_cache,
/// all_encoder_hidden_states,
/// all_encoder_attentions,
/// all_decoder_hidden_states,
/// all_decoder_attentions,
/// ) = no_grad(|| {
/// let model_output = no_grad(|| {
/// bart_model.forward_t(
/// Some(&input_tensor),
/// Some(&encoder_attention_mask),
@ -371,19 +363,11 @@ impl BartModel {
input_ids: Option<&Tensor>,
attention_mask: Option<&Tensor>,
decoder_input_ids: Option<&Tensor>,
encoder_outputs: Option<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>)>,
encoder_output: Option<BartEncoderOutput>,
decoder_attention_mask: Option<&Tensor>,
layer_states: Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
train: bool,
) -> (
Tensor,
Tensor,
Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
Option<Vec<Tensor>>,
Option<Vec<Tensor>>,
Option<Vec<Tensor>>,
Option<Vec<Tensor>>,
) {
) -> BartModelOutput {
let (decoder_input_ids, decoder_padding_mask, causal_mask) = if self.generation_mode {
(decoder_input_ids.unwrap().copy(), None, None)
} else {
@ -398,43 +382,37 @@ impl BartModel {
decoder_attention_mask,
)
};
let (encoder_hidden_states, all_encoder_hidden_states, all_encoder_attentions) =
match encoder_outputs {
Some(value) => value,
None => {
assert!(
input_ids.is_some(),
"input_ids must be provided when encoder output is not pre-computed"
);
self.encoder.forward_t(
input_ids.unwrap(),
attention_mask,
&self.embeddings,
train,
)
}
};
let encoder_output = match encoder_output {
Some(value) => value,
None => {
assert!(
input_ids.is_some(),
"input_ids must be provided when encoder output is not pre-computed"
);
self.encoder
.forward_t(input_ids.unwrap(), attention_mask, &self.embeddings, train)
}
};
let (decoder_outputs, decoder_cache, all_decoder_hidden_states, all_decoder_attentions) =
self.decoder.forward_t(
&decoder_input_ids,
&encoder_hidden_states,
attention_mask,
decoder_padding_mask.as_ref(),
causal_mask.as_ref(),
&self.embeddings,
layer_states,
train,
);
(
decoder_outputs,
encoder_hidden_states,
decoder_cache.1,
all_decoder_hidden_states,
all_decoder_attentions,
all_encoder_hidden_states,
all_encoder_attentions,
)
let decoder_output = self.decoder.forward_t(
&decoder_input_ids,
&encoder_output.hidden_state,
attention_mask,
decoder_padding_mask.as_ref(),
causal_mask.as_ref(),
&self.embeddings,
layer_states,
train,
);
BartModelOutput {
decoder_hidden_state: decoder_output.hidden_state,
encoder_hidden_state: encoder_output.hidden_state,
cache: decoder_output.next_decoder_cache,
all_decoder_hidden_states: decoder_output.all_hidden_states,
all_decoder_attentions: decoder_output.all_attentions,
all_encoder_hidden_states: encoder_output.all_hidden_states,
all_encoder_attentions: encoder_output.all_attentions,
}
}
}
@ -525,9 +503,7 @@ impl BartForConditionalGeneration {
/// let encoder_attention_mask = Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
/// let decoder_attention_mask = Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
///
/// let (decoder_output, encoder_hidden_states, cache,
/// all_encoder_hidden_states, all_encoder_attentions,
/// all_decoder_hidden_states, all_decoder_attentions) = no_grad(|| {
/// let model_output = no_grad(|| {
/// bart_model
/// .forward_t(Some(&input_tensor),
/// Some(&encoder_attention_mask),
@ -542,58 +518,41 @@ impl BartForConditionalGeneration {
&self,
input_ids: Option<&Tensor>,
attention_mask: Option<&Tensor>,
encoder_outputs: Option<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>)>,
encoder_output: Option<BartEncoderOutput>,
decoder_input_ids: Option<&Tensor>,
decoder_attention_mask: Option<&Tensor>,
old_layer_states: Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
train: bool,
) -> (
Tensor,
Tensor,
Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
Option<Vec<Tensor>>,
Option<Vec<Tensor>>,
Option<Vec<Tensor>>,
Option<Vec<Tensor>>,
) {
let (
decoder_outputs,
encoder_hidden_states,
decoder_cache,
all_decoder_hidden_states,
all_decoder_attentions,
all_encoder_hidden_states,
all_encoder_attentions,
) = self.base_model.forward_t(
) -> BartModelOutput {
let base_model_output = self.base_model.forward_t(
input_ids,
attention_mask,
decoder_input_ids,
encoder_outputs,
encoder_output,
decoder_attention_mask,
old_layer_states,
train,
);
let lm_logits = decoder_outputs.linear::<Tensor>(&self.base_model.embeddings.ws, None);
(
lm_logits,
encoder_hidden_states,
decoder_cache,
all_decoder_hidden_states,
all_decoder_attentions,
all_encoder_hidden_states,
all_encoder_attentions,
)
let lm_logits = base_model_output
.decoder_hidden_state
.linear::<Tensor>(&self.base_model.embeddings.ws, None);
BartModelOutput {
decoder_hidden_state: lm_logits,
..base_model_output
}
}
pub fn encode(&self, input_ids: &Tensor, attention_mask: Option<&Tensor>) -> Tensor {
let (encoder_hidden_states, _, _) = self.base_model.encoder.forward_t(
input_ids,
attention_mask,
&self.base_model.embeddings,
false,
);
encoder_hidden_states
self.base_model
.encoder
.forward_t(
input_ids,
attention_mask,
&self.base_model.embeddings,
false,
)
.hidden_state
}
}
@ -740,9 +699,7 @@ impl BartForSequenceClassification {
/// let encoder_attention_mask = Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
/// let decoder_attention_mask = Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
///
/// let (decoder_output, encoder_hidden_states, cache,
/// all_encoder_hidden_states, all_encoder_attentions,
/// all_decoder_hidden_states, all_decoder_attentions) = no_grad(|| {
/// let model_output = no_grad(|| {
/// bart_model
/// .forward_t(Some(&input_tensor),
/// Some(&encoder_attention_mask),
@ -757,60 +714,51 @@ impl BartForSequenceClassification {
&self,
input_ids: &Tensor,
attention_mask: Option<&Tensor>,
encoder_outputs: Option<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>)>,
encoder_output: Option<BartEncoderOutput>,
decoder_input_ids: Option<&Tensor>,
decoder_attention_mask: Option<&Tensor>,
train: bool,
) -> (
Tensor,
Tensor,
Option<Vec<Tensor>>,
Option<Vec<Tensor>>,
Option<Vec<Tensor>>,
Option<Vec<Tensor>>,
) {
let (
decoder_outputs,
encoder_hidden_states,
_,
all_decoder_hidden_states,
all_decoder_attentions,
all_encoder_hidden_states,
all_encoder_attentions,
) = self.base_model.forward_t(
) -> BartModelOutput {
let base_model_output = self.base_model.forward_t(
Some(input_ids),
attention_mask,
decoder_input_ids,
encoder_outputs,
encoder_output,
decoder_attention_mask,
None,
train,
);
let eos_mask = input_ids.eq(self.eos_token_id);
let reshape = eos_mask.sum1(&[1], true, Int64);
let sentence_representation = decoder_outputs
let sentence_representation = base_model_output
.decoder_hidden_state
.permute(&[2, 0, 1])
.masked_select(&eos_mask)
.view((-1, reshape.size()[0] * reshape.int64_value(&[0, 0])))
.transpose(0, 1)
.view((
decoder_outputs.size()[0],
base_model_output.decoder_hidden_state.size()[0],
-1,
*decoder_outputs.size().last().unwrap(),
*base_model_output
.decoder_hidden_state
.size()
.last()
.unwrap(),
))
.select(1, -1);
let logits = self
.classification_head
.forward_t(&sentence_representation, train);
(
logits,
encoder_hidden_states,
all_decoder_hidden_states,
all_decoder_attentions,
all_encoder_hidden_states,
all_encoder_attentions,
)
BartModelOutput {
decoder_hidden_state: logits,
encoder_hidden_state: base_model_output.encoder_hidden_state,
cache: None,
all_decoder_hidden_states: base_model_output.all_decoder_hidden_states,
all_decoder_attentions: base_model_output.all_decoder_attentions,
all_encoder_hidden_states: base_model_output.all_encoder_hidden_states,
all_encoder_attentions: base_model_output.all_encoder_attentions,
}
}
}
@ -860,9 +808,7 @@ impl LMHeadModel for BartForConditionalGeneration {
/// let encoder_attention_mask = Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
/// let decoder_attention_mask = Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
///
/// let (decoder_output, encoder_hidden_states, cache,
/// all_encoder_hidden_states, all_encoder_attentions,
/// all_decoder_hidden_states, all_decoder_attentions) = no_grad(|| {
/// let model_output = no_grad(|| {
/// bart_model
/// .forward_t(Some(&input_tensor),
/// Some(&encoder_attention_mask),
@ -884,22 +830,17 @@ impl LMHeadModel for BartForConditionalGeneration {
encoder_outputs: Option<&Tensor>,
decoder_input_ids: &Option<Tensor>,
train: bool,
) -> Result<
(
Tensor,
Option<Tensor>,
Cache,
Option<Vec<Tensor>>,
Option<Vec<Tensor>>,
),
&'static str,
> {
let (decoder_output, encoder_hidden_states, new_cache, _, _, _, _) = match cache {
) -> Result<LMModelOutput, &'static str> {
let base_model_output = match cache {
Cache::BARTCache(cached_layer_states) => self.base_model.forward_t(
input_ids.as_ref(),
attention_mask.as_ref(),
decoder_input_ids.as_ref(),
Some((encoder_outputs.as_ref().unwrap().copy(), None, None)),
Some(BartEncoderOutput {
hidden_state: encoder_outputs.as_ref().unwrap().copy(),
all_hidden_states: None,
all_attentions: None,
}),
None,
cached_layer_states,
train,
@ -909,21 +850,37 @@ impl LMHeadModel for BartForConditionalGeneration {
input_ids.as_ref(),
attention_mask.as_ref(),
decoder_input_ids.as_ref(),
Some((encoder_outputs.as_ref().unwrap().copy(), None, None)),
Some(BartEncoderOutput {
hidden_state: encoder_outputs.as_ref().unwrap().copy(),
all_hidden_states: None,
all_attentions: None,
}),
None,
None,
train,
),
_ => Err("Cache not compatible with BART Model")?,
_ => return Err("Cache not compatible with BART Model"),
};
let lm_logits = decoder_output.linear::<Tensor>(&self.base_model.embeddings.ws, None);
Ok((
let lm_logits = base_model_output
.decoder_hidden_state
.linear::<Tensor>(&self.base_model.embeddings.ws, None);
Ok(LMModelOutput {
lm_logits,
Some(encoder_hidden_states),
Cache::BARTCache(new_cache),
None,
None,
))
encoder_hidden_state: Some(base_model_output.encoder_hidden_state),
cache: Cache::BARTCache(base_model_output.cache),
all_hidden_states: None,
all_attentions: None,
})
}
}
pub struct BartModelOutput {
pub decoder_hidden_state: Tensor,
pub encoder_hidden_state: Tensor,
pub cache: Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
pub all_decoder_hidden_states: Option<Vec<Tensor>>,
pub all_decoder_attentions: Option<Vec<Tensor>>,
pub all_encoder_hidden_states: Option<Vec<Tensor>>,
pub all_encoder_attentions: Option<Vec<Tensor>>,
}

View File

@ -288,15 +288,7 @@ impl BartDecoder {
embeddings: &nn::Embedding,
old_layer_states: Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
train: bool,
) -> (
Tensor,
(
Option<Tensor>,
Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
),
Option<Vec<Tensor>>,
Option<Vec<Tensor>>,
) {
) -> BartDecoderOutput {
let encoder_padding_mask = match encoder_padding_mask {
Some(mask) => Some(mask.eq(0).to_kind(Bool)),
None => None,
@ -342,45 +334,48 @@ impl BartDecoder {
};
let encoder_hidden_states = encoder_hidden_states.transpose(0, 1);
let mut attention_weights: Option<Tensor>;
let mut layers = self.layers.iter().enumerate();
loop {
match layers.next() {
Some((layer_idx, layer)) => {
let layer_state = match &next_decoder_cache {
Some(values) => values[layer_idx].to_owned(),
None => (None, None),
};
let temp = layer.forward_t(
&hidden_state,
&encoder_hidden_states,
encoder_padding_mask.as_ref(),
decoder_causal_mask,
decoder_padding_mask,
layer_state,
train,
);
hidden_state = temp.0;
attention_weights = temp.1;
if let Some(hidden_states) = all_hidden_states.borrow_mut() {
hidden_states.push(hidden_state.as_ref().copy().transpose(0, 1));
};
if let Some(attentions) = all_attentions.borrow_mut() {
attentions.push(attention_weights.as_ref().unwrap().copy());
};
if let Some(value) = &mut next_decoder_cache {
value[layer_idx] = temp.2
};
}
None => break,
for (layer_idx, layer) in self.layers.iter().enumerate() {
let layer_state = match &next_decoder_cache {
Some(values) => values[layer_idx].to_owned(),
None => (None, None),
};
let temp = layer.forward_t(
&hidden_state,
&encoder_hidden_states,
encoder_padding_mask.as_ref(),
decoder_causal_mask,
decoder_padding_mask,
layer_state,
train,
);
hidden_state = temp.0;
attention_weights = temp.1;
if let Some(hidden_states) = all_hidden_states.borrow_mut() {
hidden_states.push(hidden_state.as_ref().copy().transpose(0, 1));
};
if let Some(attentions) = all_attentions.borrow_mut() {
attentions.push(attention_weights.as_ref().unwrap().copy());
};
if let Some(value) = &mut next_decoder_cache {
value[layer_idx] = temp.2
};
}
(
hidden_state.transpose(0, 1),
(encoder_padding_mask, next_decoder_cache),
BartDecoderOutput {
hidden_state: hidden_state.transpose(0, 1),
encoder_padding_mask,
next_decoder_cache,
all_hidden_states,
all_attentions,
)
}
}
}
pub struct BartDecoderOutput {
pub hidden_state: Tensor,
pub encoder_padding_mask: Option<Tensor>,
pub next_decoder_cache: Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
pub all_hidden_states: Option<Vec<Tensor>>,
pub all_attentions: Option<Vec<Tensor>>,
}

View File

@ -232,7 +232,7 @@ impl BartEncoder {
attention_mask: Option<&Tensor>,
embeddings: &nn::Embedding,
train: bool,
) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
) -> BartEncoderOutput {
let attention_mask = match attention_mask {
Some(mask) => Some(mask.eq(0).to_kind(Bool)),
None => None,
@ -260,33 +260,34 @@ impl BartEncoder {
let mut hidden_state = x.copy();
let mut attention_weights: Option<Tensor>;
let mut layers = self.layers.iter();
loop {
match layers.next() {
Some(layer) => {
if let Some(hidden_states) = all_hidden_states.borrow_mut() {
hidden_states.push(hidden_state.as_ref().copy().transpose(0, 1));
};
for layer in &self.layers {
if let Some(hidden_states) = all_hidden_states.borrow_mut() {
hidden_states.push(hidden_state.as_ref().copy().transpose(0, 1));
};
let temp = layer.forward_t(&hidden_state, attention_mask.as_ref(), train);
hidden_state = temp.0;
attention_weights = temp.1;
if let Some(attentions) = all_attentions.borrow_mut() {
attentions.push(attention_weights.as_ref().unwrap().copy());
};
}
None => break,
let temp = layer.forward_t(&hidden_state, attention_mask.as_ref(), train);
hidden_state = temp.0;
attention_weights = temp.1;
if let Some(attentions) = all_attentions.borrow_mut() {
attentions.push(attention_weights.as_ref().unwrap().copy());
};
}
if let Some(hidden_states) = all_hidden_states.borrow_mut() {
hidden_states.push(hidden_state.as_ref().copy().transpose(0, 1));
};
(
hidden_state.transpose(0, 1),
BartEncoderOutput {
hidden_state: hidden_state.transpose(0, 1),
all_hidden_states,
all_attentions,
)
}
}
}
pub struct BartEncoderOutput {
pub hidden_state: Tensor,
pub all_hidden_states: Option<Vec<Tensor>>,
pub all_attentions: Option<Vec<Tensor>>,
}

View File

@ -69,3 +69,5 @@ pub use bart_model::{
BartForSequenceClassification, BartMergesResources, BartModel, BartModelResources,
BartVocabResources,
};
pub(crate) use encoder::BartEncoderOutput;

View File

@ -86,7 +86,7 @@ impl BertSelfAttention {
fn flatten(&self, x: Tensor, bs: i64, dim_per_head: i64) -> Tensor {
x.transpose(1, 2)
.contiguous()
.view((bs, -1, &self.num_attention_heads * dim_per_head))
.view((bs, -1, self.num_attention_heads * dim_per_head))
}
pub fn forward_t(

View File

@ -15,7 +15,7 @@
use crate::common::dropout::Dropout;
use crate::common::linear::{linear_no_bias, LinearNoBias};
use crate::gpt2::transformer::Block;
use crate::pipelines::generation::{Cache, LMHeadModel};
use crate::pipelines::generation::{Cache, LMHeadModel, LMModelOutput};
use crate::Config;
use serde::{Deserialize, Serialize};
use std::borrow::{Borrow, BorrowMut};
@ -608,7 +608,7 @@ impl LMHeadModel for GPT2LMHeadModel {
/// let position_ids = Tensor::arange(sequence_length, (Int64, device))
/// .expand(&[batch_size, sequence_length], true);
///
/// let (output, _, past, hidden_states, attentions) = no_grad(|| {
/// let model_output = no_grad(|| {
/// gpt2_model
/// .forward_t(
/// &Some(input_tensor),
@ -635,16 +635,7 @@ impl LMHeadModel for GPT2LMHeadModel {
_encoder_outputs: Option<&Tensor>,
_decoder_input_ids: &Option<Tensor>,
train: bool,
) -> Result<
(
Tensor,
Option<Tensor>,
Cache,
Option<Vec<Tensor>>,
Option<Vec<Tensor>>,
),
&'static str,
> {
) -> Result<LMModelOutput, &'static str> {
let (output, past, all_hidden_states, all_attentions) = match layer_past {
Cache::GPT2Cache(layer_past) => Ok(self.transformer.forward_t(
input_ids,
@ -668,12 +659,12 @@ impl LMHeadModel for GPT2LMHeadModel {
}?;
let lm_logits = output.apply(&self.lm_head);
Ok((
Ok(LMModelOutput {
lm_logits,
None,
Cache::GPT2Cache(past),
encoder_hidden_state: None,
cache: Cache::GPT2Cache(past),
all_hidden_states,
all_attentions,
))
})
}
}

View File

@ -11,8 +11,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use crate::bart::{BartConfig, BartModel, LayerState};
use crate::pipelines::generation::{Cache, LMHeadModel};
use crate::bart::{BartConfig, BartEncoderOutput, BartModel, LayerState};
use crate::pipelines::generation::{Cache, LMHeadModel, LMModelOutput};
use std::borrow::Borrow;
use tch::nn::Init;
use tch::{nn, Tensor};
@ -349,7 +349,7 @@ impl MarianForConditionalGeneration {
&self,
input_ids: Option<&Tensor>,
attention_mask: Option<&Tensor>,
encoder_outputs: Option<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>)>,
encoder_outputs: Option<BartEncoderOutput>,
decoder_input_ids: Option<&Tensor>,
decoder_attention_mask: Option<&Tensor>,
old_layer_states: Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
@ -363,15 +363,7 @@ impl MarianForConditionalGeneration {
Option<Vec<Tensor>>,
Option<Vec<Tensor>>,
) {
let (
decoder_outputs,
encoder_hidden_states,
decoder_cache,
all_decoder_hidden_states,
all_decoder_attentions,
all_encoder_hidden_states,
all_encoder_attentions,
) = self.base_model.forward_t(
let base_model_output = self.base_model.forward_t(
input_ids,
attention_mask,
decoder_input_ids,
@ -381,25 +373,31 @@ impl MarianForConditionalGeneration {
train,
);
let lm_logits = decoder_outputs.linear::<Tensor>(&self.base_model.embeddings.ws, None);
let lm_logits = base_model_output
.decoder_hidden_state
.linear::<Tensor>(&self.base_model.embeddings.ws, None);
(
lm_logits,
encoder_hidden_states,
decoder_cache,
all_decoder_hidden_states,
all_decoder_attentions,
all_encoder_hidden_states,
all_encoder_attentions,
base_model_output.encoder_hidden_state,
base_model_output.cache,
base_model_output.all_decoder_hidden_states,
base_model_output.all_decoder_attentions,
base_model_output.all_encoder_hidden_states,
base_model_output.all_encoder_attentions,
)
}
pub fn encode(&self, input_ids: &Tensor, attention_mask: Option<&Tensor>) -> Tensor {
let (encoder_hidden_states, _, _) = self.base_model.encoder.forward_t(
input_ids,
attention_mask,
&self.base_model.embeddings,
false,
);
let encoder_hidden_states = self
.base_model
.encoder
.forward_t(
input_ids,
attention_mask,
&self.base_model.embeddings,
false,
)
.hidden_state;
encoder_hidden_states
}
}
@ -483,22 +481,17 @@ impl LMHeadModel for MarianForConditionalGeneration {
encoder_outputs: Option<&Tensor>,
decoder_input_ids: &Option<Tensor>,
train: bool,
) -> Result<
(
Tensor,
Option<Tensor>,
Cache,
Option<Vec<Tensor>>,
Option<Vec<Tensor>>,
),
&'static str,
> {
let (decoder_output, encoder_hidden_states, new_cache, _, _, _, _) = match cache {
) -> Result<LMModelOutput, &'static str> {
let base_model_output = match cache {
Cache::BARTCache(cached_layer_states) => self.base_model.forward_t(
input_ids.as_ref(),
attention_mask.as_ref(),
decoder_input_ids.as_ref(),
Some((encoder_outputs.as_ref().unwrap().copy(), None, None)),
Some(BartEncoderOutput {
hidden_state: encoder_outputs.as_ref().unwrap().copy(),
all_hidden_states: None,
all_attentions: None,
}),
None,
cached_layer_states,
train,
@ -507,7 +500,11 @@ impl LMHeadModel for MarianForConditionalGeneration {
input_ids.as_ref(),
attention_mask.as_ref(),
decoder_input_ids.as_ref(),
Some((encoder_outputs.as_ref().unwrap().copy(), None, None)),
Some(BartEncoderOutput {
hidden_state: encoder_outputs.as_ref().unwrap().copy(),
all_hidden_states: None,
all_attentions: None,
}),
None,
None,
train,
@ -515,14 +512,16 @@ impl LMHeadModel for MarianForConditionalGeneration {
_ => Err("Cache not compatible with Marian Model")?,
};
let lm_logits = decoder_output.linear::<Tensor>(&self.base_model.embeddings.ws, None)
let lm_logits = base_model_output
.decoder_hidden_state
.linear::<Tensor>(&self.base_model.embeddings.ws, None)
+ &self.final_logits_bias;
Ok((
Ok(LMModelOutput {
lm_logits,
Some(encoder_hidden_states),
Cache::BARTCache(new_cache),
None,
None,
))
encoder_hidden_state: Some(base_model_output.encoder_hidden_state),
cache: Cache::BARTCache(base_model_output.cache),
all_hidden_states: None,
all_attentions: None,
})
}
}

View File

@ -16,7 +16,7 @@ use crate::common::dropout::Dropout;
use crate::common::linear::{linear_no_bias, LinearNoBias};
use crate::gpt2::Gpt2Config;
use crate::openai_gpt::transformer::Block;
use crate::pipelines::generation::{Cache, LMHeadModel};
use crate::pipelines::generation::{Cache, LMHeadModel, LMModelOutput};
use std::borrow::{Borrow, BorrowMut};
use tch::kind::Kind::Int64;
use tch::nn::embedding;
@ -388,7 +388,7 @@ impl LMHeadModel for OpenAIGPTLMHeadModel {
/// let token_type_ids = Tensor::ones(&[batch_size, sequence_length], (Int64, device));
/// let position_ids = Tensor::arange(sequence_length, (Int64, device)).expand(&[batch_size, sequence_length], true);
///
/// let (output, _, _, hidden_states, attentions) = no_grad(|| {
/// let model_output = no_grad(|| {
/// gpt_model
/// .forward_t(&Some(input_tensor),
/// Cache::None,
@ -412,16 +412,7 @@ impl LMHeadModel for OpenAIGPTLMHeadModel {
_encoder_outputs: Option<&Tensor>,
_decoder_input_ids: &Option<Tensor>,
train: bool,
) -> Result<
(
Tensor,
Option<Tensor>,
Cache,
Option<Vec<Tensor>>,
Option<Vec<Tensor>>,
),
&'static str,
> {
) -> Result<LMModelOutput, &'static str> {
let (output, all_hidden_states, all_attentions) = self.transformer.forward_t(
input_ids,
attention_mask,
@ -432,12 +423,12 @@ impl LMHeadModel for OpenAIGPTLMHeadModel {
)?;
let lm_logits = output.apply(&self.lm_head);
Ok((
Ok(LMModelOutput {
lm_logits,
None,
Cache::None,
encoder_hidden_state: None,
cache: Cache::None,
all_hidden_states,
all_attentions,
))
})
}
}

View File

@ -1648,8 +1648,8 @@ pub(crate) mod private_generation_utils {
false,
)
.unwrap();
outputs = temp.0;
past = temp.2;
outputs = temp.lm_logits;
past = temp.cache;
let mut next_token_logits = outputs.select(1, -1);
// Reduce probability for repeated inputs
@ -1854,8 +1854,8 @@ pub(crate) mod private_generation_utils {
false,
)
.unwrap();
outputs = temp.0;
past = temp.2;
outputs = temp.lm_logits;
past = temp.cache;
let mut next_token_logits = outputs.select(1, -1);
@ -2575,7 +2575,7 @@ pub trait LMHeadModel {
/// let position_ids = Tensor::arange(sequence_length, (Int64, device))
/// .expand(&[batch_size, sequence_length], true);
///
/// let (output, encoder_output, past, hidden_states, attentions) = no_grad(|| {
/// let model_output = no_grad(|| {
/// gpt2_model
/// .forward_t(
/// &Some(input_tensor),
@ -2602,14 +2602,13 @@ pub trait LMHeadModel {
encoder_outputs: Option<&Tensor>,
decoder_input_ids: &Option<Tensor>,
train: bool,
) -> Result<
(
Tensor,
Option<Tensor>,
Cache,
Option<Vec<Tensor>>,
Option<Vec<Tensor>>,
),
&'static str,
>;
) -> Result<LMModelOutput, &'static str>;
}
pub struct LMModelOutput {
pub lm_logits: Tensor,
pub encoder_hidden_state: Option<Tensor>,
pub cache: Cache,
pub all_hidden_states: Option<Vec<Tensor>>,
pub all_attentions: Option<Vec<Tensor>>,
}

View File

@ -301,7 +301,7 @@ impl SequenceClassificationOption {
None,
train,
)
.0
.decoder_hidden_state
}
Self::Bert(ref model) => {
model

View File

@ -332,7 +332,7 @@ impl ZeroShotClassificationOption {
None,
train,
)
.0
.decoder_hidden_state
}
Self::Bert(ref model) => {
model

View File

@ -9,7 +9,7 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use crate::pipelines::generation::{Cache, LMHeadModel};
use crate::pipelines::generation::{Cache, LMHeadModel, LMModelOutput};
use crate::t5::attention::LayerState;
use crate::t5::encoder::T5Stack;
use crate::Config;
@ -684,16 +684,7 @@ impl LMHeadModel for T5ForConditionalGeneration {
encoder_outputs: Option<&Tensor>,
decoder_input_ids: &Option<Tensor>,
train: bool,
) -> Result<
(
Tensor,
Option<Tensor>,
Cache,
Option<Vec<Tensor>>,
Option<Vec<Tensor>>,
),
&'static str,
> {
) -> Result<LMModelOutput, &'static str> {
let (decoder_output, encoder_hidden_states, new_cache, _, _, _, _) = match cache {
Cache::T5Cache(cached_layer_states) => self.base_model.forward_t(
input_ids.as_ref(),
@ -723,12 +714,12 @@ impl LMHeadModel for T5ForConditionalGeneration {
let lm_logits = decoder_output.linear::<Tensor>(&self.base_model.embeddings.ws, None)
* (self.model_dim.powf(-0.5));
Ok((
Ok(LMModelOutput {
lm_logits,
Some(encoder_hidden_states),
Cache::T5Cache(new_cache),
None,
None,
))
encoder_hidden_state: Some(encoder_hidden_states),
cache: Cache::T5Cache(new_cache),
all_hidden_states: None,
all_attentions: None,
})
}
}

View File

@ -62,12 +62,12 @@ fn bart_lm_model() -> anyhow::Result<()> {
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
let (output, encoder_outputs, _, _, _, _, _) =
let model_output =
bart_model.forward_t(Some(&input_tensor), None, None, None, None, None, false);
assert_eq!(output.size(), vec!(1, 6, 1024));
assert_eq!(encoder_outputs.size(), vec!(1, 6, 1024));
assert!((output.double_value(&[0, 0, 0]) - 0.7877).abs() < 1e-4);
assert_eq!(model_output.decoder_hidden_state.size(), vec!(1, 6, 1024));
assert_eq!(model_output.encoder_hidden_state.size(), vec!(1, 6, 1024));
assert!((model_output.decoder_hidden_state.double_value(&[0, 0, 0]) - 0.7877).abs() < 1e-4);
Ok(())
}

View File

@ -61,7 +61,7 @@ fn distilgpt2_lm_model() -> anyhow::Result<()> {
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
let (output, _, past, _, _) = gpt2_model
let model_output = gpt2_model
.forward_t(
&Some(input_tensor),
Cache::None,
@ -75,11 +75,16 @@ fn distilgpt2_lm_model() -> anyhow::Result<()> {
)
.unwrap();
let next_word_id = output.get(0).get(-1).argmax(-1, true).int64_value(&[0]);
let next_word_id = model_output
.lm_logits
.get(0)
.get(-1)
.argmax(-1, true)
.int64_value(&[0]);
let next_word = tokenizer.decode(vec![next_word_id], true, true);
assert_eq!(output.size(), vec!(1, 11, 50257));
match past {
assert_eq!(model_output.lm_logits.size(), vec!(1, 11, 50257));
match model_output.cache {
Cache::GPT2Cache(past) => {
assert!(past.is_some());
assert_eq!(past.as_ref().unwrap().len(), config.n_layer as usize);
@ -91,7 +96,13 @@ fn distilgpt2_lm_model() -> anyhow::Result<()> {
_ => panic!("Wrong cache returned for GPT2"),
}
assert!(
(output.double_value(&[0, output.size()[1] - 1, next_word_id]) - (-48.7065)).abs() < 1e-4
(model_output.lm_logits.double_value(&[
0,
model_output.lm_logits.size()[1] - 1,
next_word_id
]) - (-48.7065))
.abs()
< 1e-4
);
assert_eq!(next_word_id, 14104i64);
assert_eq!(next_word, String::from(" twelve"));

View File

@ -62,7 +62,7 @@ fn gpt2_lm_model() -> anyhow::Result<()> {
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
let (output, _, past, _, _) = gpt2_model
let model_output = gpt2_model
.forward_t(
&Some(input_tensor),
Cache::None,
@ -76,11 +76,16 @@ fn gpt2_lm_model() -> anyhow::Result<()> {
)
.unwrap();
let next_word_id = output.get(0).get(-1).argmax(-1, true).int64_value(&[0]);
let next_word_id = model_output
.lm_logits
.get(0)
.get(-1)
.argmax(-1, true)
.int64_value(&[0]);
let next_word = tokenizer.decode(vec![next_word_id], true, true);
assert_eq!(output.size(), vec!(1, 4, 50257));
match past {
assert_eq!(model_output.lm_logits.size(), vec!(1, 4, 50257));
match model_output.cache {
Cache::GPT2Cache(past) => {
assert!(past.is_some());
assert_eq!(past.as_ref().unwrap().len(), config.n_layer as usize);
@ -92,7 +97,13 @@ fn gpt2_lm_model() -> anyhow::Result<()> {
_ => panic!("Wrong cache returned for GPT2"),
}
assert!(
(output.double_value(&[0, output.size()[1] - 1, next_word_id]) - (-69.4948)).abs() < 1e-4
(model_output.lm_logits.double_value(&[
0,
model_output.lm_logits.size()[1] - 1,
next_word_id
]) - (-69.4948))
.abs()
< 1e-4
);
assert_eq!(next_word_id, 1936i64);
assert_eq!(next_word, String::from(" five"));

View File

@ -64,7 +64,7 @@ fn openai_gpt_lm_model() -> anyhow::Result<()> {
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
let (output, _, _, _, _) = openai_gpt
let model_output = openai_gpt
.forward_t(
&Some(input_tensor),
Cache::None,
@ -78,12 +78,23 @@ fn openai_gpt_lm_model() -> anyhow::Result<()> {
)
.unwrap();
let next_word_id = output.get(0).get(-1).argmax(-1, true).int64_value(&[0]);
let next_word_id = model_output
.lm_logits
.get(0)
.get(-1)
.argmax(-1, true)
.int64_value(&[0]);
let next_word = tokenizer.decode(vec![next_word_id], true, true);
assert_eq!(output.size(), vec!(1, 6, 40478));
assert_eq!(model_output.lm_logits.size(), vec!(1, 6, 40478));
assert!(
(output.double_value(&[0, output.size()[1] - 1, next_word_id]) - (9.1056)).abs() < 1e-4
(model_output.lm_logits.double_value(&[
0,
model_output.lm_logits.size()[1] - 1,
next_word_id
]) - (9.1056))
.abs()
< 1e-4
);
assert_eq!(next_word_id, 580i64);
assert_eq!(next_word, String::from("be"));