End-to-end summarization

This commit is contained in:
Guillaume B 2020-04-04 17:11:37 +02:00
parent fa87fce96e
commit 0a8832a012
5 changed files with 126 additions and 42 deletions

View File

@ -40,12 +40,12 @@ fn main() -> failure::Fallible<()> {
let device = Device::cuda_if_available();
let generate_config = GenerateConfig {
max_length: 142,
do_sample: true,
do_sample: false,
num_beams: 3,
temperature: 1.0,
top_k: 50,
top_p: 1.0,
length_penalty: 2.0,
length_penalty: 1.0,
min_length: 56,
num_return_sequences: 1,
..Default::default()

View File

@ -17,9 +17,23 @@ use tch::kind::Kind::Float;
#[derive(Debug)]
pub struct LayerState {
prev_key: Option<Tensor>,
prev_value: Option<Tensor>,
prev_key_padding_mask: Option<Tensor>,
pub prev_key: Option<Tensor>,
pub prev_value: Option<Tensor>,
pub prev_key_padding_mask: Option<Tensor>,
}
impl LayerState {
pub fn reorder_cache(&mut self, new_indices: &Tensor) {
if self.prev_key.is_some() {
self.prev_key = Some(self.prev_key.as_ref().unwrap().index_select(0, new_indices));
}
if self.prev_value.is_some() {
self.prev_value = Some(self.prev_value.as_ref().unwrap().index_select(0, new_indices));
}
if self.prev_key_padding_mask.is_some() {
self.prev_key_padding_mask = Some(self.prev_key_padding_mask.as_ref().unwrap().index_select(0, new_indices));
}
}
}
@ -112,8 +126,8 @@ impl SelfAttention {
self.prev_state = match &self.prev_state {
Some(_) => Some(LayerState {
prev_key: Some(k.copy()),
prev_value: Some(v.copy()),
prev_key: Some(k.copy().view((bs, self.num_heads, -1, self.head_dim))),
prev_value: Some(v.copy().view((bs, self.num_heads, -1, self.head_dim))),
prev_key_padding_mask: match key_padding_mask.as_ref() {
Some(tensor) => Some(tensor.copy()),
None => None

View File

@ -162,6 +162,8 @@ impl BartModel {
BartModel { encoder, decoder, generation_mode, pad_token_id }
}
pub fn get_decoder(&mut self) -> &mut BartDecoder { &mut self.decoder }
pub fn forward_t(&mut self,
input_ids: Option<&Tensor>,
attention_mask: Option<&Tensor>,
@ -236,6 +238,8 @@ impl BartForConditionalGeneration {
all_encoder_hidden_states, all_encoder_attentions)
}
pub fn get_base_model(&mut self) -> &mut BartModel { &mut self.base_model }
pub fn encode(&mut self, input_ids: &Tensor, attention_mask: Option<&Tensor>) -> Tensor {
let (encoder_hidden_states, _, _) = self.base_model.encoder.forward_t(input_ids, attention_mask, false);
encoder_hidden_states
@ -332,6 +336,6 @@ impl LMHeadModel for BartForConditionalGeneration {
train);
let lm_logits = decoder_output.linear::<Tensor>(&self.base_model.encoder.embed_tokens.ws, None);
Ok((lm_logits, Some(encoder_hidden_states), None, None, None))
Ok((lm_logits, Some(encoder_hidden_states), None, None, None))
}
}

View File

@ -96,6 +96,9 @@ impl DecoderLayer {
}
}
pub fn get_self_attention(&mut self) -> &mut SelfAttention { &mut self.self_attention }
pub fn get_encoder_attention(&mut self) -> &mut SelfAttention { &mut self.encoder_attention }
pub fn forward_t(&mut self,
x: &Tensor,
encoder_hidden_states: &Tensor,
@ -186,17 +189,19 @@ impl BartDecoder {
}
}
pub fn get_layers(&mut self) -> &mut Vec<DecoderLayer> { &mut self.layers }
pub fn forward_t(&mut self,
input_ids: &Tensor,
encoder_hidden_states: &Tensor,
encoder_padding_mask: Option<&Tensor>,
decoder_padding_mask: Option<&Tensor>,
decoder_causal_mask: Option<&Tensor>,
train: bool)
-> (Tensor,
(Option<Tensor>, Option<Vec<(&LayerState, &LayerState)>>),
Option<Vec<Tensor>>,
Option<Vec<Tensor>>) {
input_ids: &Tensor,
encoder_hidden_states: &Tensor,
encoder_padding_mask: Option<&Tensor>,
decoder_padding_mask: Option<&Tensor>,
decoder_causal_mask: Option<&Tensor>,
train: bool)
-> (Tensor,
(Option<Tensor>, Option<Vec<(&LayerState, &LayerState)>>),
Option<Vec<Tensor>>,
Option<Vec<Tensor>>) {
let encoder_padding_mask = match encoder_padding_mask {
Some(mask) => Some(mask.eq(0).to_kind(Int64)),
None => None

View File

@ -159,6 +159,8 @@ pub struct OpenAIGenerator {
eos_token_ids: Option<Vec<i64>>,
pad_token_id: Option<i64>,
is_encoder_decoder: bool,
vocab_size: i64,
decoder_start_id: Option<i64>,
}
impl OpenAIGenerator {
@ -215,8 +217,10 @@ impl OpenAIGenerator {
let eos_token_ids = None;
let pad_token_id = None;
let is_encoder_decoder = false;
let vocab_size = config.vocab_size;
let decoder_start_id = None;
Ok(OpenAIGenerator { model, tokenizer, var_store, generate_config, bos_token_id, eos_token_ids, pad_token_id, is_encoder_decoder })
Ok(OpenAIGenerator { model, tokenizer, var_store, generate_config, bos_token_id, eos_token_ids, pad_token_id, is_encoder_decoder, vocab_size, decoder_start_id })
}
}
@ -229,6 +233,8 @@ impl PrivateLanguageGenerator<OpenAIGPTLMHeadModel, OpenAiGptVocab, OpenAiGptTok
fn get_eos_ids(&self) -> &Option<Vec<i64>> { &self.eos_token_ids }
fn get_pad_id(&self) -> &Option<i64> { &self.pad_token_id }
fn is_encoder_decoder(&self) -> bool { self.is_encoder_decoder }
fn get_vocab_size(&self) -> i64 { self.vocab_size }
fn get_decoder_start_id(&self) -> Option<i64> { self.decoder_start_id }
}
impl LanguageGenerator<OpenAIGPTLMHeadModel, OpenAiGptVocab, OpenAiGptTokenizer> for OpenAIGenerator {}
@ -243,6 +249,8 @@ pub struct GPT2Generator {
eos_token_ids: Option<Vec<i64>>,
pad_token_id: Option<i64>,
is_encoder_decoder: bool,
vocab_size: i64,
decoder_start_id: Option<i64>,
}
impl GPT2Generator {
@ -299,8 +307,10 @@ impl GPT2Generator {
let eos_token_ids = Some(vec!(tokenizer.vocab().token_to_id(Gpt2Vocab::eos_value())));
let pad_token_id = None;
let is_encoder_decoder = false;
let vocab_size = config.vocab_size;
let decoder_start_id = None;
Ok(GPT2Generator { model, tokenizer, var_store, generate_config, bos_token_id, eos_token_ids, pad_token_id, is_encoder_decoder })
Ok(GPT2Generator { model, tokenizer, var_store, generate_config, bos_token_id, eos_token_ids, pad_token_id, is_encoder_decoder, vocab_size, decoder_start_id })
}
}
@ -313,6 +323,8 @@ impl PrivateLanguageGenerator<GPT2LMHeadModel, Gpt2Vocab, Gpt2Tokenizer> for GPT
fn get_eos_ids(&self) -> &Option<Vec<i64>> { &self.eos_token_ids }
fn get_pad_id(&self) -> &Option<i64> { &self.pad_token_id }
fn is_encoder_decoder(&self) -> bool { self.is_encoder_decoder }
fn get_vocab_size(&self) -> i64 { self.vocab_size }
fn get_decoder_start_id(&self) -> Option<i64> { self.decoder_start_id }
fn prepare_inputs_for_generation<'a>(&self,
input_ids: Tensor,
@ -340,6 +352,8 @@ pub struct BartGenerator {
eos_token_ids: Option<Vec<i64>>,
pad_token_id: Option<i64>,
is_encoder_decoder: bool,
vocab_size: i64,
decoder_start_id: Option<i64>,
}
impl BartGenerator {
@ -392,7 +406,7 @@ impl BartGenerator {
let model = BartForConditionalGeneration::new(&var_store.root(), &config, true);
var_store.load(weight_path)?;
let bos_token_id = Some(2);
let bos_token_id = Some(0);
let eos_token_ids = Some(match config.eos_token_id {
Some(value) => vec!(value),
None => vec!(2)
@ -401,9 +415,19 @@ impl BartGenerator {
Some(value) => value,
None => 1
});
let vocab_size = config.vocab_size;
let is_encoder_decoder = true;
let decoder_start_id = Some(2);
Ok(BartGenerator { model, tokenizer, var_store, generate_config, bos_token_id, eos_token_ids, pad_token_id, is_encoder_decoder })
Ok(BartGenerator { model, tokenizer, var_store, generate_config, bos_token_id, eos_token_ids, pad_token_id, is_encoder_decoder, vocab_size, decoder_start_id })
}
fn force_token_id_generation(&self, scores: &mut Tensor, token_ids: &[i64]) {
let impossible_tokens: Vec<i64> = (0..self.get_vocab_size() as i64)
.filter(|pos| !token_ids.contains(pos))
.collect();
let impossible_tokens = Tensor::of_slice(&impossible_tokens).to_device(scores.device());
let _ = scores.index_fill_(1, &impossible_tokens, std::f64::NEG_INFINITY);
}
}
@ -416,6 +440,17 @@ impl PrivateLanguageGenerator<BartForConditionalGeneration, RobertaVocab, Robert
fn get_eos_ids(&self) -> &Option<Vec<i64>> { &self.eos_token_ids }
fn get_pad_id(&self) -> &Option<i64> { &self.pad_token_id }
fn is_encoder_decoder(&self) -> bool { self.is_encoder_decoder }
fn get_vocab_size(&self) -> i64 { self.vocab_size }
fn get_decoder_start_id(&self) -> Option<i64> { self.decoder_start_id }
fn prepare_scores_for_generation(&self, scores: &mut Tensor, current_length: i64, max_length: i64) {
if current_length == 1 {
self.force_token_id_generation(scores, &vec!(self.get_bos_id().unwrap()));
} else if current_length == max_length - 1 {
self.force_token_id_generation(scores, self.get_eos_ids().as_ref().unwrap());
}
}
fn encode(&mut self, input_ids: &Tensor, attention_mask: Option<&Tensor>) -> Option<Tensor> {
Some(self.get_model().encode(input_ids, attention_mask))
}
@ -459,6 +494,18 @@ impl PrivateLanguageGenerator<BartForConditionalGeneration, RobertaVocab, Robert
Tensor::stack(&token_ids, 0)
}
fn reorder_cache(&mut self, _past: Option<Vec<Tensor>>, encoder_outputs: Option<Tensor>, beam_indices: &Tensor) -> (Option<Vec<Tensor>>, Option<Tensor>) {
let encoder_outputs = match encoder_outputs {
Some(value) => Some(value.index_select(0, beam_indices)),
None => None
};
for layer in self.get_model().get_base_model().get_decoder().get_layers() {
layer.get_self_attention().prev_state.as_mut().unwrap().reorder_cache(beam_indices);
layer.get_encoder_attention().prev_state.as_mut().unwrap().reorder_cache(beam_indices);
};
(None, encoder_outputs)
}
}
impl LanguageGenerator<BartForConditionalGeneration, RobertaVocab, RobertaTokenizer> for BartGenerator {}
@ -483,6 +530,11 @@ mod private_generation_utils {
fn get_eos_ids(&self) -> &Option<Vec<i64>>;
fn get_pad_id(&self) -> &Option<i64>;
fn is_encoder_decoder(&self) -> bool;
fn get_vocab_size(&self) -> i64;
fn get_decoder_start_id(&self) -> Option<i64>;
fn prepare_scores_for_generation(&self, _scores: &mut Tensor, _current_length: i64, _max_length: i64) {}
fn encode(&mut self, _input_ids: &Tensor, _attention_mask: Option<&Tensor>) -> Option<Tensor> { None }
fn prepare_inputs_for_generation<'a>(&self,
@ -656,13 +708,11 @@ mod private_generation_utils {
if repetition_penalty > 1f64 {
self.enforce_repetition_penalty(&mut next_token_logits, batch_size, 1, &input_ids, repetition_penalty)
}
// Get banned tokens and set their probability to 0
let banned_tokens = self.get_banned_tokens(&input_ids, no_repeat_ngram_size as i64, current_length as i64);
for (batch_index, index_banned_token) in (0..banned_tokens.len() as i64).zip(banned_tokens) {
&next_token_logits.get(batch_index).index_fill_(0, &Tensor::of_slice(&index_banned_token).to_device(next_token_logits.device()), std::f64::NEG_INFINITY);
}
// Do not allow eos token if min length is not reached
if (&eos_token_ids.is_some()) & (current_length < min_length) {
&next_token_logits.index_fill_(1, &Tensor::of_slice(eos_token_ids.as_ref().unwrap()).to(next_token_logits.device()), std::f64::NEG_INFINITY);
@ -698,9 +748,10 @@ mod private_generation_utils {
break;
}
}
attention_mask = Tensor::cat(&[attention_mask.as_ref(), Tensor::ones(&[*attention_mask.size().first().unwrap(), 1],
(Int64, attention_mask.device())).as_ref()], -1);
if !self.is_encoder_decoder() {
attention_mask = Tensor::cat(&[attention_mask.as_ref(), Tensor::ones(&[*attention_mask.size().first().unwrap(), 1],
(Int64, attention_mask.device())).as_ref()], -1);
}
current_length += 1;
}
@ -734,10 +785,10 @@ mod private_generation_utils {
.map(|_| BeamHypotheses::new(num_beams, max_length, length_penalty, early_stopping))
.collect::<Vec<BeamHypotheses>>();
let vocab_size = self.get_tokenizer().vocab().values().len() as i64;
let vocab_size = self.get_vocab_size();
let beam_scores = Tensor::zeros(&[batch_size, num_beams], (Float, self.get_var_store().device()));
if !do_sample {
let _ = beam_scores.slice(1, 1, *beam_scores.size().last().unwrap(), 1).fill_(std::f64::NEG_INFINITY);
let _ = beam_scores.slice(1, 1, *beam_scores.size().last().unwrap(), 1).fill_(-1e9);
}
let mut beam_scores = beam_scores.view_(&[-1]);
@ -749,6 +800,7 @@ mod private_generation_utils {
let mut attention_mask = attention_mask.copy();
let mut input_ids = input_ids.copy();
let mut outputs: Tensor;
let mut encoder_outputs = encoder_outputs;
let mut current_length = cur_len;
while current_length < max_length {
@ -769,6 +821,7 @@ mod private_generation_utils {
&prepared_decoder_input,
false).unwrap();
outputs = temp.0;
encoder_outputs = temp.1;
past = temp.2;
let mut next_token_logits = outputs.select(1, -1);
@ -781,6 +834,9 @@ mod private_generation_utils {
next_token_logits = next_token_logits / temperature;
}
let mut scores = next_token_logits.log_softmax(-1, Float);
if self.is_encoder_decoder() & !do_sample {
self.prepare_scores_for_generation(&mut scores, current_length, max_length);
}
// Do not allow eos token if min length is not reached
if (&eos_token_ids.is_some()) & (current_length < min_length) {
&scores.index_fill_(1, &Tensor::of_slice(eos_token_ids.as_ref().unwrap()).to(scores.device()), std::f64::NEG_INFINITY);
@ -870,13 +926,13 @@ mod private_generation_utils {
input_ids = input_ids.index_select(0, &beam_indices);
input_ids = Tensor::cat(&[input_ids, beam_tokens.unsqueeze(1)], -1);
past = match past {
Some(past_values) => Some(self.reorder_cache(past_values, &beam_indices)),
None => None
};
attention_mask = Tensor::cat(&[attention_mask.as_ref(), Tensor::ones(&[*attention_mask.size().first().unwrap(), 1],
(Int64, attention_mask.device())).as_ref()], -1);
let temp_past = self.reorder_cache(past, encoder_outputs, &beam_indices);
past = temp_past.0;
encoder_outputs = temp_past.1;
if !self.is_encoder_decoder() {
attention_mask = Tensor::cat(&[attention_mask.as_ref(), Tensor::ones(&[*attention_mask.size().first().unwrap(), 1],
(Int64, attention_mask.device())).as_ref()], -1);
}
current_length += 1;
}
@ -946,12 +1002,17 @@ mod private_generation_utils {
decoded
}
fn reorder_cache(&self, past: Vec<Tensor>, beam_indices: &Tensor) -> Vec<Tensor> {
let mut reordered_past = vec!();
for layer_past in past.iter() {
reordered_past.push(layer_past.index_select(1, beam_indices));
fn reorder_cache(&mut self, past: Option<Vec<Tensor>>, _encoder_outputs: Option<Tensor>, beam_indices: &Tensor) -> (Option<Vec<Tensor>>, Option<Tensor>) {
match past {
Some(value) => {
let mut reordered_past = vec!();
for layer_past in value.iter() {
reordered_past.push(layer_past.index_select(1, beam_indices));
}
(Some(reordered_past), None)
}
None => (None, None)
}
reordered_past
}
}
}
@ -1100,7 +1161,7 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>: PrivateL
(input_ids, attention_mask)
}
} else {
let decoder_start_token_id = self.get_bos_id().expect("BOS token id must be specified for encoder decoders");
let decoder_start_token_id = self.get_decoder_start_id().expect("decoder start id must be specified for encoder decoders");
let input_ids = Tensor::full(&[effective_batch_size * num_beams as i64, 1], decoder_start_token_id, (Int64, input_ids.device()));
(input_ids, attention_mask)
};