Updated T5 (clippy)

This commit is contained in:
Guillaume B 2020-09-13 12:25:22 +02:00
parent b42cc60409
commit 47d9a1017d
9 changed files with 154 additions and 175 deletions

View File

@ -33,7 +33,7 @@ fn main() -> anyhow::Result<()> {
};
// Get answer
let answers = qa_model.predict(&vec![qa_input_1, qa_input_2], 1, 32);
let answers = qa_model.predict(&[qa_input_1, qa_input_2], 1, 32);
println!("{:?}", answers);
Ok(())
}

View File

@ -51,7 +51,7 @@ fn main() -> anyhow::Result<()> {
};
// Get answer
let answers = qa_model.predict(&vec![qa_input_1, qa_input_2], 1, 32);
let answers = qa_model.predict(&[qa_input_1, qa_input_2], 1, 32);
println!("{:?}", answers);
Ok(())
}

View File

@ -27,8 +27,10 @@ fn main() -> anyhow::Result<()> {
// Run model
let output = sequence_classification_model.predict_multilabel(&input, 0.05);
for label in output {
println!("{:?}", label);
if let Ok(labels) = output {
for label in labels {
println!("{:?}", label);
}
}
Ok(())

View File

@ -199,9 +199,10 @@ impl T5Attention {
temp_value = temp_value.slice(2, length - 1, length, 1);
};
if let Some(attention_mask) = attention_mask {
temp_value += attention_mask;
};
Some(temp_value)
Some(temp_value + attention_mask)
} else {
Some(temp_value)
}
} else {
None
};

View File

