non optional LayerState fields

This commit is contained in:
Guillaume B 2020-06-05 20:49:26 +02:00
parent e0772f42ae
commit db3d126a2b
6 changed files with 34 additions and 75 deletions

View File

@ -44,13 +44,13 @@ about exoplanets like K2-18b."];
// Credits: WikiNews, CC BY 2.5 license (https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b)
let now = Instant::now();
for _ in 0..3 {
let output = summarization_model.summarize(&input);
for sentence in output {
for _ in 0..5 {
let _output = summarization_model.summarize(&input);
for sentence in _output {
println!("{:?}", sentence);
}
}
println!("{:?}", now.elapsed().as_millis() / 3);
println!("{:?}", now.elapsed().as_millis() / 5);
Ok(())
}

View File

@ -20,63 +20,37 @@ use tch::kind::Kind::Float;
/// Stores the cached value of key, value and key padding mask to avoid recalculation (e.g. at each generation step)
pub struct LayerState {
/// Cached keys
pub prev_key: Option<Tensor>,
pub prev_key: Tensor,
/// Cached values
pub prev_value: Option<Tensor>,
pub prev_value: Tensor,
/// Cached keys padding mask
pub prev_key_padding_mask: Option<Tensor>,
}
impl Clone for LayerState {
fn clone(&self) -> Self {
let prev_key = match &self.prev_key {
Some(key) => Some(key.copy()),
None => None
};
let prev_value = match &self.prev_value {
Some(value) => Some(value.copy()),
None => None
};
let prev_key_padding_mask = match &self.prev_key_padding_mask {
Some(key_padding_mask) => Some(key_padding_mask.copy()),
None => None
};
LayerState { prev_key, prev_value, prev_key_padding_mask }
LayerState {
prev_key: self.prev_key.copy(),
prev_value: self.prev_value.copy(),
prev_key_padding_mask,
}
}
}
impl LayerState {
pub(crate) fn reorder_cache(&mut self, new_indices: &Tensor) {
if self.prev_key.is_some() {
self.prev_key = Some(self.prev_key.as_ref().unwrap().index_select(0, new_indices));
}
if self.prev_value.is_some() {
self.prev_value = Some(self.prev_value.as_ref().unwrap().index_select(0, new_indices));
}
self.prev_key = self.prev_key.index_select(0, new_indices);
self.prev_value = self.prev_value.index_select(0, new_indices);
if self.prev_key_padding_mask.is_some() {
self.prev_key_padding_mask = Some(self.prev_key_padding_mask.as_ref().unwrap().index_select(0, new_indices));
}
}
}
// impl LayerState {
// pub(crate) fn reorder_cache(&self, new_indices: &Tensor) -> LayerState {
// let new_key = match &self.prev_key {
// Some(value) => Some(value.index_select(0, new_indices)),
// None => None
// };
// let new_value = match &self.prev_value {
// Some(value) => Some(value.index_select(0, new_indices)),
// None => None
// };
// let new_key_padding_mask = match &self.prev_key_padding_mask {
// Some(value) => Some(value.index_select(0, new_indices)),
// None => None
// };
// LayerState { prev_key: new_key, prev_value: new_value, prev_key_padding_mask: new_key_padding_mask }
// }
// }
#[derive(Debug)]
pub struct SelfAttention {
@ -134,13 +108,7 @@ impl SelfAttention {
let (target_sequence_length, bs) = (query_size[0], query_size[1]);
let q: Tensor = self.flatten(query.as_ref().apply(&self.q_proj) * self.scaling, target_sequence_length, bs);
let key = match &old_layer_state {
Some(prev_state) => {
if prev_state.prev_key.is_some() & self.encoder_decoder_attention {
None
} else {
key
}
}
Some(_) => { if self.encoder_decoder_attention { None } else { key } }
None => key
};
@ -163,8 +131,8 @@ impl SelfAttention {
let new_layer_state = if self.store_cache {
Some(LayerState {
prev_key: Some(k.view((bs, self.num_heads, -1, self.head_dim))),
prev_value: Some(v.view((bs, self.num_heads, -1, self.head_dim))),
prev_key: k.view((bs, self.num_heads, -1, self.head_dim)),
prev_value: v.view((bs, self.num_heads, -1, self.head_dim)),
prev_key_padding_mask: match key_padding_mask.as_ref() {
Some(tensor) => Some(tensor.copy()),
None => None
@ -219,28 +187,19 @@ impl SelfAttention {
-> (Tensor, Tensor, Option<Tensor>) {
match &layer_state {
Some(prev_state) => {
let k = match &prev_state.prev_key {
Some(prev_key) => {
let prev_key = prev_key.view((bs * self.num_heads, -1, self.head_dim));
if self.encoder_decoder_attention {
prev_key
} else {
Tensor::cat(&[prev_key, k.unwrap()], 1)
}
}
None => k.unwrap()
let prev_key = prev_state.prev_key.view((bs * self.num_heads, -1, self.head_dim));
let prev_value = prev_state.prev_value.view((bs * self.num_heads, -1, self.head_dim));
let k = if self.encoder_decoder_attention {
prev_key
} else {
Tensor::cat(&[prev_key, k.unwrap()], 1)
};
let v = match &prev_state.prev_value {
Some(prev_value) => {
let prev_value = prev_value.view((bs * self.num_heads, -1, self.head_dim));
if self.encoder_decoder_attention {
prev_value
} else {
Tensor::cat(&[prev_value, v.unwrap()], 1)
}
}
None => v.unwrap()
let v = if self.encoder_decoder_attention {
prev_value
} else {
Tensor::cat(&[prev_value, v.unwrap()], 1)
};
let key_padding_mask = self.use_saved_key_padding_mask(key_padding_mask,
&prev_state.prev_key_padding_mask,
bs,

View File

@ -407,7 +407,7 @@ impl BartForConditionalGeneration {
/// let encoder_attention_mask = Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
/// let decoder_attention_mask = Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
///
/// let (decoder_output, encoder_hidden_states,
/// let (decoder_output, encoder_hidden_states, cache,
/// all_encoder_hidden_states, all_encoder_attentions,
/// all_decoder_hidden_states, all_decoder_attentions) = no_grad(|| {
/// bart_model
@ -565,7 +565,7 @@ impl BartForSequenceClassification {
/// let encoder_attention_mask = Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
/// let decoder_attention_mask = Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
///
/// let (decoder_output, encoder_hidden_states,
/// let (decoder_output, encoder_hidden_states, cache,
/// all_encoder_hidden_states, all_encoder_attentions,
/// all_decoder_hidden_states, all_decoder_attentions) = no_grad(|| {
/// bart_model
@ -655,7 +655,7 @@ impl LMHeadModel for BartForConditionalGeneration {
/// let encoder_attention_mask = Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
/// let decoder_attention_mask = Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
///
/// let (decoder_output, encoder_hidden_states,
/// let (decoder_output, encoder_hidden_states, cache,
/// all_encoder_hidden_states, all_encoder_attentions,
/// all_decoder_hidden_states, all_decoder_attentions) = no_grad(|| {
/// bart_model

View File

@ -433,7 +433,7 @@ impl LMHeadModel for GPT2LMHeadModel {
/// let (output, _, past, hidden_states, attentions) = no_grad(|| {
/// gpt2_model
/// .forward_t(&Some(input_tensor),
/// &Cache::GPT2Cache(Some(past)),
/// Cache::GPT2Cache(Some(past)),
/// &Some(attention_mask),
/// &Some(token_type_ids),
/// &Some(position_ids),

View File

@ -211,7 +211,7 @@ impl MarianForConditionalGeneration {
/// let encoder_attention_mask = Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
/// let decoder_attention_mask = Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
///
/// let (decoder_output, encoder_hidden_states,
/// let (decoder_output, encoder_hidden_states, cache,
/// all_encoder_hidden_states, all_encoder_attentions,
/// all_decoder_hidden_states, all_decoder_attentions) = no_grad(|| {
/// marian_model
@ -301,7 +301,7 @@ impl LMHeadModel for MarianForConditionalGeneration {
/// let encoder_attention_mask = Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
/// let decoder_attention_mask = Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
///
/// let (decoder_output, encoder_hidden_states,
/// let (decoder_output, encoder_hidden_states, cache,
/// all_encoder_hidden_states, all_encoder_attentions,
/// all_decoder_hidden_states, all_decoder_attentions) = no_grad(|| {
/// marian_model

View File

@ -327,7 +327,7 @@ impl LMHeadModel for OpenAIGPTLMHeadModel {
/// let (output, _, _, hidden_states, attentions) = no_grad(|| {
/// gpt_model
/// .forward_t(&Some(input_tensor),
/// &Cache::None,
/// Cache::None,
/// &Some(attention_mask),
/// &Some(token_type_ids),
/// &Some(position_ids),