Updated T5 (clippy)

This commit is contained in:
Guillaume B 2020-09-13 12:25:22 +02:00
parent b42cc60409
commit 47d9a1017d
9 changed files with 154 additions and 175 deletions

View File

@ -33,7 +33,7 @@ fn main() -> anyhow::Result<()> {
}; };
// Get answer // Get answer
let answers = qa_model.predict(&vec![qa_input_1, qa_input_2], 1, 32); let answers = qa_model.predict(&[qa_input_1, qa_input_2], 1, 32);
println!("{:?}", answers); println!("{:?}", answers);
Ok(()) Ok(())
} }

View File

@ -51,7 +51,7 @@ fn main() -> anyhow::Result<()> {
}; };
// Get answer // Get answer
let answers = qa_model.predict(&vec![qa_input_1, qa_input_2], 1, 32); let answers = qa_model.predict(&[qa_input_1, qa_input_2], 1, 32);
println!("{:?}", answers); println!("{:?}", answers);
Ok(()) Ok(())
} }

View File

@ -27,8 +27,10 @@ fn main() -> anyhow::Result<()> {
// Run model // Run model
let output = sequence_classification_model.predict_multilabel(&input, 0.05); let output = sequence_classification_model.predict_multilabel(&input, 0.05);
for label in output { if let Ok(labels) = output {
println!("{:?}", label); for label in labels {
println!("{:?}", label);
}
} }
Ok(()) Ok(())

View File

@ -199,9 +199,10 @@ impl T5Attention {
temp_value = temp_value.slice(2, length - 1, length, 1); temp_value = temp_value.slice(2, length - 1, length, 1);
}; };
if let Some(attention_mask) = attention_mask { if let Some(attention_mask) = attention_mask {
temp_value += attention_mask; Some(temp_value + attention_mask)
}; } else {
Some(temp_value) Some(temp_value)
}
} else { } else {
None None
}; };

View File

