Updated version

This commit is contained in:
Guillaume B 2020-06-05 09:45:12 +02:00
parent 40a6a148a0
commit e0772f42ae
8 changed files with 62 additions and 50 deletions

View File

@ -26,7 +26,7 @@ fn main() -> failure::Fallible<()> {
num_return_sequences: 3,
..Default::default()
};
let mut model = GPT2Generator::new(generate_config)?;
let model = GPT2Generator::new(generate_config)?;
let input_context = "The dog";
let second_input_context = "The cat was";

View File

@ -18,7 +18,7 @@ use std::time::Instant;
fn main() -> failure::Fallible<()> {
unsafe{ torch_sys::dummy_cuda_dependency(); }
let mut summarization_model = SummarizationModel::new(Default::default())?;
let summarization_model = SummarizationModel::new(Default::default())?;
let input = ["In findings published Tuesday in Cornell University's arXiv by a team of scientists \
from the University of Montreal and a separate report published Wednesday in Nature Astronomy by a team \

View File

@ -46,23 +46,37 @@ impl Clone for LayerState {
}
impl LayerState {
pub(crate) fn reorder_cache(&self, new_indices: &Tensor) -> LayerState {
let new_key = match &self.prev_key {
Some(value) => Some(value.index_select(0, new_indices)),
None => None
};
let new_value = match &self.prev_value {
Some(value) => Some(value.index_select(0, new_indices)),
None => None
};
let new_key_padding_mask = match &self.prev_key_padding_mask {
Some(value) => Some(value.index_select(0, new_indices)),
None => None
};
LayerState { prev_key: new_key, prev_value: new_value, prev_key_padding_mask: new_key_padding_mask }
pub(crate) fn reorder_cache(&mut self, new_indices: &Tensor) {
if self.prev_key.is_some() {
self.prev_key = Some(self.prev_key.as_ref().unwrap().index_select(0, new_indices));
}
if self.prev_value.is_some() {
self.prev_value = Some(self.prev_value.as_ref().unwrap().index_select(0, new_indices));
}
if self.prev_key_padding_mask.is_some() {
self.prev_key_padding_mask = Some(self.prev_key_padding_mask.as_ref().unwrap().index_select(0, new_indices));
}
}
}
// impl LayerState {
// pub(crate) fn reorder_cache(&self, new_indices: &Tensor) -> LayerState {
// let new_key = match &self.prev_key {
// Some(value) => Some(value.index_select(0, new_indices)),
// None => None
// };
// let new_value = match &self.prev_value {
// Some(value) => Some(value.index_select(0, new_indices)),
// None => None
// };
// let new_key_padding_mask = match &self.prev_key_padding_mask {
// Some(value) => Some(value.index_select(0, new_indices)),
// None => None
// };
// LayerState { prev_key: new_key, prev_value: new_value, prev_key_padding_mask: new_key_padding_mask }
// }
// }
#[derive(Debug)]
pub struct SelfAttention {

View File

@ -548,12 +548,12 @@ impl PrivateLanguageGenerator<BartForConditionalGeneration, RobertaVocab, Robert
Tensor::stack(&token_ids, 0)
}
fn reorder_cache(&self, past: Cache, encoder_outputs: Option<Tensor>, beam_indices: &Tensor) -> (Cache, Option<Tensor>) {
fn reorder_cache(&self, past: &mut Cache, encoder_outputs: Option<Tensor>, beam_indices: &Tensor) -> Option<Tensor> {
let encoder_outputs = match encoder_outputs {
Some(value) => Some(value.index_select(0, beam_indices)),
None => None
};
let new_past = match past {
match past {
Cache::BARTCache(old_cache_option) => {
match old_cache_option {
Some(old_cache) => {
@ -569,15 +569,14 @@ impl PrivateLanguageGenerator<BartForConditionalGeneration, RobertaVocab, Robert
};
new_past.push((new_self_layer_state, new_encoder_layer_state));
};
Cache::BARTCache(Some(new_past))
}
None => Cache::BARTCache(None)
None => { }
}
}
Cache::None => Cache::None,
Cache::None => {},
_ => { panic!("Invalid cache for BART model"); }
};
(new_past, encoder_outputs)
encoder_outputs
}
}
@ -741,12 +740,12 @@ impl PrivateLanguageGenerator<MarianForConditionalGeneration, MarianVocab, Maria
Tensor::stack(&token_ids, 0)
}
fn reorder_cache(&self, past: Cache, encoder_outputs: Option<Tensor>, beam_indices: &Tensor) -> (Cache, Option<Tensor>) {
fn reorder_cache(&self, past: &mut Cache, encoder_outputs: Option<Tensor>, beam_indices: &Tensor) -> Option<Tensor> {
let encoder_outputs = match encoder_outputs {
Some(value) => Some(value.index_select(0, beam_indices)),
None => None
};
let new_past = match past {
match past {
Cache::BARTCache(old_cache_option) => {
match old_cache_option {
Some(old_cache) => {
@ -762,15 +761,14 @@ impl PrivateLanguageGenerator<MarianForConditionalGeneration, MarianVocab, Maria
};
new_past.push((new_self_layer_state, new_encoder_layer_state));
};
Cache::BARTCache(Some(new_past))
}
None => Cache::BARTCache(None)
None => { }
}
}
Cache::None => Cache::None,
Cache::None => {},
_ => { panic!("Invalid cache for BART model"); }
};
(new_past, encoder_outputs)
encoder_outputs
}
}
@ -1216,9 +1214,9 @@ mod private_generation_utils {
input_ids = input_ids.index_select(0, &beam_indices);
input_ids = Tensor::cat(&[input_ids, beam_tokens.unsqueeze(1)], -1);
let temp_past = self.reorder_cache(past, encoder_outputs, &beam_indices);
past = temp_past.0;
encoder_outputs = temp_past.1;
encoder_outputs = self.reorder_cache(&mut past, encoder_outputs, &beam_indices);
// past = temp_past.0;
// encoder_outputs = temp_past.1;
if !self.is_encoder_decoder() {
attention_mask = Tensor::cat(&[attention_mask.as_ref(), Tensor::ones(&[*attention_mask.size().first().unwrap(), 1],
(Int64, attention_mask.device())).as_ref()], -1);
@ -1293,20 +1291,20 @@ mod private_generation_utils {
decoded
}
fn reorder_cache(&self, past: Cache, _encoder_outputs: Option<Tensor>, beam_indices: &Tensor)
-> (Cache, Option<Tensor>) {
fn reorder_cache(&self, past: &mut Cache, _encoder_outputs: Option<Tensor>, beam_indices: &Tensor)
-> Option<Tensor> {
match past {
Cache::None => { (Cache::None, None) }
Cache::None => { None }
Cache::GPT2Cache(cached_decoder_state) => {
match cached_decoder_state {
Some(value) => {
let mut reordered_past = vec!();
for layer_past in value.iter() {
reordered_past.push(layer_past.index_select(1, beam_indices));
// let mut reordered_past = vec!();
for layer_past in value.iter_mut() {
*layer_past = layer_past.index_select(1, beam_indices);
}
(Cache::GPT2Cache(Some(reordered_past)), None)
None
}
None => (Cache::GPT2Cache(None), None)
None => None
}
}
Cache::BARTCache(_) => { panic!("Not implemented"); }

View File

@ -69,7 +69,7 @@ fn bart_summarization_greedy() -> failure::Fallible<()> {
device: Device::Cpu,
..Default::default()
};
let mut model = SummarizationModel::new(summarization_config)?;
let model = SummarizationModel::new(summarization_config)?;
let input = ["In findings published Tuesday in Cornell University's arXiv by a team of scientists \
from the University of Montreal and a separate report published Wednesday in Nature Astronomy by a team \
@ -114,7 +114,7 @@ fn bart_summarization_beam_search() -> failure::Fallible<()> {
device: Device::Cpu,
..Default::default()
};
let mut model = SummarizationModel::new(summarization_config)?;
let model = SummarizationModel::new(summarization_config)?;
let input = ["In findings published Tuesday in Cornell University's arXiv by a team of scientists \
from the University of Montreal and a separate report published Wednesday in Nature Astronomy by a team \

View File

@ -94,7 +94,7 @@ fn gpt2_generation_greedy() -> failure::Fallible<()> {
repetition_penalty: 1.1,
..Default::default()
};
let mut model = GPT2Generator::new(generate_config)?;
let model = GPT2Generator::new(generate_config)?;
let input_context = "The cat";
let output = model.generate(Some(vec!(input_context)), None);
@ -126,7 +126,7 @@ fn gpt2_generation_beam_search() -> failure::Fallible<()> {
num_return_sequences: 3,
..Default::default()
};
let mut model = GPT2Generator::new(generate_config)?;
let model = GPT2Generator::new(generate_config)?;
let input_context = "The dog";
let output = model.generate(Some(vec!(input_context)), None);
@ -160,7 +160,7 @@ fn gpt2_generation_beam_search_multiple_prompts_without_padding() -> failure::Fa
num_return_sequences: 3,
..Default::default()
};
let mut model = GPT2Generator::new(generate_config)?;
let model = GPT2Generator::new(generate_config)?;
let input_context_1 = "The dog";
let input_context_2 = "The cat";
@ -198,7 +198,7 @@ fn gpt2_generation_beam_search_multiple_prompts_with_padding() -> failure::Falli
num_return_sequences: 3,
..Default::default()
};
let mut model = GPT2Generator::new(generate_config)?;
let model = GPT2Generator::new(generate_config)?;
let input_context_1 = "The dog";
let input_context_2 = "The cat was";

View File

@ -7,7 +7,7 @@ fn test_translation() -> failure::Fallible<()> {
// Set-up translation model
let translation_config = TranslationConfig::new(Language::EnglishToFrench, Device::Cpu);
let mut model = TranslationModel::new(translation_config)?;
let model = TranslationModel::new(translation_config)?;
let input_context_1 = "The quick brown fox jumps over the lazy dog";
let input_context_2 = "The dog did not wake up";

View File

@ -87,7 +87,7 @@ fn openai_gpt_generation_greedy() -> failure::Fallible<()> {
temperature: 1.1,
..Default::default()
};
let mut model = OpenAIGenerator::new(generate_config)?;
let model = OpenAIGenerator::new(generate_config)?;
let input_context = "It was an intense machine dialogue. ";
let output = model.generate(Some(vec!(input_context)), None);
@ -119,7 +119,7 @@ fn openai_gpt_generation_beam_search() -> failure::Fallible<()> {
num_return_sequences: 3,
..Default::default()
};
let mut model = OpenAIGenerator::new(generate_config)?;
let model = OpenAIGenerator::new(generate_config)?;
let input_context = "The dog is";
let output = model.generate(Some(vec!(input_context)), None);
@ -153,7 +153,7 @@ fn openai_gpt_generation_beam_search_multiple_prompts_without_padding() -> failu
num_return_sequences: 3,
..Default::default()
};
let mut model = OpenAIGenerator::new(generate_config)?;
let model = OpenAIGenerator::new(generate_config)?;
let input_context_1 = "The dog is";
let input_context_2 = "The cat";
@ -194,7 +194,7 @@ fn openai_gpt_generation_beam_search_multiple_prompts_with_padding() -> failure:
num_return_sequences: 3,
..Default::default()
};
let mut model = OpenAIGenerator::new(generate_config)?;
let model = OpenAIGenerator::new(generate_config)?;
let input_context_1 = "The dog is";
let input_context_2 = "The cat was in";