@ -148,12 +148,7 @@ impl T5Block {
encoder_decoder_position_bias: Option<&Tensor>,
mut layer_states: (Option<LayerState>, Option<LayerState>),
train: bool,
) -> (
Tensor,
(Option<Tensor>, Option<Tensor>),
(Option<Tensor>, Option<Tensor>),
(Option<LayerState>, Option<LayerState>),
) {
) -> T5BlockOutput {
let (
hidden_states,
self_attention_weights,
@ -190,17 +185,17 @@ impl T5Block {
(hidden_states, None, None, None)
};
let attention_weights = (self_attention_weights, cross_attention_weights);
let position_bias = (self_attention_position_bias, cross_attention_position_bias);
layer_states = (self_attention_layer_past, cross_attention_layer_past);
let hidden_states = self.ff_layer.forward_t(&hidden_states, train);
(
T5BlockOutput {
hidden_states,
attention_weights,
position_bias,
layer_states,
)
self_attention_weights,
cross_attention_weights,
self_attention_position_bias,
cross_attention_position_bias,
cache: layer_states,
}
}
}
@ -269,15 +264,7 @@ impl T5Stack {
embeddings: &nn::Embedding,
old_layer_states: Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
train: bool,
) -> Result<
(
Tensor,
Option<Vec<Tensor>>,
Option<Vec<Tensor>>,
Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
),
&'static str,
> {
) -> Result<T5StackOutput, &'static str> {
let (input_embeddings, input_shape) = match input_ids {
Some(input_ids_value) => match input_embeds {
Some(_) => {
@ -332,7 +319,7 @@ impl T5Stack {
if self.is_decoder {
let seq_ids =
Tensor::arange(input_shape[1], (Kind::Float, input_embeddings.device()));
let causal_mask = seq_ids.unsqueeze(0).unsqueeze(0).repeat(&vec![
let causal_mask = seq_ids.unsqueeze(0).unsqueeze(0).repeat(&[
input_shape[0],
input_shape[1],
1,
@ -386,7 +373,7 @@ impl T5Stack {
} else {
None
};
let mut next_decoder_cache: Option<Vec<(Option<LayerState>, Option<LayerState>)>> =
let mut next_cache: Option<Vec<(Option<LayerState>, Option<LayerState>)>> =
if self.store_cache {
if old_layer_states.is_some() {
old_layer_states
@ -400,42 +387,36 @@ impl T5Stack {
let mut encoder_decoder_position_bias = None;
let mut attention_weights: Option<Tensor>;
let mut hidden_state = input_embeddings.apply_t(&self.dropout, train);
let mut blocks = self.blocks.iter().enumerate();
loop {
match blocks.next() {
Some((layer_idx, layer)) => {
let layer_state = match &next_decoder_cache {
Some(values) => values[layer_idx].to_owned(),
None => (None, None),
};
let temp = layer.forward_t(
&hidden_state,
position_bias.as_ref(),
extended_attention_mask.as_ref(),
encoder_hidden_states,
extended_encoder_attention_mask.as_ref(),
encoder_decoder_position_bias.as_ref(),
layer_state,
train,
);
if layer_idx == 0 {
position_bias = (temp.2).0;
encoder_decoder_position_bias = (temp.2).1;
}
hidden_state = temp.0;
attention_weights = (temp.1).1;
if let Some(hidden_states) = all_hidden_states.borrow_mut() {
hidden_states.push(hidden_state.as_ref().copy().transpose(0, 1));
};
if let Some(attentions) = all_attentions.borrow_mut() {
attentions.push(attention_weights.as_ref().unwrap().copy());
};
if let Some(value) = &mut next_decoder_cache {
value[layer_idx] = temp.3
};
}
None => break,
for (layer_idx, layer) in self.blocks.iter().enumerate() {
let layer_state = match &next_cache {
Some(values) => values[layer_idx].to_owned(),
None => (None, None),
};
let block_output = layer.forward_t(
&hidden_state,
position_bias.as_ref(),
extended_attention_mask.as_ref(),
encoder_hidden_states,
extended_encoder_attention_mask.as_ref(),
encoder_decoder_position_bias.as_ref(),
layer_state,
train,
);
if layer_idx == 0 {
position_bias = block_output.self_attention_position_bias;
encoder_decoder_position_bias = block_output.cross_attention_position_bias;
}
hidden_state = block_output.hidden_states;
attention_weights = block_output.cross_attention_weights;
if let Some(hidden_states) = all_hidden_states.borrow_mut() {
hidden_states.push(hidden_state.as_ref().copy().transpose(0, 1));
};
if let Some(attentions) = all_attentions.borrow_mut() {
attentions.push(attention_weights.as_ref().unwrap().copy());
};
if let Some(value) = &mut next_cache {
value[layer_idx] = block_output.cache
};
}
@ -443,11 +424,27 @@ impl T5Stack {
.apply(&self.final_layer_norm)
.apply_t(&self.dropout, train);
Ok((
Ok(T5StackOutput {
hidden_state,
all_hidden_states,
all_attentions,
next_decoder_cache,
))
next_cache,
})
}
}
pub struct T5BlockOutput {
pub hidden_states: Tensor,
pub self_attention_weights: Option<Tensor>,
pub cross_attention_weights: Option<Tensor>,
pub self_attention_position_bias: Option<Tensor>,
pub cross_attention_position_bias: Option<Tensor>,
pub cache: (Option<LayerState>, Option<LayerState>),
}
pub struct T5StackOutput {
pub hidden_state: Tensor,
pub all_hidden_states: Option<Vec<Tensor>>,
pub all_attentions: Option<Vec<Tensor>>,
pub next_cache: Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
}

View File

@ -11,7 +11,7 @@
// limitations under the License.
use crate::pipelines::generation::{Cache, LMHeadModel, LMModelOutput};
use crate::t5::attention::LayerState;
use crate::t5::encoder::T5Stack;
use crate::t5::encoder::{T5Stack, T5StackOutput};
use crate::Config;
use serde::{Deserialize, Serialize};
use std::borrow::Borrow;
@ -308,52 +308,30 @@ impl T5Model {
&self,
input_ids: Option<&Tensor>,
attention_mask: Option<&Tensor>,
encoder_outputs: Option<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>)>,
encoder_outputs: Option<T5StackOutput>,
decoder_input_ids: Option<&Tensor>,
decoder_attention_mask: Option<&Tensor>,
input_embeds: Option<Tensor>,
decoder_input_embeds: Option<Tensor>,
old_layer_states: Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
train: bool,
) -> (
Tensor,
Tensor,
Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
Option<Vec<Tensor>>,
Option<Vec<Tensor>>,
Option<Vec<Tensor>>,
Option<Vec<Tensor>>,
) {
let (encoder_hidden_states, all_encoder_hidden_states, all_encoder_attentions) =
match encoder_outputs {
Some(value) => value,
None => {
let (
encoder_hidden_states,
all_encoder_hidden_states,
all_encoder_attentions,
_,
) = self
.encoder
.forward_t(
input_ids,
attention_mask,
None,
None,
input_embeds,
&self.embeddings,
None,
train,
)
.unwrap();
(
encoder_hidden_states,
all_encoder_hidden_states,
all_encoder_attentions,
)
}
};
) -> T5ModelOutput {
let encoder_output = match encoder_outputs {
Some(value) => value,
None => self
.encoder
.forward_t(
input_ids,
attention_mask,
None,
None,
input_embeds,
&self.embeddings,
None,
train,
)
.unwrap(),
};
let (calculated_decoder_input_ids, calculated_decoder_input_embeds) =
if old_layer_states.is_some() {
let decoder_input_ids = match decoder_input_ids {
@ -377,29 +355,28 @@ impl T5Model {
(decoder_input_ids, decoder_input_embeds)
};
let (decoder_outputs, all_decoder_hidden_states, all_decoder_attentions, decoder_cache) =
self.decoder
.forward_t(
decoder_input_ids,
decoder_attention_mask,
Some(&encoder_hidden_states),
attention_mask,
decoder_input_embeds,
&self.embeddings,
old_layer_states,
train,
)
.unwrap();
(
decoder_outputs,
encoder_hidden_states,
decoder_cache,
all_decoder_hidden_states,
all_decoder_attentions,
all_encoder_hidden_states,
all_encoder_attentions,
)
let decoder_output = self
.decoder
.forward_t(
decoder_input_ids,
decoder_attention_mask,
Some(&encoder_output.hidden_state),
attention_mask,
decoder_input_embeds,
&self.embeddings,
old_layer_states,
train,
)
.unwrap();
T5ModelOutput {
decoder_hidden_state: decoder_output.hidden_state,
encoder_hidden_state: encoder_output.hidden_state,
next_cache: decoder_output.next_cache,
decoder_all_hidden_states: decoder_output.all_hidden_states,
decoder_all_attentions: decoder_output.all_attentions,
encoder_all_hidden_states: encoder_output.all_hidden_states,
encoder_all_attentions: encoder_output.all_attentions,
}
}
}
@ -537,31 +514,15 @@ impl T5ForConditionalGeneration {
&self,
input_ids: Option<&Tensor>,
attention_mask: Option<&Tensor>,
encoder_outputs: Option<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>)>,
encoder_outputs: Option<T5StackOutput>,
decoder_input_ids: Option<&Tensor>,
decoder_attention_mask: Option<&Tensor>,
input_embeds: Option<Tensor>,
decoder_input_embeds: Option<Tensor>,
old_layer_states: Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
train: bool,
) -> (
Tensor,
Tensor,
Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
Option<Vec<Tensor>>,
Option<Vec<Tensor>>,
Option<Vec<Tensor>>,
Option<Vec<Tensor>>,
) {
let (
decoder_outputs,
encoder_hidden_states,
decoder_cache,
all_decoder_hidden_states,
all_decoder_attentions,
all_encoder_hidden_states,
all_encoder_attentions,
) = self.base_model.forward_t(
) -> T5ModelOutput {
let model_output = self.base_model.forward_t(
input_ids,
attention_mask,
encoder_outputs,
@ -572,23 +533,19 @@ impl T5ForConditionalGeneration {
old_layer_states,
train,
);
let lm_logits = decoder_outputs.linear::<Tensor>(&self.base_model.embeddings.ws, None)
let lm_logits = model_output
.decoder_hidden_state
.linear::<Tensor>(&self.base_model.embeddings.ws, None)
* (self.model_dim.powf(-0.5));
(
lm_logits,
encoder_hidden_states,
decoder_cache,
all_decoder_hidden_states,
all_decoder_attentions,
all_encoder_hidden_states,
all_encoder_attentions,
)
T5ModelOutput {
decoder_hidden_state: lm_logits,
..model_output
}
}
pub fn encode(&self, input_ids: &Tensor, attention_mask: Option<&Tensor>) -> Tensor {
let (encoder_hidden_states, _, _, _) = self
.base_model
self.base_model
.encoder
.forward_t(
Some(input_ids),
@ -600,8 +557,8 @@ impl T5ForConditionalGeneration {
None,
false,
)
.unwrap();
encoder_hidden_states
.unwrap()
.hidden_state
}
}
@ -685,11 +642,16 @@ impl LMHeadModel for T5ForConditionalGeneration {
decoder_input_ids: &Option<Tensor>,
train: bool,
) -> Result<LMModelOutput, &'static str> {
let (decoder_output, encoder_hidden_states, new_cache, _, _, _, _) = match cache {
let model_output = match cache {
Cache::T5Cache(cached_layer_states) => self.base_model.forward_t(
input_ids.as_ref(),
attention_mask.as_ref(),
Some((encoder_outputs.as_ref().unwrap().copy(), None, None)),
Some(T5StackOutput {
hidden_state: encoder_outputs.as_ref().unwrap().copy(),
all_hidden_states: None,
all_attentions: None,
next_cache: None,
}),
Option::from(decoder_input_ids),
None,
None,
@ -700,7 +662,12 @@ impl LMHeadModel for T5ForConditionalGeneration {
Cache::None => self.base_model.forward_t(
input_ids.as_ref(),
attention_mask.as_ref(),
Some((encoder_outputs.as_ref().unwrap().copy(), None, None)),
Some(T5StackOutput {
hidden_state: encoder_outputs.as_ref().unwrap().copy(),
all_hidden_states: None,
all_attentions: None,
next_cache: None,
}),
Option::from(decoder_input_ids),
None,
None,
@ -711,15 +678,27 @@ impl LMHeadModel for T5ForConditionalGeneration {
_ => return Err("Cache not compatible with T5 Model"),
};
let lm_logits = decoder_output.linear::<Tensor>(&self.base_model.embeddings.ws, None)
let lm_logits = model_output
.decoder_hidden_state
.linear::<Tensor>(&self.base_model.embeddings.ws, None)
* (self.model_dim.powf(-0.5));
Ok(LMModelOutput {
lm_logits,
encoder_hidden_state: Some(encoder_hidden_states),
cache: Cache::T5Cache(new_cache),
encoder_hidden_state: Some(model_output.encoder_hidden_state),
cache: Cache::T5Cache(model_output.next_cache),
all_hidden_states: None,
all_attentions: None,
})
}
}
pub struct T5ModelOutput {
pub decoder_hidden_state: Tensor,
pub encoder_hidden_state: Tensor,
pub next_cache: Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
pub decoder_all_hidden_states: Option<Vec<Tensor>>,
pub decoder_all_attentions: Option<Vec<Tensor>>,
pub encoder_all_hidden_states: Option<Vec<Tensor>>,
pub encoder_all_attentions: Option<Vec<Tensor>>,
}

View File

@ -414,7 +414,7 @@ fn bert_question_answering() -> anyhow::Result<()> {
let context = String::from("Amy lives in Amsterdam");
let qa_input = QaInput { question, context };
let answers = qa_model.predict(&vec![qa_input], 1, 32);
let answers = qa_model.predict(&[qa_input], 1, 32);
assert_eq!(answers.len(), 1 as usize);
assert_eq!(answers[0].len(), 1 as usize);

View File

@ -269,7 +269,7 @@ fn distilbert_question_answering() -> anyhow::Result<()> {
let context = String::from("Amy lives in Amsterdam");
let qa_input = QaInput { question, context };
let answers = qa_model.predict(&vec![qa_input], 1, 32);
let answers = qa_model.predict(&[qa_input], 1, 32);
assert_eq!(answers.len(), 1 as usize);
assert_eq!(answers[0].len(), 1 as usize);

View File

@ -357,7 +357,7 @@ fn roberta_question_answering() -> anyhow::Result<()> {
let context = String::from("Amy lives in Amsterdam");
let qa_input = QaInput { question, context };
let answers = qa_model.predict(&vec![qa_input], 1, 32);
let answers = qa_model.predict(&[qa_input], 1, 32);
assert_eq!(answers.len(), 1 as usize);
assert_eq!(answers[0].len(), 1 as usize);