diff --git a/clippy.toml b/clippy.toml new file mode 100644 index 0000000..1e5837c --- /dev/null +++ b/clippy.toml @@ -0,0 +1 @@ +too-many-arguments-threshold = 10 \ No newline at end of file diff --git a/examples/bart.rs b/examples/bart.rs index 5831934..ac18972 100644 --- a/examples/bart.rs +++ b/examples/bart.rs @@ -73,12 +73,15 @@ fn main() -> anyhow::Result<()> { let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device); // Forward pass - let (decoder_output, encoder_output, _, _, _, _, _) = + let model_output = no_grad(|| bart_model.forward_t(Some(&input_tensor), None, None, None, None, None, false)); // Print masked tokens - println!("{:?}", encoder_output); - println!("{:?}", decoder_output); - println!("{:?}", decoder_output.double_value(&[0, 0, 0])); + println!("{:?}", model_output.encoder_hidden_state); + println!("{:?}", model_output.decoder_hidden_state); + println!( + "{:?}", + model_output.decoder_hidden_state.double_value(&[0, 0, 0]) + ); Ok(()) } diff --git a/examples/gpt2.rs b/examples/gpt2.rs index 171179c..fa281fb 100644 --- a/examples/gpt2.rs +++ b/examples/gpt2.rs @@ -70,7 +70,7 @@ fn main() -> anyhow::Result<()> { let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device); // Forward pass - let (output, _, _, _, _) = gpt2_model + let model_output = gpt2_model .forward_t( &Some(input_tensor), Cache::None, @@ -84,7 +84,12 @@ fn main() -> anyhow::Result<()> { ) .unwrap(); - let next_word_id = output.get(0).get(-1).argmax(-1, true).int64_value(&[0]); + let next_word_id = model_output + .lm_logits + .get(0) + .get(-1) + .argmax(-1, true) + .int64_value(&[0]); let next_word = tokenizer.decode(vec![next_word_id], true, true); println!("Provided input: {}", input[0]); println!("Next word: {}", next_word); diff --git a/examples/openai_gpt.rs b/examples/openai_gpt.rs index a51ae32..493f222 100644 --- a/examples/openai_gpt.rs +++ b/examples/openai_gpt.rs @@ -75,7 +75,7 @@ fn main() -> anyhow::Result<()> { let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device); // Forward pass - let (output, _, _, _, _) = openai_gpt + let model_output = openai_gpt .forward_t( &Some(input_tensor), Cache::None, @@ -89,7 +89,12 @@ fn main() -> anyhow::Result<()> { ) .unwrap(); - let next_word_id = output.get(0).get(-1).argmax(-1, true).int64_value(&[0]); + let next_word_id = model_output + .lm_logits + .get(0) + .get(-1) + .argmax(-1, true) + .int64_value(&[0]); let next_word = tokenizer.decode(vec![next_word_id], true, true); println!("Provided input: {}", input[0]); println!("Next word: {}", next_word); diff --git a/src/bart/bart_model.rs b/src/bart/bart_model.rs index fbab14f..06fa030 100644 --- a/src/bart/bart_model.rs +++ b/src/bart/bart_model.rs @@ -13,9 +13,9 @@ use crate::bart::attention::LayerState; use crate::bart::decoder::BartDecoder; -use crate::bart::encoder::BartEncoder; +use crate::bart::encoder::{BartEncoder, BartEncoderOutput}; use crate::common::dropout::Dropout; -use crate::pipelines::generation::{Cache, LMHeadModel}; +use crate::pipelines::generation::{Cache, LMHeadModel, LMModelOutput}; use crate::Config; use serde::{Deserialize, Serialize}; use std::borrow::Borrow; @@ -346,15 +346,7 @@ impl BartModel { /// let decoder_attention_mask = /// Tensor::ones(&[batch_size, source_sequence_length], (Int64, device)); /// - /// let ( - /// decoder_output, - /// encoder_hidden_states, - /// decoder_cache, - /// all_encoder_hidden_states, - /// all_encoder_attentions, - /// all_decoder_hidden_states, - /// all_decoder_attentions, - /// ) = no_grad(|| { + /// let model_output = no_grad(|| { /// bart_model.forward_t( /// Some(&input_tensor), /// Some(&encoder_attention_mask), @@ -371,19 +363,11 @@ impl BartModel { input_ids: Option<&Tensor>, attention_mask: Option<&Tensor>, decoder_input_ids: Option<&Tensor>, - encoder_outputs: Option<(Tensor, Option>, Option>)>, + encoder_output: Option, decoder_attention_mask: Option<&Tensor>, layer_states: Option, Option)>>, train: bool, - ) -> ( - Tensor, - Tensor, - Option, Option)>>, - Option>, - Option>, - Option>, - Option>, - ) { + ) -> BartModelOutput { let (decoder_input_ids, decoder_padding_mask, causal_mask) = if self.generation_mode { (decoder_input_ids.unwrap().copy(), None, None) } else { @@ -398,43 +382,37 @@ impl BartModel { decoder_attention_mask, ) }; - let (encoder_hidden_states, all_encoder_hidden_states, all_encoder_attentions) = - match encoder_outputs { - Some(value) => value, - None => { - assert!( - input_ids.is_some(), - "input_ids must be provided when encoder output is not pre-computed" - ); - self.encoder.forward_t( - input_ids.unwrap(), - attention_mask, - &self.embeddings, - train, - ) - } - }; + let encoder_output = match encoder_output { + Some(value) => value, + None => { + assert!( + input_ids.is_some(), + "input_ids must be provided when encoder output is not pre-computed" + ); + self.encoder + .forward_t(input_ids.unwrap(), attention_mask, &self.embeddings, train) + } + }; - let (decoder_outputs, decoder_cache, all_decoder_hidden_states, all_decoder_attentions) = - self.decoder.forward_t( - &decoder_input_ids, - &encoder_hidden_states, - attention_mask, - decoder_padding_mask.as_ref(), - causal_mask.as_ref(), - &self.embeddings, - layer_states, - train, - ); - ( - decoder_outputs, - encoder_hidden_states, - decoder_cache.1, - all_decoder_hidden_states, - all_decoder_attentions, - all_encoder_hidden_states, - all_encoder_attentions, - ) + let decoder_output = self.decoder.forward_t( + &decoder_input_ids, + &encoder_output.hidden_state, + attention_mask, + decoder_padding_mask.as_ref(), + causal_mask.as_ref(), + &self.embeddings, + layer_states, + train, + ); + BartModelOutput { + decoder_hidden_state: decoder_output.hidden_state, + encoder_hidden_state: encoder_output.hidden_state, + cache: decoder_output.next_decoder_cache, + all_decoder_hidden_states: decoder_output.all_hidden_states, + all_decoder_attentions: decoder_output.all_attentions, + all_encoder_hidden_states: encoder_output.all_hidden_states, + all_encoder_attentions: encoder_output.all_attentions, + } } } @@ -525,9 +503,7 @@ impl BartForConditionalGeneration { /// let encoder_attention_mask = Tensor::ones(&[batch_size, source_sequence_length], (Int64, device)); /// let decoder_attention_mask = Tensor::ones(&[batch_size, source_sequence_length], (Int64, device)); /// - /// let (decoder_output, encoder_hidden_states, cache, - /// all_encoder_hidden_states, all_encoder_attentions, - /// all_decoder_hidden_states, all_decoder_attentions) = no_grad(|| { + /// let model_output = no_grad(|| { /// bart_model /// .forward_t(Some(&input_tensor), /// Some(&encoder_attention_mask), @@ -542,58 +518,41 @@ impl BartForConditionalGeneration { &self, input_ids: Option<&Tensor>, attention_mask: Option<&Tensor>, - encoder_outputs: Option<(Tensor, Option>, Option>)>, + encoder_output: Option, decoder_input_ids: Option<&Tensor>, decoder_attention_mask: Option<&Tensor>, old_layer_states: Option, Option)>>, train: bool, - ) -> ( - Tensor, - Tensor, - Option, Option)>>, - Option>, - Option>, - Option>, - Option>, - ) { - let ( - decoder_outputs, - encoder_hidden_states, - decoder_cache, - all_decoder_hidden_states, - all_decoder_attentions, - all_encoder_hidden_states, - all_encoder_attentions, - ) = self.base_model.forward_t( + ) -> BartModelOutput { + let base_model_output = self.base_model.forward_t( input_ids, attention_mask, decoder_input_ids, - encoder_outputs, + encoder_output, decoder_attention_mask, old_layer_states, train, ); - let lm_logits = decoder_outputs.linear::(&self.base_model.embeddings.ws, None); - ( - lm_logits, - encoder_hidden_states, - decoder_cache, - all_decoder_hidden_states, - all_decoder_attentions, - all_encoder_hidden_states, - all_encoder_attentions, - ) + let lm_logits = base_model_output + .decoder_hidden_state + .linear::(&self.base_model.embeddings.ws, None); + BartModelOutput { + decoder_hidden_state: lm_logits, + ..base_model_output + } } pub fn encode(&self, input_ids: &Tensor, attention_mask: Option<&Tensor>) -> Tensor { - let (encoder_hidden_states, _, _) = self.base_model.encoder.forward_t( - input_ids, - attention_mask, - &self.base_model.embeddings, - false, - ); - encoder_hidden_states + self.base_model + .encoder + .forward_t( + input_ids, + attention_mask, + &self.base_model.embeddings, + false, + ) + .hidden_state } } @@ -740,9 +699,7 @@ impl BartForSequenceClassification { /// let encoder_attention_mask = Tensor::ones(&[batch_size, source_sequence_length], (Int64, device)); /// let decoder_attention_mask = Tensor::ones(&[batch_size, source_sequence_length], (Int64, device)); /// - /// let (decoder_output, encoder_hidden_states, cache, - /// all_encoder_hidden_states, all_encoder_attentions, - /// all_decoder_hidden_states, all_decoder_attentions) = no_grad(|| { + /// let model_output = no_grad(|| { /// bart_model /// .forward_t(Some(&input_tensor), /// Some(&encoder_attention_mask), @@ -757,60 +714,51 @@ impl BartForSequenceClassification { &self, input_ids: &Tensor, attention_mask: Option<&Tensor>, - encoder_outputs: Option<(Tensor, Option>, Option>)>, + encoder_output: Option, decoder_input_ids: Option<&Tensor>, decoder_attention_mask: Option<&Tensor>, train: bool, - ) -> ( - Tensor, - Tensor, - Option>, - Option>, - Option>, - Option>, - ) { - let ( - decoder_outputs, - encoder_hidden_states, - _, - all_decoder_hidden_states, - all_decoder_attentions, - all_encoder_hidden_states, - all_encoder_attentions, - ) = self.base_model.forward_t( + ) -> BartModelOutput { + let base_model_output = self.base_model.forward_t( Some(input_ids), attention_mask, decoder_input_ids, - encoder_outputs, + encoder_output, decoder_attention_mask, None, train, ); let eos_mask = input_ids.eq(self.eos_token_id); let reshape = eos_mask.sum1(&[1], true, Int64); - let sentence_representation = decoder_outputs + let sentence_representation = base_model_output + .decoder_hidden_state .permute(&[2, 0, 1]) .masked_select(&eos_mask) .view((-1, reshape.size()[0] * reshape.int64_value(&[0, 0]))) .transpose(0, 1) .view(( - decoder_outputs.size()[0], + base_model_output.decoder_hidden_state.size()[0], -1, - *decoder_outputs.size().last().unwrap(), + *base_model_output + .decoder_hidden_state + .size() + .last() + .unwrap(), )) .select(1, -1); let logits = self .classification_head .forward_t(&sentence_representation, train); - ( - logits, - encoder_hidden_states, - all_decoder_hidden_states, - all_decoder_attentions, - all_encoder_hidden_states, - all_encoder_attentions, - ) + BartModelOutput { + decoder_hidden_state: logits, + encoder_hidden_state: base_model_output.encoder_hidden_state, + cache: None, + all_decoder_hidden_states: base_model_output.all_decoder_hidden_states, + all_decoder_attentions: base_model_output.all_decoder_attentions, + all_encoder_hidden_states: base_model_output.all_encoder_hidden_states, + all_encoder_attentions: base_model_output.all_encoder_attentions, + } } } @@ -860,9 +808,7 @@ impl LMHeadModel for BartForConditionalGeneration { /// let encoder_attention_mask = Tensor::ones(&[batch_size, source_sequence_length], (Int64, device)); /// let decoder_attention_mask = Tensor::ones(&[batch_size, source_sequence_length], (Int64, device)); /// - /// let (decoder_output, encoder_hidden_states, cache, - /// all_encoder_hidden_states, all_encoder_attentions, - /// all_decoder_hidden_states, all_decoder_attentions) = no_grad(|| { + /// let model_output = no_grad(|| { /// bart_model /// .forward_t(Some(&input_tensor), /// Some(&encoder_attention_mask), @@ -884,22 +830,17 @@ impl LMHeadModel for BartForConditionalGeneration { encoder_outputs: Option<&Tensor>, decoder_input_ids: &Option, train: bool, - ) -> Result< - ( - Tensor, - Option, - Cache, - Option>, - Option>, - ), - &'static str, - > { - let (decoder_output, encoder_hidden_states, new_cache, _, _, _, _) = match cache { + ) -> Result { + let base_model_output = match cache { Cache::BARTCache(cached_layer_states) => self.base_model.forward_t( input_ids.as_ref(), attention_mask.as_ref(), decoder_input_ids.as_ref(), - Some((encoder_outputs.as_ref().unwrap().copy(), None, None)), + Some(BartEncoderOutput { + hidden_state: encoder_outputs.as_ref().unwrap().copy(), + all_hidden_states: None, + all_attentions: None, + }), None, cached_layer_states, train, @@ -909,21 +850,37 @@ impl LMHeadModel for BartForConditionalGeneration { input_ids.as_ref(), attention_mask.as_ref(), decoder_input_ids.as_ref(), - Some((encoder_outputs.as_ref().unwrap().copy(), None, None)), + Some(BartEncoderOutput { + hidden_state: encoder_outputs.as_ref().unwrap().copy(), + all_hidden_states: None, + all_attentions: None, + }), None, None, train, ), - _ => Err("Cache not compatible with BART Model")?, + _ => return Err("Cache not compatible with BART Model"), }; - let lm_logits = decoder_output.linear::(&self.base_model.embeddings.ws, None); - Ok(( + let lm_logits = base_model_output + .decoder_hidden_state + .linear::(&self.base_model.embeddings.ws, None); + Ok(LMModelOutput { lm_logits, - Some(encoder_hidden_states), - Cache::BARTCache(new_cache), - None, - None, - )) + encoder_hidden_state: Some(base_model_output.encoder_hidden_state), + cache: Cache::BARTCache(base_model_output.cache), + all_hidden_states: None, + all_attentions: None, + }) } } + +pub struct BartModelOutput { + pub decoder_hidden_state: Tensor, + pub encoder_hidden_state: Tensor, + pub cache: Option, Option)>>, + pub all_decoder_hidden_states: Option>, + pub all_decoder_attentions: Option>, + pub all_encoder_hidden_states: Option>, + pub all_encoder_attentions: Option>, +} diff --git a/src/bart/decoder.rs b/src/bart/decoder.rs index fe99aa9..aa8e4e7 100644 --- a/src/bart/decoder.rs +++ b/src/bart/decoder.rs @@ -288,15 +288,7 @@ impl BartDecoder { embeddings: &nn::Embedding, old_layer_states: Option, Option)>>, train: bool, - ) -> ( - Tensor, - ( - Option, - Option, Option)>>, - ), - Option>, - Option>, - ) { + ) -> BartDecoderOutput { let encoder_padding_mask = match encoder_padding_mask { Some(mask) => Some(mask.eq(0).to_kind(Bool)), None => None, @@ -342,45 +334,48 @@ impl BartDecoder { }; let encoder_hidden_states = encoder_hidden_states.transpose(0, 1); let mut attention_weights: Option; - let mut layers = self.layers.iter().enumerate(); - loop { - match layers.next() { - Some((layer_idx, layer)) => { - let layer_state = match &next_decoder_cache { - Some(values) => values[layer_idx].to_owned(), - None => (None, None), - }; - let temp = layer.forward_t( - &hidden_state, - &encoder_hidden_states, - encoder_padding_mask.as_ref(), - decoder_causal_mask, - decoder_padding_mask, - layer_state, - train, - ); - hidden_state = temp.0; - attention_weights = temp.1; - if let Some(hidden_states) = all_hidden_states.borrow_mut() { - hidden_states.push(hidden_state.as_ref().copy().transpose(0, 1)); - }; - if let Some(attentions) = all_attentions.borrow_mut() { - attentions.push(attention_weights.as_ref().unwrap().copy()); - }; - if let Some(value) = &mut next_decoder_cache { - value[layer_idx] = temp.2 - }; - } - None => break, + for (layer_idx, layer) in self.layers.iter().enumerate() { + let layer_state = match &next_decoder_cache { + Some(values) => values[layer_idx].to_owned(), + None => (None, None), + }; + let temp = layer.forward_t( + &hidden_state, + &encoder_hidden_states, + encoder_padding_mask.as_ref(), + decoder_causal_mask, + decoder_padding_mask, + layer_state, + train, + ); + hidden_state = temp.0; + attention_weights = temp.1; + if let Some(hidden_states) = all_hidden_states.borrow_mut() { + hidden_states.push(hidden_state.as_ref().copy().transpose(0, 1)); + }; + if let Some(attentions) = all_attentions.borrow_mut() { + attentions.push(attention_weights.as_ref().unwrap().copy()); + }; + if let Some(value) = &mut next_decoder_cache { + value[layer_idx] = temp.2 }; } - ( - hidden_state.transpose(0, 1), - (encoder_padding_mask, next_decoder_cache), + BartDecoderOutput { + hidden_state: hidden_state.transpose(0, 1), + encoder_padding_mask, + next_decoder_cache, all_hidden_states, all_attentions, - ) + } } } + +pub struct BartDecoderOutput { + pub hidden_state: Tensor, + pub encoder_padding_mask: Option, + pub next_decoder_cache: Option, Option)>>, + pub all_hidden_states: Option>, + pub all_attentions: Option>, +} diff --git a/src/bart/encoder.rs b/src/bart/encoder.rs index 4eab4fe..ad615e7 100644 --- a/src/bart/encoder.rs +++ b/src/bart/encoder.rs @@ -232,7 +232,7 @@ impl BartEncoder { attention_mask: Option<&Tensor>, embeddings: &nn::Embedding, train: bool, - ) -> (Tensor, Option>, Option>) { + ) -> BartEncoderOutput { let attention_mask = match attention_mask { Some(mask) => Some(mask.eq(0).to_kind(Bool)), None => None, @@ -260,33 +260,34 @@ impl BartEncoder { let mut hidden_state = x.copy(); let mut attention_weights: Option; - let mut layers = self.layers.iter(); - loop { - match layers.next() { - Some(layer) => { - if let Some(hidden_states) = all_hidden_states.borrow_mut() { - hidden_states.push(hidden_state.as_ref().copy().transpose(0, 1)); - }; + for layer in &self.layers { + if let Some(hidden_states) = all_hidden_states.borrow_mut() { + hidden_states.push(hidden_state.as_ref().copy().transpose(0, 1)); + }; - let temp = layer.forward_t(&hidden_state, attention_mask.as_ref(), train); - hidden_state = temp.0; - attention_weights = temp.1; - if let Some(attentions) = all_attentions.borrow_mut() { - attentions.push(attention_weights.as_ref().unwrap().copy()); - }; - } - None => break, + let temp = layer.forward_t(&hidden_state, attention_mask.as_ref(), train); + hidden_state = temp.0; + attention_weights = temp.1; + if let Some(attentions) = all_attentions.borrow_mut() { + attentions.push(attention_weights.as_ref().unwrap().copy()); }; } + if let Some(hidden_states) = all_hidden_states.borrow_mut() { hidden_states.push(hidden_state.as_ref().copy().transpose(0, 1)); }; - ( - hidden_state.transpose(0, 1), + BartEncoderOutput { + hidden_state: hidden_state.transpose(0, 1), all_hidden_states, all_attentions, - ) + } } } + +pub struct BartEncoderOutput { + pub hidden_state: Tensor, + pub all_hidden_states: Option>, + pub all_attentions: Option>, +} diff --git a/src/bart/mod.rs b/src/bart/mod.rs index 1f82c68..782ff2c 100644 --- a/src/bart/mod.rs +++ b/src/bart/mod.rs @@ -69,3 +69,5 @@ pub use bart_model::{ BartForSequenceClassification, BartMergesResources, BartModel, BartModelResources, BartVocabResources, }; + +pub(crate) use encoder::BartEncoderOutput; diff --git a/src/bert/attention.rs b/src/bert/attention.rs index 8be886e..6010e89 100644 --- a/src/bert/attention.rs +++ b/src/bert/attention.rs @@ -86,7 +86,7 @@ impl BertSelfAttention { fn flatten(&self, x: Tensor, bs: i64, dim_per_head: i64) -> Tensor { x.transpose(1, 2) .contiguous() - .view((bs, -1, &self.num_attention_heads * dim_per_head)) + .view((bs, -1, self.num_attention_heads * dim_per_head)) } pub fn forward_t( diff --git a/src/gpt2/gpt2_model.rs b/src/gpt2/gpt2_model.rs index c008800..dceb7f9 100644 --- a/src/gpt2/gpt2_model.rs +++ b/src/gpt2/gpt2_model.rs @@ -15,7 +15,7 @@ use crate::common::dropout::Dropout; use crate::common::linear::{linear_no_bias, LinearNoBias}; use crate::gpt2::transformer::Block; -use crate::pipelines::generation::{Cache, LMHeadModel}; +use crate::pipelines::generation::{Cache, LMHeadModel, LMModelOutput}; use crate::Config; use serde::{Deserialize, Serialize}; use std::borrow::{Borrow, BorrowMut}; @@ -608,7 +608,7 @@ impl LMHeadModel for GPT2LMHeadModel { /// let position_ids = Tensor::arange(sequence_length, (Int64, device)) /// .expand(&[batch_size, sequence_length], true); /// - /// let (output, _, past, hidden_states, attentions) = no_grad(|| { + /// let model_output = no_grad(|| { /// gpt2_model /// .forward_t( /// &Some(input_tensor), @@ -635,16 +635,7 @@ impl LMHeadModel for GPT2LMHeadModel { _encoder_outputs: Option<&Tensor>, _decoder_input_ids: &Option, train: bool, - ) -> Result< - ( - Tensor, - Option, - Cache, - Option>, - Option>, - ), - &'static str, - > { + ) -> Result { let (output, past, all_hidden_states, all_attentions) = match layer_past { Cache::GPT2Cache(layer_past) => Ok(self.transformer.forward_t( input_ids, @@ -668,12 +659,12 @@ impl LMHeadModel for GPT2LMHeadModel { }?; let lm_logits = output.apply(&self.lm_head); - Ok(( + Ok(LMModelOutput { lm_logits, - None, - Cache::GPT2Cache(past), + encoder_hidden_state: None, + cache: Cache::GPT2Cache(past), all_hidden_states, all_attentions, - )) + }) } } diff --git a/src/marian/marian_model.rs b/src/marian/marian_model.rs index b421df9..c7b679a 100644 --- a/src/marian/marian_model.rs +++ b/src/marian/marian_model.rs @@ -11,8 +11,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::bart::{BartConfig, BartModel, LayerState}; -use crate::pipelines::generation::{Cache, LMHeadModel}; +use crate::bart::{BartConfig, BartEncoderOutput, BartModel, LayerState}; +use crate::pipelines::generation::{Cache, LMHeadModel, LMModelOutput}; use std::borrow::Borrow; use tch::nn::Init; use tch::{nn, Tensor}; @@ -349,7 +349,7 @@ impl MarianForConditionalGeneration { &self, input_ids: Option<&Tensor>, attention_mask: Option<&Tensor>, - encoder_outputs: Option<(Tensor, Option>, Option>)>, + encoder_outputs: Option, decoder_input_ids: Option<&Tensor>, decoder_attention_mask: Option<&Tensor>, old_layer_states: Option, Option)>>, @@ -363,15 +363,7 @@ impl MarianForConditionalGeneration { Option>, Option>, ) { - let ( - decoder_outputs, - encoder_hidden_states, - decoder_cache, - all_decoder_hidden_states, - all_decoder_attentions, - all_encoder_hidden_states, - all_encoder_attentions, - ) = self.base_model.forward_t( + let base_model_output = self.base_model.forward_t( input_ids, attention_mask, decoder_input_ids, @@ -381,25 +373,31 @@ impl MarianForConditionalGeneration { train, ); - let lm_logits = decoder_outputs.linear::(&self.base_model.embeddings.ws, None); + let lm_logits = base_model_output + .decoder_hidden_state + .linear::(&self.base_model.embeddings.ws, None); ( lm_logits, - encoder_hidden_states, - decoder_cache, - all_decoder_hidden_states, - all_decoder_attentions, - all_encoder_hidden_states, - all_encoder_attentions, + base_model_output.encoder_hidden_state, + base_model_output.cache, + base_model_output.all_decoder_hidden_states, + base_model_output.all_decoder_attentions, + base_model_output.all_encoder_hidden_states, + base_model_output.all_encoder_attentions, ) } pub fn encode(&self, input_ids: &Tensor, attention_mask: Option<&Tensor>) -> Tensor { - let (encoder_hidden_states, _, _) = self.base_model.encoder.forward_t( - input_ids, - attention_mask, - &self.base_model.embeddings, - false, - ); + let encoder_hidden_states = self + .base_model + .encoder + .forward_t( + input_ids, + attention_mask, + &self.base_model.embeddings, + false, + ) + .hidden_state; encoder_hidden_states } } @@ -483,22 +481,17 @@ impl LMHeadModel for MarianForConditionalGeneration { encoder_outputs: Option<&Tensor>, decoder_input_ids: &Option, train: bool, - ) -> Result< - ( - Tensor, - Option, - Cache, - Option>, - Option>, - ), - &'static str, - > { - let (decoder_output, encoder_hidden_states, new_cache, _, _, _, _) = match cache { + ) -> Result { + let base_model_output = match cache { Cache::BARTCache(cached_layer_states) => self.base_model.forward_t( input_ids.as_ref(), attention_mask.as_ref(), decoder_input_ids.as_ref(), - Some((encoder_outputs.as_ref().unwrap().copy(), None, None)), + Some(BartEncoderOutput { + hidden_state: encoder_outputs.as_ref().unwrap().copy(), + all_hidden_states: None, + all_attentions: None, + }), None, cached_layer_states, train, @@ -507,7 +500,11 @@ impl LMHeadModel for MarianForConditionalGeneration { input_ids.as_ref(), attention_mask.as_ref(), decoder_input_ids.as_ref(), - Some((encoder_outputs.as_ref().unwrap().copy(), None, None)), + Some(BartEncoderOutput { + hidden_state: encoder_outputs.as_ref().unwrap().copy(), + all_hidden_states: None, + all_attentions: None, + }), None, None, train, @@ -515,14 +512,16 @@ impl LMHeadModel for MarianForConditionalGeneration { _ => Err("Cache not compatible with Marian Model")?, }; - let lm_logits = decoder_output.linear::(&self.base_model.embeddings.ws, None) + let lm_logits = base_model_output + .decoder_hidden_state + .linear::(&self.base_model.embeddings.ws, None) + &self.final_logits_bias; - Ok(( + Ok(LMModelOutput { lm_logits, - Some(encoder_hidden_states), - Cache::BARTCache(new_cache), - None, - None, - )) + encoder_hidden_state: Some(base_model_output.encoder_hidden_state), + cache: Cache::BARTCache(base_model_output.cache), + all_hidden_states: None, + all_attentions: None, + }) } } diff --git a/src/openai_gpt/openai_gpt_model.rs b/src/openai_gpt/openai_gpt_model.rs index ffdb2be..b4ea7db 100644 --- a/src/openai_gpt/openai_gpt_model.rs +++ b/src/openai_gpt/openai_gpt_model.rs @@ -16,7 +16,7 @@ use crate::common::dropout::Dropout; use crate::common::linear::{linear_no_bias, LinearNoBias}; use crate::gpt2::Gpt2Config; use crate::openai_gpt::transformer::Block; -use crate::pipelines::generation::{Cache, LMHeadModel}; +use crate::pipelines::generation::{Cache, LMHeadModel, LMModelOutput}; use std::borrow::{Borrow, BorrowMut}; use tch::kind::Kind::Int64; use tch::nn::embedding; @@ -388,7 +388,7 @@ impl LMHeadModel for OpenAIGPTLMHeadModel { /// let token_type_ids = Tensor::ones(&[batch_size, sequence_length], (Int64, device)); /// let position_ids = Tensor::arange(sequence_length, (Int64, device)).expand(&[batch_size, sequence_length], true); /// - /// let (output, _, _, hidden_states, attentions) = no_grad(|| { + /// let model_output = no_grad(|| { /// gpt_model /// .forward_t(&Some(input_tensor), /// Cache::None, @@ -412,16 +412,7 @@ impl LMHeadModel for OpenAIGPTLMHeadModel { _encoder_outputs: Option<&Tensor>, _decoder_input_ids: &Option, train: bool, - ) -> Result< - ( - Tensor, - Option, - Cache, - Option>, - Option>, - ), - &'static str, - > { + ) -> Result { let (output, all_hidden_states, all_attentions) = self.transformer.forward_t( input_ids, attention_mask, @@ -432,12 +423,12 @@ impl LMHeadModel for OpenAIGPTLMHeadModel { )?; let lm_logits = output.apply(&self.lm_head); - Ok(( + Ok(LMModelOutput { lm_logits, - None, - Cache::None, + encoder_hidden_state: None, + cache: Cache::None, all_hidden_states, all_attentions, - )) + }) } } diff --git a/src/pipelines/generation.rs b/src/pipelines/generation.rs index ce89fb0..287275c 100644 --- a/src/pipelines/generation.rs +++ b/src/pipelines/generation.rs @@ -1648,8 +1648,8 @@ pub(crate) mod private_generation_utils { false, ) .unwrap(); - outputs = temp.0; - past = temp.2; + outputs = temp.lm_logits; + past = temp.cache; let mut next_token_logits = outputs.select(1, -1); // Reduce probability for repeated inputs @@ -1854,8 +1854,8 @@ pub(crate) mod private_generation_utils { false, ) .unwrap(); - outputs = temp.0; - past = temp.2; + outputs = temp.lm_logits; + past = temp.cache; let mut next_token_logits = outputs.select(1, -1); @@ -2575,7 +2575,7 @@ pub trait LMHeadModel { /// let position_ids = Tensor::arange(sequence_length, (Int64, device)) /// .expand(&[batch_size, sequence_length], true); /// - /// let (output, encoder_output, past, hidden_states, attentions) = no_grad(|| { + /// let model_output = no_grad(|| { /// gpt2_model /// .forward_t( /// &Some(input_tensor), @@ -2602,14 +2602,13 @@ pub trait LMHeadModel { encoder_outputs: Option<&Tensor>, decoder_input_ids: &Option, train: bool, - ) -> Result< - ( - Tensor, - Option, - Cache, - Option>, - Option>, - ), - &'static str, - >; + ) -> Result; +} + +pub struct LMModelOutput { + pub lm_logits: Tensor, + pub encoder_hidden_state: Option, + pub cache: Cache, + pub all_hidden_states: Option>, + pub all_attentions: Option>, } diff --git a/src/pipelines/sequence_classification.rs b/src/pipelines/sequence_classification.rs index 5aa6e61..02370e2 100644 --- a/src/pipelines/sequence_classification.rs +++ b/src/pipelines/sequence_classification.rs @@ -301,7 +301,7 @@ impl SequenceClassificationOption { None, train, ) - .0 + .decoder_hidden_state } Self::Bert(ref model) => { model diff --git a/src/pipelines/zero_shot_classification.rs b/src/pipelines/zero_shot_classification.rs index ae1c660..c5a2592 100644 --- a/src/pipelines/zero_shot_classification.rs +++ b/src/pipelines/zero_shot_classification.rs @@ -332,7 +332,7 @@ impl ZeroShotClassificationOption { None, train, ) - .0 + .decoder_hidden_state } Self::Bert(ref model) => { model diff --git a/src/t5/t5_model.rs b/src/t5/t5_model.rs index 025ed6c..8445731 100644 --- a/src/t5/t5_model.rs +++ b/src/t5/t5_model.rs @@ -9,7 +9,7 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -use crate::pipelines::generation::{Cache, LMHeadModel}; +use crate::pipelines::generation::{Cache, LMHeadModel, LMModelOutput}; use crate::t5::attention::LayerState; use crate::t5::encoder::T5Stack; use crate::Config; @@ -684,16 +684,7 @@ impl LMHeadModel for T5ForConditionalGeneration { encoder_outputs: Option<&Tensor>, decoder_input_ids: &Option, train: bool, - ) -> Result< - ( - Tensor, - Option, - Cache, - Option>, - Option>, - ), - &'static str, - > { + ) -> Result { let (decoder_output, encoder_hidden_states, new_cache, _, _, _, _) = match cache { Cache::T5Cache(cached_layer_states) => self.base_model.forward_t( input_ids.as_ref(), @@ -723,12 +714,12 @@ impl LMHeadModel for T5ForConditionalGeneration { let lm_logits = decoder_output.linear::(&self.base_model.embeddings.ws, None) * (self.model_dim.powf(-0.5)); - Ok(( + Ok(LMModelOutput { lm_logits, - Some(encoder_hidden_states), - Cache::T5Cache(new_cache), - None, - None, - )) + encoder_hidden_state: Some(encoder_hidden_states), + cache: Cache::T5Cache(new_cache), + all_hidden_states: None, + all_attentions: None, + }) } } diff --git a/tests/bart.rs b/tests/bart.rs index b4531a0..7999d6b 100644 --- a/tests/bart.rs +++ b/tests/bart.rs @@ -62,12 +62,12 @@ fn bart_lm_model() -> anyhow::Result<()> { let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device); // Forward pass - let (output, encoder_outputs, _, _, _, _, _) = + let model_output = bart_model.forward_t(Some(&input_tensor), None, None, None, None, None, false); - assert_eq!(output.size(), vec!(1, 6, 1024)); - assert_eq!(encoder_outputs.size(), vec!(1, 6, 1024)); - assert!((output.double_value(&[0, 0, 0]) - 0.7877).abs() < 1e-4); + assert_eq!(model_output.decoder_hidden_state.size(), vec!(1, 6, 1024)); + assert_eq!(model_output.encoder_hidden_state.size(), vec!(1, 6, 1024)); + assert!((model_output.decoder_hidden_state.double_value(&[0, 0, 0]) - 0.7877).abs() < 1e-4); Ok(()) } diff --git a/tests/distilgpt2.rs b/tests/distilgpt2.rs index c2d9f75..a47939c 100644 --- a/tests/distilgpt2.rs +++ b/tests/distilgpt2.rs @@ -61,7 +61,7 @@ fn distilgpt2_lm_model() -> anyhow::Result<()> { let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device); // Forward pass - let (output, _, past, _, _) = gpt2_model + let model_output = gpt2_model .forward_t( &Some(input_tensor), Cache::None, @@ -75,11 +75,16 @@ fn distilgpt2_lm_model() -> anyhow::Result<()> { ) .unwrap(); - let next_word_id = output.get(0).get(-1).argmax(-1, true).int64_value(&[0]); + let next_word_id = model_output + .lm_logits + .get(0) + .get(-1) + .argmax(-1, true) + .int64_value(&[0]); let next_word = tokenizer.decode(vec![next_word_id], true, true); - assert_eq!(output.size(), vec!(1, 11, 50257)); - match past { + assert_eq!(model_output.lm_logits.size(), vec!(1, 11, 50257)); + match model_output.cache { Cache::GPT2Cache(past) => { assert!(past.is_some()); assert_eq!(past.as_ref().unwrap().len(), config.n_layer as usize); @@ -91,7 +96,13 @@ fn distilgpt2_lm_model() -> anyhow::Result<()> { _ => panic!("Wrong cache returned for GPT2"), } assert!( - (output.double_value(&[0, output.size()[1] - 1, next_word_id]) - (-48.7065)).abs() < 1e-4 + (model_output.lm_logits.double_value(&[ + 0, + model_output.lm_logits.size()[1] - 1, + next_word_id + ]) - (-48.7065)) + .abs() + < 1e-4 ); assert_eq!(next_word_id, 14104i64); assert_eq!(next_word, String::from(" twelve")); diff --git a/tests/gpt2.rs b/tests/gpt2.rs index f5d7063..e2c8360 100644 --- a/tests/gpt2.rs +++ b/tests/gpt2.rs @@ -62,7 +62,7 @@ fn gpt2_lm_model() -> anyhow::Result<()> { let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device); // Forward pass - let (output, _, past, _, _) = gpt2_model + let model_output = gpt2_model .forward_t( &Some(input_tensor), Cache::None, @@ -76,11 +76,16 @@ fn gpt2_lm_model() -> anyhow::Result<()> { ) .unwrap(); - let next_word_id = output.get(0).get(-1).argmax(-1, true).int64_value(&[0]); + let next_word_id = model_output + .lm_logits + .get(0) + .get(-1) + .argmax(-1, true) + .int64_value(&[0]); let next_word = tokenizer.decode(vec![next_word_id], true, true); - assert_eq!(output.size(), vec!(1, 4, 50257)); - match past { + assert_eq!(model_output.lm_logits.size(), vec!(1, 4, 50257)); + match model_output.cache { Cache::GPT2Cache(past) => { assert!(past.is_some()); assert_eq!(past.as_ref().unwrap().len(), config.n_layer as usize); @@ -92,7 +97,13 @@ fn gpt2_lm_model() -> anyhow::Result<()> { _ => panic!("Wrong cache returned for GPT2"), } assert!( - (output.double_value(&[0, output.size()[1] - 1, next_word_id]) - (-69.4948)).abs() < 1e-4 + (model_output.lm_logits.double_value(&[ + 0, + model_output.lm_logits.size()[1] - 1, + next_word_id + ]) - (-69.4948)) + .abs() + < 1e-4 ); assert_eq!(next_word_id, 1936i64); assert_eq!(next_word, String::from(" five")); diff --git a/tests/openai_gpt.rs b/tests/openai_gpt.rs index 5db6cd2..02179e1 100644 --- a/tests/openai_gpt.rs +++ b/tests/openai_gpt.rs @@ -64,7 +64,7 @@ fn openai_gpt_lm_model() -> anyhow::Result<()> { let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device); // Forward pass - let (output, _, _, _, _) = openai_gpt + let model_output = openai_gpt .forward_t( &Some(input_tensor), Cache::None, @@ -78,12 +78,23 @@ fn openai_gpt_lm_model() -> anyhow::Result<()> { ) .unwrap(); - let next_word_id = output.get(0).get(-1).argmax(-1, true).int64_value(&[0]); + let next_word_id = model_output + .lm_logits + .get(0) + .get(-1) + .argmax(-1, true) + .int64_value(&[0]); let next_word = tokenizer.decode(vec![next_word_id], true, true); - assert_eq!(output.size(), vec!(1, 6, 40478)); + assert_eq!(model_output.lm_logits.size(), vec!(1, 6, 40478)); assert!( - (output.double_value(&[0, output.size()[1] - 1, next_word_id]) - (9.1056)).abs() < 1e-4 + (model_output.lm_logits.double_value(&[ + 0, + model_output.lm_logits.size()[1] - 1, + next_word_id + ]) - (9.1056)) + .abs() + < 1e-4 ); assert_eq!(next_word_id, 580i64); assert_eq!(next_word, String::from("be"));