mirror of
https://github.com/guillaume-be/rust-bert.git
synced 2024-10-26 14:07:25 +03:00
End-to-end summarization
This commit is contained in:
parent
fa87fce96e
commit
0a8832a012
@ -40,12 +40,12 @@ fn main() -> failure::Fallible<()> {
|
||||
let device = Device::cuda_if_available();
|
||||
let generate_config = GenerateConfig {
|
||||
max_length: 142,
|
||||
do_sample: true,
|
||||
do_sample: false,
|
||||
num_beams: 3,
|
||||
temperature: 1.0,
|
||||
top_k: 50,
|
||||
top_p: 1.0,
|
||||
length_penalty: 2.0,
|
||||
length_penalty: 1.0,
|
||||
min_length: 56,
|
||||
num_return_sequences: 1,
|
||||
..Default::default()
|
||||
|
@ -17,9 +17,23 @@ use tch::kind::Kind::Float;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct LayerState {
|
||||
prev_key: Option<Tensor>,
|
||||
prev_value: Option<Tensor>,
|
||||
prev_key_padding_mask: Option<Tensor>,
|
||||
pub prev_key: Option<Tensor>,
|
||||
pub prev_value: Option<Tensor>,
|
||||
pub prev_key_padding_mask: Option<Tensor>,
|
||||
}
|
||||
|
||||
impl LayerState {
|
||||
pub fn reorder_cache(&mut self, new_indices: &Tensor) {
|
||||
if self.prev_key.is_some() {
|
||||
self.prev_key = Some(self.prev_key.as_ref().unwrap().index_select(0, new_indices));
|
||||
}
|
||||
if self.prev_value.is_some() {
|
||||
self.prev_value = Some(self.prev_value.as_ref().unwrap().index_select(0, new_indices));
|
||||
}
|
||||
if self.prev_key_padding_mask.is_some() {
|
||||
self.prev_key_padding_mask = Some(self.prev_key_padding_mask.as_ref().unwrap().index_select(0, new_indices));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -112,8 +126,8 @@ impl SelfAttention {
|
||||
|
||||
self.prev_state = match &self.prev_state {
|
||||
Some(_) => Some(LayerState {
|
||||
prev_key: Some(k.copy()),
|
||||
prev_value: Some(v.copy()),
|
||||
prev_key: Some(k.copy().view((bs, self.num_heads, -1, self.head_dim))),
|
||||
prev_value: Some(v.copy().view((bs, self.num_heads, -1, self.head_dim))),
|
||||
prev_key_padding_mask: match key_padding_mask.as_ref() {
|
||||
Some(tensor) => Some(tensor.copy()),
|
||||
None => None
|
||||
|
@ -162,6 +162,8 @@ impl BartModel {
|
||||
BartModel { encoder, decoder, generation_mode, pad_token_id }
|
||||
}
|
||||
|
||||
pub fn get_decoder(&mut self) -> &mut BartDecoder { &mut self.decoder }
|
||||
|
||||
pub fn forward_t(&mut self,
|
||||
input_ids: Option<&Tensor>,
|
||||
attention_mask: Option<&Tensor>,
|
||||
@ -236,6 +238,8 @@ impl BartForConditionalGeneration {
|
||||
all_encoder_hidden_states, all_encoder_attentions)
|
||||
}
|
||||
|
||||
pub fn get_base_model(&mut self) -> &mut BartModel { &mut self.base_model }
|
||||
|
||||
pub fn encode(&mut self, input_ids: &Tensor, attention_mask: Option<&Tensor>) -> Tensor {
|
||||
let (encoder_hidden_states, _, _) = self.base_model.encoder.forward_t(input_ids, attention_mask, false);
|
||||
encoder_hidden_states
|
||||
@ -332,6 +336,6 @@ impl LMHeadModel for BartForConditionalGeneration {
|
||||
train);
|
||||
|
||||
let lm_logits = decoder_output.linear::<Tensor>(&self.base_model.encoder.embed_tokens.ws, None);
|
||||
Ok((lm_logits, Some(encoder_hidden_states), None, None, None))
|
||||
Ok((lm_logits, Some(encoder_hidden_states), None, None, None))
|
||||
}
|
||||
}
|
@ -96,6 +96,9 @@ impl DecoderLayer {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_self_attention(&mut self) -> &mut SelfAttention { &mut self.self_attention }
|
||||
pub fn get_encoder_attention(&mut self) -> &mut SelfAttention { &mut self.encoder_attention }
|
||||
|
||||
pub fn forward_t(&mut self,
|
||||
x: &Tensor,
|
||||
encoder_hidden_states: &Tensor,
|
||||
@ -186,17 +189,19 @@ impl BartDecoder {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_layers(&mut self) -> &mut Vec<DecoderLayer> { &mut self.layers }
|
||||
|
||||
pub fn forward_t(&mut self,
|
||||
input_ids: &Tensor,
|
||||
encoder_hidden_states: &Tensor,
|
||||
encoder_padding_mask: Option<&Tensor>,
|
||||
decoder_padding_mask: Option<&Tensor>,
|
||||
decoder_causal_mask: Option<&Tensor>,
|
||||
train: bool)
|
||||
-> (Tensor,
|
||||
(Option<Tensor>, Option<Vec<(&LayerState, &LayerState)>>),
|
||||
Option<Vec<Tensor>>,
|
||||
Option<Vec<Tensor>>) {
|
||||
input_ids: &Tensor,
|
||||
encoder_hidden_states: &Tensor,
|
||||
encoder_padding_mask: Option<&Tensor>,
|
||||
decoder_padding_mask: Option<&Tensor>,
|
||||
decoder_causal_mask: Option<&Tensor>,
|
||||
train: bool)
|
||||
-> (Tensor,
|
||||
(Option<Tensor>, Option<Vec<(&LayerState, &LayerState)>>),
|
||||
Option<Vec<Tensor>>,
|
||||
Option<Vec<Tensor>>) {
|
||||
let encoder_padding_mask = match encoder_padding_mask {
|
||||
Some(mask) => Some(mask.eq(0).to_kind(Int64)),
|
||||
None => None
|
||||
|
@ -159,6 +159,8 @@ pub struct OpenAIGenerator {
|
||||
eos_token_ids: Option<Vec<i64>>,
|
||||
pad_token_id: Option<i64>,
|
||||
is_encoder_decoder: bool,
|
||||
vocab_size: i64,
|
||||
decoder_start_id: Option<i64>,
|
||||
}
|
||||
|
||||
impl OpenAIGenerator {
|
||||
@ -215,8 +217,10 @@ impl OpenAIGenerator {
|
||||
let eos_token_ids = None;
|
||||
let pad_token_id = None;
|
||||
let is_encoder_decoder = false;
|
||||
let vocab_size = config.vocab_size;
|
||||
let decoder_start_id = None;
|
||||
|
||||
Ok(OpenAIGenerator { model, tokenizer, var_store, generate_config, bos_token_id, eos_token_ids, pad_token_id, is_encoder_decoder })
|
||||
Ok(OpenAIGenerator { model, tokenizer, var_store, generate_config, bos_token_id, eos_token_ids, pad_token_id, is_encoder_decoder, vocab_size, decoder_start_id })
|
||||
}
|
||||
}
|
||||
|
||||
@ -229,6 +233,8 @@ impl PrivateLanguageGenerator<OpenAIGPTLMHeadModel, OpenAiGptVocab, OpenAiGptTok
|
||||
fn get_eos_ids(&self) -> &Option<Vec<i64>> { &self.eos_token_ids }
|
||||
fn get_pad_id(&self) -> &Option<i64> { &self.pad_token_id }
|
||||
fn is_encoder_decoder(&self) -> bool { self.is_encoder_decoder }
|
||||
fn get_vocab_size(&self) -> i64 { self.vocab_size }
|
||||
fn get_decoder_start_id(&self) -> Option<i64> { self.decoder_start_id }
|
||||
}
|
||||
|
||||
impl LanguageGenerator<OpenAIGPTLMHeadModel, OpenAiGptVocab, OpenAiGptTokenizer> for OpenAIGenerator {}
|
||||
@ -243,6 +249,8 @@ pub struct GPT2Generator {
|
||||
eos_token_ids: Option<Vec<i64>>,
|
||||
pad_token_id: Option<i64>,
|
||||
is_encoder_decoder: bool,
|
||||
vocab_size: i64,
|
||||
decoder_start_id: Option<i64>,
|
||||
}
|
||||
|
||||
impl GPT2Generator {
|
||||
@ -299,8 +307,10 @@ impl GPT2Generator {
|
||||
let eos_token_ids = Some(vec!(tokenizer.vocab().token_to_id(Gpt2Vocab::eos_value())));
|
||||
let pad_token_id = None;
|
||||
let is_encoder_decoder = false;
|
||||
let vocab_size = config.vocab_size;
|
||||
let decoder_start_id = None;
|
||||
|
||||
Ok(GPT2Generator { model, tokenizer, var_store, generate_config, bos_token_id, eos_token_ids, pad_token_id, is_encoder_decoder })
|
||||
Ok(GPT2Generator { model, tokenizer, var_store, generate_config, bos_token_id, eos_token_ids, pad_token_id, is_encoder_decoder, vocab_size, decoder_start_id })
|
||||
}
|
||||
}
|
||||
|
||||
@ -313,6 +323,8 @@ impl PrivateLanguageGenerator<GPT2LMHeadModel, Gpt2Vocab, Gpt2Tokenizer> for GPT
|
||||
fn get_eos_ids(&self) -> &Option<Vec<i64>> { &self.eos_token_ids }
|
||||
fn get_pad_id(&self) -> &Option<i64> { &self.pad_token_id }
|
||||
fn is_encoder_decoder(&self) -> bool { self.is_encoder_decoder }
|
||||
fn get_vocab_size(&self) -> i64 { self.vocab_size }
|
||||
fn get_decoder_start_id(&self) -> Option<i64> { self.decoder_start_id }
|
||||
|
||||
fn prepare_inputs_for_generation<'a>(&self,
|
||||
input_ids: Tensor,
|
||||
@ -340,6 +352,8 @@ pub struct BartGenerator {
|
||||
eos_token_ids: Option<Vec<i64>>,
|
||||
pad_token_id: Option<i64>,
|
||||
is_encoder_decoder: bool,
|
||||
vocab_size: i64,
|
||||
decoder_start_id: Option<i64>,
|
||||
}
|
||||
|
||||
impl BartGenerator {
|
||||
@ -392,7 +406,7 @@ impl BartGenerator {
|
||||
let model = BartForConditionalGeneration::new(&var_store.root(), &config, true);
|
||||
var_store.load(weight_path)?;
|
||||
|
||||
let bos_token_id = Some(2);
|
||||
let bos_token_id = Some(0);
|
||||
let eos_token_ids = Some(match config.eos_token_id {
|
||||
Some(value) => vec!(value),
|
||||
None => vec!(2)
|
||||
@ -401,9 +415,19 @@ impl BartGenerator {
|
||||
Some(value) => value,
|
||||
None => 1
|
||||
});
|
||||
let vocab_size = config.vocab_size;
|
||||
let is_encoder_decoder = true;
|
||||
let decoder_start_id = Some(2);
|
||||
|
||||
Ok(BartGenerator { model, tokenizer, var_store, generate_config, bos_token_id, eos_token_ids, pad_token_id, is_encoder_decoder })
|
||||
Ok(BartGenerator { model, tokenizer, var_store, generate_config, bos_token_id, eos_token_ids, pad_token_id, is_encoder_decoder, vocab_size, decoder_start_id })
|
||||
}
|
||||
|
||||
fn force_token_id_generation(&self, scores: &mut Tensor, token_ids: &[i64]) {
|
||||
let impossible_tokens: Vec<i64> = (0..self.get_vocab_size() as i64)
|
||||
.filter(|pos| !token_ids.contains(pos))
|
||||
.collect();
|
||||
let impossible_tokens = Tensor::of_slice(&impossible_tokens).to_device(scores.device());
|
||||
let _ = scores.index_fill_(1, &impossible_tokens, std::f64::NEG_INFINITY);
|
||||
}
|
||||
}
|
||||
|
||||
@ -416,6 +440,17 @@ impl PrivateLanguageGenerator<BartForConditionalGeneration, RobertaVocab, Robert
|
||||
fn get_eos_ids(&self) -> &Option<Vec<i64>> { &self.eos_token_ids }
|
||||
fn get_pad_id(&self) -> &Option<i64> { &self.pad_token_id }
|
||||
fn is_encoder_decoder(&self) -> bool { self.is_encoder_decoder }
|
||||
fn get_vocab_size(&self) -> i64 { self.vocab_size }
|
||||
fn get_decoder_start_id(&self) -> Option<i64> { self.decoder_start_id }
|
||||
|
||||
fn prepare_scores_for_generation(&self, scores: &mut Tensor, current_length: i64, max_length: i64) {
|
||||
if current_length == 1 {
|
||||
self.force_token_id_generation(scores, &vec!(self.get_bos_id().unwrap()));
|
||||
} else if current_length == max_length - 1 {
|
||||
self.force_token_id_generation(scores, self.get_eos_ids().as_ref().unwrap());
|
||||
}
|
||||
}
|
||||
|
||||
fn encode(&mut self, input_ids: &Tensor, attention_mask: Option<&Tensor>) -> Option<Tensor> {
|
||||
Some(self.get_model().encode(input_ids, attention_mask))
|
||||
}
|
||||
@ -459,6 +494,18 @@ impl PrivateLanguageGenerator<BartForConditionalGeneration, RobertaVocab, Robert
|
||||
|
||||
Tensor::stack(&token_ids, 0)
|
||||
}
|
||||
|
||||
fn reorder_cache(&mut self, _past: Option<Vec<Tensor>>, encoder_outputs: Option<Tensor>, beam_indices: &Tensor) -> (Option<Vec<Tensor>>, Option<Tensor>) {
|
||||
let encoder_outputs = match encoder_outputs {
|
||||
Some(value) => Some(value.index_select(0, beam_indices)),
|
||||
None => None
|
||||
};
|
||||
for layer in self.get_model().get_base_model().get_decoder().get_layers() {
|
||||
layer.get_self_attention().prev_state.as_mut().unwrap().reorder_cache(beam_indices);
|
||||
layer.get_encoder_attention().prev_state.as_mut().unwrap().reorder_cache(beam_indices);
|
||||
};
|
||||
(None, encoder_outputs)
|
||||
}
|
||||
}
|
||||
|
||||
impl LanguageGenerator<BartForConditionalGeneration, RobertaVocab, RobertaTokenizer> for BartGenerator {}
|
||||
@ -483,6 +530,11 @@ mod private_generation_utils {
|
||||
fn get_eos_ids(&self) -> &Option<Vec<i64>>;
|
||||
fn get_pad_id(&self) -> &Option<i64>;
|
||||
fn is_encoder_decoder(&self) -> bool;
|
||||
fn get_vocab_size(&self) -> i64;
|
||||
fn get_decoder_start_id(&self) -> Option<i64>;
|
||||
|
||||
fn prepare_scores_for_generation(&self, _scores: &mut Tensor, _current_length: i64, _max_length: i64) {}
|
||||
|
||||
fn encode(&mut self, _input_ids: &Tensor, _attention_mask: Option<&Tensor>) -> Option<Tensor> { None }
|
||||
|
||||
fn prepare_inputs_for_generation<'a>(&self,
|
||||
@ -656,13 +708,11 @@ mod private_generation_utils {
|
||||
if repetition_penalty > 1f64 {
|
||||
self.enforce_repetition_penalty(&mut next_token_logits, batch_size, 1, &input_ids, repetition_penalty)
|
||||
}
|
||||
|
||||
// Get banned tokens and set their probability to 0
|
||||
let banned_tokens = self.get_banned_tokens(&input_ids, no_repeat_ngram_size as i64, current_length as i64);
|
||||
for (batch_index, index_banned_token) in (0..banned_tokens.len() as i64).zip(banned_tokens) {
|
||||
&next_token_logits.get(batch_index).index_fill_(0, &Tensor::of_slice(&index_banned_token).to_device(next_token_logits.device()), std::f64::NEG_INFINITY);
|
||||
}
|
||||
|
||||
// Do not allow eos token if min length is not reached
|
||||
if (&eos_token_ids.is_some()) & (current_length < min_length) {
|
||||
&next_token_logits.index_fill_(1, &Tensor::of_slice(eos_token_ids.as_ref().unwrap()).to(next_token_logits.device()), std::f64::NEG_INFINITY);
|
||||
@ -698,9 +748,10 @@ mod private_generation_utils {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
attention_mask = Tensor::cat(&[attention_mask.as_ref(), Tensor::ones(&[*attention_mask.size().first().unwrap(), 1],
|
||||
(Int64, attention_mask.device())).as_ref()], -1);
|
||||
if !self.is_encoder_decoder() {
|
||||
attention_mask = Tensor::cat(&[attention_mask.as_ref(), Tensor::ones(&[*attention_mask.size().first().unwrap(), 1],
|
||||
(Int64, attention_mask.device())).as_ref()], -1);
|
||||
}
|
||||
current_length += 1;
|
||||
}
|
||||
|
||||
@ -734,10 +785,10 @@ mod private_generation_utils {
|
||||
.map(|_| BeamHypotheses::new(num_beams, max_length, length_penalty, early_stopping))
|
||||
.collect::<Vec<BeamHypotheses>>();
|
||||
|
||||
let vocab_size = self.get_tokenizer().vocab().values().len() as i64;
|
||||
let vocab_size = self.get_vocab_size();
|
||||
let beam_scores = Tensor::zeros(&[batch_size, num_beams], (Float, self.get_var_store().device()));
|
||||
if !do_sample {
|
||||
let _ = beam_scores.slice(1, 1, *beam_scores.size().last().unwrap(), 1).fill_(std::f64::NEG_INFINITY);
|
||||
let _ = beam_scores.slice(1, 1, *beam_scores.size().last().unwrap(), 1).fill_(-1e9);
|
||||
}
|
||||
|
||||
let mut beam_scores = beam_scores.view_(&[-1]);
|
||||
@ -749,6 +800,7 @@ mod private_generation_utils {
|
||||
let mut attention_mask = attention_mask.copy();
|
||||
let mut input_ids = input_ids.copy();
|
||||
let mut outputs: Tensor;
|
||||
let mut encoder_outputs = encoder_outputs;
|
||||
let mut current_length = cur_len;
|
||||
|
||||
while current_length < max_length {
|
||||
@ -769,6 +821,7 @@ mod private_generation_utils {
|
||||
&prepared_decoder_input,
|
||||
false).unwrap();
|
||||
outputs = temp.0;
|
||||
encoder_outputs = temp.1;
|
||||
past = temp.2;
|
||||
let mut next_token_logits = outputs.select(1, -1);
|
||||
|
||||
@ -781,6 +834,9 @@ mod private_generation_utils {
|
||||
next_token_logits = next_token_logits / temperature;
|
||||
}
|
||||
let mut scores = next_token_logits.log_softmax(-1, Float);
|
||||
if self.is_encoder_decoder() & !do_sample {
|
||||
self.prepare_scores_for_generation(&mut scores, current_length, max_length);
|
||||
}
|
||||
// Do not allow eos token if min length is not reached
|
||||
if (&eos_token_ids.is_some()) & (current_length < min_length) {
|
||||
&scores.index_fill_(1, &Tensor::of_slice(eos_token_ids.as_ref().unwrap()).to(scores.device()), std::f64::NEG_INFINITY);
|
||||
@ -870,13 +926,13 @@ mod private_generation_utils {
|
||||
|
||||
input_ids = input_ids.index_select(0, &beam_indices);
|
||||
input_ids = Tensor::cat(&[input_ids, beam_tokens.unsqueeze(1)], -1);
|
||||
past = match past {
|
||||
Some(past_values) => Some(self.reorder_cache(past_values, &beam_indices)),
|
||||
None => None
|
||||
};
|
||||
|
||||
attention_mask = Tensor::cat(&[attention_mask.as_ref(), Tensor::ones(&[*attention_mask.size().first().unwrap(), 1],
|
||||
(Int64, attention_mask.device())).as_ref()], -1);
|
||||
let temp_past = self.reorder_cache(past, encoder_outputs, &beam_indices);
|
||||
past = temp_past.0;
|
||||
encoder_outputs = temp_past.1;
|
||||
if !self.is_encoder_decoder() {
|
||||
attention_mask = Tensor::cat(&[attention_mask.as_ref(), Tensor::ones(&[*attention_mask.size().first().unwrap(), 1],
|
||||
(Int64, attention_mask.device())).as_ref()], -1);
|
||||
}
|
||||
current_length += 1;
|
||||
}
|
||||
|
||||
@ -946,12 +1002,17 @@ mod private_generation_utils {
|
||||
decoded
|
||||
}
|
||||
|
||||
fn reorder_cache(&self, past: Vec<Tensor>, beam_indices: &Tensor) -> Vec<Tensor> {
|
||||
let mut reordered_past = vec!();
|
||||
for layer_past in past.iter() {
|
||||
reordered_past.push(layer_past.index_select(1, beam_indices));
|
||||
fn reorder_cache(&mut self, past: Option<Vec<Tensor>>, _encoder_outputs: Option<Tensor>, beam_indices: &Tensor) -> (Option<Vec<Tensor>>, Option<Tensor>) {
|
||||
match past {
|
||||
Some(value) => {
|
||||
let mut reordered_past = vec!();
|
||||
for layer_past in value.iter() {
|
||||
reordered_past.push(layer_past.index_select(1, beam_indices));
|
||||
}
|
||||
(Some(reordered_past), None)
|
||||
}
|
||||
None => (None, None)
|
||||
}
|
||||
reordered_past
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1100,7 +1161,7 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>: PrivateL
|
||||
(input_ids, attention_mask)
|
||||
}
|
||||
} else {
|
||||
let decoder_start_token_id = self.get_bos_id().expect("BOS token id must be specified for encoder decoders");
|
||||
let decoder_start_token_id = self.get_decoder_start_id().expect("decoder start id must be specified for encoder decoders");
|
||||
let input_ids = Tensor::full(&[effective_batch_size * num_beams as i64, 1], decoder_start_token_id, (Int64, input_ids.device()));
|
||||
(input_ids, attention_mask)
|
||||
};
|
||||
|
Loading…
Reference in New Issue
Block a user