Addition of forced_bos_token_id argument for generation

This commit is contained in:
Guillaume B 2021-06-06 12:03:31 +02:00
parent 1073fd06c0
commit 698e7143e8
9 changed files with 60 additions and 27 deletions

View File

@ -1148,9 +1148,13 @@ impl PrivateLanguageGenerator<BartForConditionalGeneration, RobertaVocab, Robert
scores: &mut Tensor,
current_length: i64,
max_length: i64,
forced_bos_token_id: Option<i64>,
) {
if current_length == 1 {
self.force_token_id_generation(scores, &[self.get_bos_id().unwrap()]);
self.force_token_id_generation(
scores,
&[forced_bos_token_id.unwrap_or_else(|| self.get_bos_id().unwrap())],
);
} else if current_length == max_length - 1 {
self.force_token_id_generation(scores, self.get_eos_ids().as_ref().unwrap());
}

View File

@ -895,6 +895,7 @@ impl PrivateLanguageGenerator<MarianForConditionalGeneration, MarianVocab, Maria
scores: &mut Tensor,
current_length: i64,
max_length: i64,
_forced_bos_token_id: Option<i64>,
) {
let _ = scores.index_fill_(
1,

View File

@ -781,6 +781,7 @@ pub struct MBartGenerator {
is_encoder_decoder: bool,
vocab_size: i64,
decoder_start_id: Option<i64>,
max_position_embeddings: i64,
}
impl MBartGenerator {
@ -862,12 +863,12 @@ impl MBartGenerator {
generate_config.validate();
let mut var_store = nn::VarStore::new(device);
let tokenizer = TokenizerOption::from_file(
ModelType::Bart,
ModelType::MBart,
vocab_path.to_str().unwrap(),
None,
false,
None,
false,
None,
)?;
let config = MBartConfig::from_file(config_path);
let model = MBartForConditionalGeneration::new(&var_store.root(), &config);
@ -882,6 +883,7 @@ impl MBartGenerator {
let vocab_size = config.vocab_size;
let is_encoder_decoder = true;
let decoder_start_id = Some(2);
let max_position_embeddings = config.max_position_embeddings;
Ok(MBartGenerator {
model,
@ -894,6 +896,7 @@ impl MBartGenerator {
is_encoder_decoder,
vocab_size,
decoder_start_id,
max_position_embeddings,
})
}
@ -940,14 +943,19 @@ impl PrivateLanguageGenerator<MBartForConditionalGeneration, MBart50Vocab, MBart
self.decoder_start_id
}
fn get_max_positions_embeddings(&self) -> i64 {
self.max_position_embeddings
}
fn prepare_scores_for_generation(
&self,
scores: &mut Tensor,
current_length: i64,
max_length: i64,
forced_bos_token_id: Option<i64>,
) {
if current_length == 1 {
self.force_token_id_generation(scores, &[self.get_bos_id().unwrap()]);
self.force_token_id_generation(scores, &[forced_bos_token_id.unwrap_or(250004)]);
} else if current_length == max_length - 1 {
self.force_token_id_generation(scores, self.get_eos_ids().as_ref().unwrap());
}

View File

@ -727,6 +727,7 @@ impl PrivateLanguageGenerator<PegasusForConditionalGeneration, PegasusVocab, Peg
scores: &mut Tensor,
current_length: i64,
max_length: i64,
_forced_bos_token_id: Option<i64>,
) {
if current_length == max_length - 1 {
self.force_token_id_generation(scores, self.get_eos_ids().as_ref().unwrap());

View File

@ -716,9 +716,15 @@ impl ConversationOption {
attention_mask: Option<Tensor>,
) -> Vec<Vec<i64>> {
match *self {
Self::GPT2(ref model) => {
model.generate_from_ids_and_past(input_ids, attention_mask, None, None, None, None)
}
Self::GPT2(ref model) => model.generate_from_ids_and_past(
input_ids,
attention_mask,
None,
None,
None,
None,
None,
),
}
}
}

View File

@ -258,6 +258,7 @@ pub(crate) mod private_generation_utils {
pub length_penalty: f64,
pub num_beam_groups: Option<i64>,
pub diversity_penalty: Option<f64>,
pub forced_bos_token_id: Option<i64>,
}
pub struct PreparedInput<'a> {
@ -287,6 +288,7 @@ pub(crate) mod private_generation_utils {
_scores: &mut Tensor,
_current_length: i64,
_max_length: i64,
_forced_bos_token_id: Option<i64>,
) {
}
@ -665,13 +667,13 @@ pub(crate) mod private_generation_utils {
f64::NEG_INFINITY,
);
}
if self.is_encoder_decoder() & !gen_opt.do_sample {
self.prepare_scores_for_generation(
&mut next_token_logits,
current_length,
gen_opt.max_length,
);
}
self.prepare_scores_for_generation(
&mut next_token_logits,
current_length,
gen_opt.max_length,
gen_opt.forced_bos_token_id,
);
// Top-k and top-p sampling
let next_token = if gen_opt.do_sample {
@ -861,13 +863,13 @@ pub(crate) mod private_generation_utils {
if gen_opt.temperature > 1f64 {
next_token_logits /= gen_opt.temperature;
}
if self.is_encoder_decoder() & !gen_opt.do_sample {
self.prepare_scores_for_generation(
&mut next_token_logits,
current_length,
gen_opt.max_length,
);
}
self.prepare_scores_for_generation(
&mut next_token_logits,
current_length,
gen_opt.max_length,
gen_opt.forced_bos_token_id,
);
let mut scores = next_token_logits.log_softmax(-1, Float);
// Do not allow eos token if min length is not reached
if (gen_opt.eos_token_ids.is_some()) & (current_length < gen_opt.min_length) {
@ -1274,6 +1276,7 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
min_length: impl Into<Option<i64>>,
max_length: impl Into<Option<i64>>,
decoder_start_token_id: impl Into<Option<i64>>,
forced_bos_token_id: impl Into<Option<i64>>,
prefix_allowed_tokens_fn: Option<&dyn Fn(i64, &Tensor) -> Vec<i64>>,
) -> Vec<String>
where
@ -1285,6 +1288,7 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
min_length,
max_length,
decoder_start_token_id,
forced_bos_token_id,
prefix_allowed_tokens_fn,
);
let mut output = Vec::with_capacity(generated.len());
@ -1376,6 +1380,7 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
min_length: impl Into<Option<i64>>,
max_length: impl Into<Option<i64>>,
decoder_start_token_id: impl Into<Option<i64>>,
forced_bos_token_id: impl Into<Option<i64>>,
prefix_allowed_tokens_fn: Option<&dyn Fn(i64, &Tensor) -> Vec<i64>>,
) -> Vec<Vec<i64>>
where
@ -1412,6 +1417,7 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
min_length,
max_length,
decoder_start_token_id,
forced_bos_token_id,
prefix_allowed_tokens_fn,
)
}
@ -1499,6 +1505,7 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
min_length: impl Into<Option<i64>>,
max_length: impl Into<Option<i64>>,
decoder_start_token_id: impl Into<Option<i64>>,
forced_bos_token_id: impl Into<Option<i64>>,
prefix_allowed_tokens_fn: Option<&dyn Fn(i64, &Tensor) -> Vec<i64>>,
) -> Vec<Vec<i64>> {
let eos_token_ids = PrivateLanguageGenerator::get_eos_ids(self).clone();
@ -1628,6 +1635,7 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
length_penalty,
num_beam_groups,
diversity_penalty,
forced_bos_token_id: forced_bos_token_id.into(),
};
let decoded = no_grad(|| {

View File

@ -264,16 +264,16 @@ impl SummarizationOption {
{
match *self {
Self::Bart(ref model) => {
model.generate(prompt_texts, attention_mask, None, None, None, None)
model.generate(prompt_texts, attention_mask, None, None, None, None, None)
}
Self::T5(ref model) => {
model.generate(prompt_texts, attention_mask, None, None, None, None)
model.generate(prompt_texts, attention_mask, None, None, None, None, None)
}
Self::ProphetNet(ref model) => {
model.generate(prompt_texts, attention_mask, None, None, None, None)
model.generate(prompt_texts, attention_mask, None, None, None, None, None)
}
Self::Pegasus(ref model) => {
model.generate(prompt_texts, attention_mask, None, None, None, None)
model.generate(prompt_texts, attention_mask, None, None, None, None, None)
}
}
}

View File

@ -256,6 +256,7 @@ impl TextGenerationOption {
max_length,
None,
None,
None,
),
Self::GPT2(ref model) => model.generate_indices(
prompt_texts,
@ -264,6 +265,7 @@ impl TextGenerationOption {
max_length,
None,
None,
None,
),
Self::GPTNeo(ref model) => model.generate_indices(
prompt_texts,
@ -272,6 +274,7 @@ impl TextGenerationOption {
max_length,
None,
None,
None,
),
Self::XLNet(ref model) => model.generate_indices(
prompt_texts,
@ -280,6 +283,7 @@ impl TextGenerationOption {
max_length,
None,
None,
None,
),
Self::Reformer(ref model) => model.generate_indices(
prompt_texts,
@ -288,6 +292,7 @@ impl TextGenerationOption {
max_length,
None,
None,
None,
),
}
}

View File

@ -675,10 +675,10 @@ impl TranslationOption {
{
match *self {
Self::Marian(ref model) => {
model.generate(prompt_texts, attention_mask, None, None, None, None)
model.generate(prompt_texts, attention_mask, None, None, None, None, None)
}
Self::T5(ref model) => {
model.generate(prompt_texts, attention_mask, None, None, None, None)
model.generate(prompt_texts, attention_mask, None, None, None, None, None)
}
}
}