@ -148,12 +148,7 @@ impl T5Block {
encoder_decoder_position_bias: Option<&Tensor>, encoder_decoder_position_bias: Option<&Tensor>,
mut layer_states: (Option<LayerState>, Option<LayerState>), mut layer_states: (Option<LayerState>, Option<LayerState>),
train: bool, train: bool,
) -> ( ) -> T5BlockOutput {
Tensor,
(Option<Tensor>, Option<Tensor>),
(Option<Tensor>, Option<Tensor>),
(Option<LayerState>, Option<LayerState>),
) {
let ( let (
hidden_states, hidden_states,
self_attention_weights, self_attention_weights,
@ -190,17 +185,17 @@ impl T5Block {
(hidden_states, None, None, None) (hidden_states, None, None, None)
}; };
let attention_weights = (self_attention_weights, cross_attention_weights);
let position_bias = (self_attention_position_bias, cross_attention_position_bias);
layer_states = (self_attention_layer_past, cross_attention_layer_past); layer_states = (self_attention_layer_past, cross_attention_layer_past);
let hidden_states = self.ff_layer.forward_t(&hidden_states, train); let hidden_states = self.ff_layer.forward_t(&hidden_states, train);
( T5BlockOutput {
hidden_states, hidden_states,
attention_weights, self_attention_weights,
position_bias, cross_attention_weights,
layer_states, self_attention_position_bias,
) cross_attention_position_bias,
cache: layer_states,
}
} }
} }
@ -269,15 +264,7 @@ impl T5Stack {
embeddings: &nn::Embedding, embeddings: &nn::Embedding,
old_layer_states: Option<Vec<(Option<LayerState>, Option<LayerState>)>>, old_layer_states: Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
train: bool, train: bool,
) -> Result< ) -> Result<T5StackOutput, &'static str> {
(
Tensor,
Option<Vec<Tensor>>,
Option<Vec<Tensor>>,
Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
),
&'static str,
> {
let (input_embeddings, input_shape) = match input_ids { let (input_embeddings, input_shape) = match input_ids {
Some(input_ids_value) => match input_embeds { Some(input_ids_value) => match input_embeds {
Some(_) => { Some(_) => {
@ -332,7 +319,7 @@ impl T5Stack {
if self.is_decoder { if self.is_decoder {
let seq_ids = let seq_ids =
Tensor::arange(input_shape[1], (Kind::Float, input_embeddings.device())); Tensor::arange(input_shape[1], (Kind::Float, input_embeddings.device()));
let causal_mask = seq_ids.unsqueeze(0).unsqueeze(0).repeat(&vec![ let causal_mask = seq_ids.unsqueeze(0).unsqueeze(0).repeat(&[
input_shape[0], input_shape[0],
input_shape[1], input_shape[1],
1, 1,
@ -386,7 +373,7 @@ impl T5Stack {
} else { } else {
None None
}; };
let mut next_decoder_cache: Option<Vec<(Option<LayerState>, Option<LayerState>)>> = let mut next_cache: Option<Vec<(Option<LayerState>, Option<LayerState>)>> =
if self.store_cache { if self.store_cache {
if old_layer_states.is_some() { if old_layer_states.is_some() {
old_layer_states old_layer_states
@ -400,42 +387,36 @@ impl T5Stack {
let mut encoder_decoder_position_bias = None; let mut encoder_decoder_position_bias = None;
let mut attention_weights: Option<Tensor>; let mut attention_weights: Option<Tensor>;
let mut hidden_state = input_embeddings.apply_t(&self.dropout, train); let mut hidden_state = input_embeddings.apply_t(&self.dropout, train);
let mut blocks = self.blocks.iter().enumerate();
loop { for (layer_idx, layer) in self.blocks.iter().enumerate() {
match blocks.next() { let layer_state = match &next_cache {
Some((layer_idx, layer)) => { Some(values) => values[layer_idx].to_owned(),
let layer_state = match &next_decoder_cache { None => (None, None),
Some(values) => values[layer_idx].to_owned(), };
None => (None, None), let block_output = layer.forward_t(
}; &hidden_state,
let temp = layer.forward_t( position_bias.as_ref(),
&hidden_state, extended_attention_mask.as_ref(),
position_bias.as_ref(), encoder_hidden_states,
extended_attention_mask.as_ref(), extended_encoder_attention_mask.as_ref(),
encoder_hidden_states, encoder_decoder_position_bias.as_ref(),
extended_encoder_attention_mask.as_ref(), layer_state,
encoder_decoder_position_bias.as_ref(), train,
layer_state, );
train, if layer_idx == 0 {
); position_bias = block_output.self_attention_position_bias;
if layer_idx == 0 { encoder_decoder_position_bias = block_output.cross_attention_position_bias;
position_bias = (temp.2).0; }
encoder_decoder_position_bias = (temp.2).1; hidden_state = block_output.hidden_states;
} attention_weights = block_output.cross_attention_weights;
hidden_state = temp.0; if let Some(hidden_states) = all_hidden_states.borrow_mut() {
attention_weights = (temp.1).1; hidden_states.push(hidden_state.as_ref().copy().transpose(0, 1));
if let Some(hidden_states) = all_hidden_states.borrow_mut() { };
hidden_states.push(hidden_state.as_ref().copy().transpose(0, 1)); if let Some(attentions) = all_attentions.borrow_mut() {
}; attentions.push(attention_weights.as_ref().unwrap().copy());
if let Some(attentions) = all_attentions.borrow_mut() { };
attentions.push(attention_weights.as_ref().unwrap().copy()); if let Some(value) = &mut next_cache {
}; value[layer_idx] = block_output.cache
if let Some(value) = &mut next_decoder_cache {
value[layer_idx] = temp.3
};
}
None => break,
}; };
} }
@ -443,11 +424,27 @@ impl T5Stack {
.apply(&self.final_layer_norm) .apply(&self.final_layer_norm)
.apply_t(&self.dropout, train); .apply_t(&self.dropout, train);
Ok(( Ok(T5StackOutput {
hidden_state, hidden_state,
all_hidden_states, all_hidden_states,
all_attentions, all_attentions,
next_decoder_cache, next_cache,
)) })
} }
} }
pub struct T5BlockOutput {
pub hidden_states: Tensor,
pub self_attention_weights: Option<Tensor>,
pub cross_attention_weights: Option<Tensor>,
pub self_attention_position_bias: Option<Tensor>,
pub cross_attention_position_bias: Option<Tensor>,
pub cache: (Option<LayerState>, Option<LayerState>),
}
pub struct T5StackOutput {
pub hidden_state: Tensor,
pub all_hidden_states: Option<Vec<Tensor>>,
pub all_attentions: Option<Vec<Tensor>>,
pub next_cache: Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
}

View File

@ -11,7 +11,7 @@
// limitations under the License. // limitations under the License.
use crate::pipelines::generation::{Cache, LMHeadModel, LMModelOutput}; use crate::pipelines::generation::{Cache, LMHeadModel, LMModelOutput};
use crate::t5::attention::LayerState; use crate::t5::attention::LayerState;
use crate::t5::encoder::T5Stack; use crate::t5::encoder::{T5Stack, T5StackOutput};
use crate::Config; use crate::Config;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::borrow::Borrow; use std::borrow::Borrow;
@ -308,52 +308,30 @@ impl T5Model {
&self, &self,
input_ids: Option<&Tensor>, input_ids: Option<&Tensor>,
attention_mask: Option<&Tensor>, attention_mask: Option<&Tensor>,
encoder_outputs: Option<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>)>, encoder_outputs: Option<T5StackOutput>,
decoder_input_ids: Option<&Tensor>, decoder_input_ids: Option<&Tensor>,
decoder_attention_mask: Option<&Tensor>, decoder_attention_mask: Option<&Tensor>,
input_embeds: Option<Tensor>, input_embeds: Option<Tensor>,
decoder_input_embeds: Option<Tensor>, decoder_input_embeds: Option<Tensor>,
old_layer_states: Option<Vec<(Option<LayerState>, Option<LayerState>)>>, old_layer_states: Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
train: bool, train: bool,
) -> ( ) -> T5ModelOutput {
Tensor, let encoder_output = match encoder_outputs {
Tensor, Some(value) => value,
Option<Vec<(Option<LayerState>, Option<LayerState>)>>, None => self
Option<Vec<Tensor>>, .encoder
Option<Vec<Tensor>>, .forward_t(
Option<Vec<Tensor>>, input_ids,
Option<Vec<Tensor>>, attention_mask,
) { None,
let (encoder_hidden_states, all_encoder_hidden_states, all_encoder_attentions) = None,
match encoder_outputs { input_embeds,
Some(value) => value, &self.embeddings,
None => { None,
let ( train,
encoder_hidden_states, )
all_encoder_hidden_states, .unwrap(),
all_encoder_attentions, };
_,
) = self
.encoder
.forward_t(
input_ids,
attention_mask,
None,
None,
input_embeds,
&self.embeddings,
None,
train,
)
.unwrap();
(
encoder_hidden_states,
all_encoder_hidden_states,
all_encoder_attentions,
)
}
};
let (calculated_decoder_input_ids, calculated_decoder_input_embeds) = let (calculated_decoder_input_ids, calculated_decoder_input_embeds) =
if old_layer_states.is_some() { if old_layer_states.is_some() {
let decoder_input_ids = match decoder_input_ids { let decoder_input_ids = match decoder_input_ids {
@ -377,29 +355,28 @@ impl T5Model {
(decoder_input_ids, decoder_input_embeds) (decoder_input_ids, decoder_input_embeds)
}; };
let (decoder_outputs, all_decoder_hidden_states, all_decoder_attentions, decoder_cache) = let decoder_output = self
self.decoder .decoder
.forward_t( .forward_t(
decoder_input_ids, decoder_input_ids,
decoder_attention_mask, decoder_attention_mask,
Some(&encoder_hidden_states), Some(&encoder_output.hidden_state),
attention_mask, attention_mask,
decoder_input_embeds, decoder_input_embeds,
&self.embeddings, &self.embeddings,
old_layer_states, old_layer_states,
train, train,
) )
.unwrap(); .unwrap();
T5ModelOutput {
( decoder_hidden_state: decoder_output.hidden_state,
decoder_outputs, encoder_hidden_state: encoder_output.hidden_state,
encoder_hidden_states, next_cache: decoder_output.next_cache,
decoder_cache, decoder_all_hidden_states: decoder_output.all_hidden_states,
all_decoder_hidden_states, decoder_all_attentions: decoder_output.all_attentions,
all_decoder_attentions, encoder_all_hidden_states: encoder_output.all_hidden_states,
all_encoder_hidden_states, encoder_all_attentions: encoder_output.all_attentions,
all_encoder_attentions, }
)
} }
} }
@ -537,31 +514,15 @@ impl T5ForConditionalGeneration {
&self, &self,
input_ids: Option<&Tensor>, input_ids: Option<&Tensor>,
attention_mask: Option<&Tensor>, attention_mask: Option<&Tensor>,
encoder_outputs: Option<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>)>, encoder_outputs: Option<T5StackOutput>,
decoder_input_ids: Option<&Tensor>, decoder_input_ids: Option<&Tensor>,
decoder_attention_mask: Option<&Tensor>, decoder_attention_mask: Option<&Tensor>,
input_embeds: Option<Tensor>, input_embeds: Option<Tensor>,
decoder_input_embeds: Option<Tensor>, decoder_input_embeds: Option<Tensor>,
old_layer_states: Option<Vec<(Option<LayerState>, Option<LayerState>)>>, old_layer_states: Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
train: bool, train: bool,
) -> ( ) -> T5ModelOutput {
Tensor, let model_output = self.base_model.forward_t(
Tensor,
Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
Option<Vec<Tensor>>,
Option<Vec<Tensor>>,
Option<Vec<Tensor>>,
Option<Vec<Tensor>>,
) {
let (
decoder_outputs,
encoder_hidden_states,
decoder_cache,
all_decoder_hidden_states,
all_decoder_attentions,
all_encoder_hidden_states,
all_encoder_attentions,
) = self.base_model.forward_t(
input_ids, input_ids,
attention_mask, attention_mask,
encoder_outputs, encoder_outputs,
@ -572,23 +533,19 @@ impl T5ForConditionalGeneration {
old_layer_states, old_layer_states,
train, train,
); );
let lm_logits = decoder_outputs.linear::<Tensor>(&self.base_model.embeddings.ws, None) let lm_logits = model_output
.decoder_hidden_state
.linear::<Tensor>(&self.base_model.embeddings.ws, None)
* (self.model_dim.powf(-0.5)); * (self.model_dim.powf(-0.5));
( T5ModelOutput {
lm_logits, decoder_hidden_state: lm_logits,
encoder_hidden_states, ..model_output
decoder_cache, }
all_decoder_hidden_states,
all_decoder_attentions,
all_encoder_hidden_states,
all_encoder_attentions,
)
} }
pub fn encode(&self, input_ids: &Tensor, attention_mask: Option<&Tensor>) -> Tensor { pub fn encode(&self, input_ids: &Tensor, attention_mask: Option<&Tensor>) -> Tensor {
let (encoder_hidden_states, _, _, _) = self self.base_model
.base_model
.encoder .encoder
.forward_t( .forward_t(
Some(input_ids), Some(input_ids),
@ -600,8 +557,8 @@ impl T5ForConditionalGeneration {
None, None,
false, false,
) )
.unwrap(); .unwrap()
encoder_hidden_states .hidden_state
} }
} }
@ -685,11 +642,16 @@ impl LMHeadModel for T5ForConditionalGeneration {
decoder_input_ids: &Option<Tensor>, decoder_input_ids: &Option<Tensor>,
train: bool, train: bool,
) -> Result<LMModelOutput, &'static str> { ) -> Result<LMModelOutput, &'static str> {
let (decoder_output, encoder_hidden_states, new_cache, _, _, _, _) = match cache { let model_output = match cache {
Cache::T5Cache(cached_layer_states) => self.base_model.forward_t( Cache::T5Cache(cached_layer_states) => self.base_model.forward_t(
input_ids.as_ref(), input_ids.as_ref(),
attention_mask.as_ref(), attention_mask.as_ref(),
Some((encoder_outputs.as_ref().unwrap().copy(), None, None)), Some(T5StackOutput {
hidden_state: encoder_outputs.as_ref().unwrap().copy(),
all_hidden_states: None,
all_attentions: None,
next_cache: None,
}),
Option::from(decoder_input_ids), Option::from(decoder_input_ids),
None, None,
None, None,
@ -700,7 +662,12 @@ impl LMHeadModel for T5ForConditionalGeneration {
Cache::None => self.base_model.forward_t( Cache::None => self.base_model.forward_t(
input_ids.as_ref(), input_ids.as_ref(),
attention_mask.as_ref(), attention_mask.as_ref(),
Some((encoder_outputs.as_ref().unwrap().copy(), None, None)), Some(T5StackOutput {
hidden_state: encoder_outputs.as_ref().unwrap().copy(),
all_hidden_states: None,
all_attentions: None,
next_cache: None,
}),
Option::from(decoder_input_ids), Option::from(decoder_input_ids),
None, None,
None, None,
@ -711,15 +678,27 @@ impl LMHeadModel for T5ForConditionalGeneration {
_ => return Err("Cache not compatible with T5 Model"), _ => return Err("Cache not compatible with T5 Model"),
}; };
let lm_logits = decoder_output.linear::<Tensor>(&self.base_model.embeddings.ws, None) let lm_logits = model_output
.decoder_hidden_state
.linear::<Tensor>(&self.base_model.embeddings.ws, None)
* (self.model_dim.powf(-0.5)); * (self.model_dim.powf(-0.5));
Ok(LMModelOutput { Ok(LMModelOutput {
lm_logits, lm_logits,
encoder_hidden_state: Some(encoder_hidden_states), encoder_hidden_state: Some(model_output.encoder_hidden_state),
cache: Cache::T5Cache(new_cache), cache: Cache::T5Cache(model_output.next_cache),
all_hidden_states: None, all_hidden_states: None,
all_attentions: None, all_attentions: None,
}) })
} }
} }
pub struct T5ModelOutput {
pub decoder_hidden_state: Tensor,
pub encoder_hidden_state: Tensor,
pub next_cache: Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
pub decoder_all_hidden_states: Option<Vec<Tensor>>,
pub decoder_all_attentions: Option<Vec<Tensor>>,
pub encoder_all_hidden_states: Option<Vec<Tensor>>,
pub encoder_all_attentions: Option<Vec<Tensor>>,
}

View File

@ -414,7 +414,7 @@ fn bert_question_answering() -> anyhow::Result<()> {
let context = String::from("Amy lives in Amsterdam"); let context = String::from("Amy lives in Amsterdam");
let qa_input = QaInput { question, context }; let qa_input = QaInput { question, context };
let answers = qa_model.predict(&vec![qa_input], 1, 32); let answers = qa_model.predict(&[qa_input], 1, 32);
assert_eq!(answers.len(), 1 as usize); assert_eq!(answers.len(), 1 as usize);
assert_eq!(answers[0].len(), 1 as usize); assert_eq!(answers[0].len(), 1 as usize);

View File

@ -269,7 +269,7 @@ fn distilbert_question_answering() -> anyhow::Result<()> {
let context = String::from("Amy lives in Amsterdam"); let context = String::from("Amy lives in Amsterdam");
let qa_input = QaInput { question, context }; let qa_input = QaInput { question, context };
let answers = qa_model.predict(&vec![qa_input], 1, 32); let answers = qa_model.predict(&[qa_input], 1, 32);
assert_eq!(answers.len(), 1 as usize); assert_eq!(answers.len(), 1 as usize);
assert_eq!(answers[0].len(), 1 as usize); assert_eq!(answers[0].len(), 1 as usize);

View File

@ -357,7 +357,7 @@ fn roberta_question_answering() -> anyhow::Result<()> {
let context = String::from("Amy lives in Amsterdam"); let context = String::from("Amy lives in Amsterdam");
let qa_input = QaInput { question, context }; let qa_input = QaInput { question, context };
let answers = qa_model.predict(&vec![qa_input], 1, 32); let answers = qa_model.predict(&[qa_input], 1, 32);
assert_eq!(answers.len(), 1 as usize); assert_eq!(answers.len(), 1 as usize);
assert_eq!(answers[0].len(), 1 as usize); assert_eq!(answers[0].len(), 1 as usize);