Make max_length optional (#296)

* Made `max_length` an optional argument for generation methods and pipelines

* Updated changelog
This commit is contained in:
guillaume-be 2022-11-15 19:20:51 +00:00 committed by GitHub
parent 5d2b107e99
commit 05367b4df2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
31 changed files with 216 additions and 147 deletions

View File

@ -8,13 +8,15 @@ All notable changes to this project will be documented in this file. The format
- Addition of Keyword/Keyphrases extraction pipeline based on KeyBERT (https://github.com/MaartenGr/KeyBERT)
## Changed
- Addition of type aliases for the controlled generation (`PrefixAllowedFunction`) and zero-shot classification (`ZeroShotTemplate`)
- (BREAKING) `merges_resource` now optional for all pipelines
- Allow mixing local and remote resources in pipelines
- Upgraded to `torch` 1.13 (via `tch` 0.9.0)
- Addition of type aliases for the controlled generation (`PrefixAllowedFunction`) and zero-shot classification (`ZeroShotTemplate`).
- (BREAKING) `merges_resource` now optional for all pipelines.
- Allow mixing local and remote resources in pipelines.
- Upgraded to `torch` 1.13 (via `tch` 0.9.0).
- (BREAKING) Made the `max_length` argument for generation methods and pipelines optional.
## Fixed
- Fixed configuration check for RoBERTa models for sentence classification.
- Fixed a bug causing the input prompt to be truncated for text generation if the prompt length was longer than `max_length`
## [0.18.0] - 2022-07-24
## Added

View File

@ -21,7 +21,7 @@ fn create_text_generation_model() -> TextGenerationModel {
Gpt2MergesResources::GPT2,
))),
min_length: 0,
max_length: 30,
max_length: Some(30),
do_sample: true,
early_stopping: false,
num_beams: 5,

View File

@ -19,7 +19,7 @@ fn main() -> anyhow::Result<()> {
// Set-up model
let generate_config = TextGenerationConfig {
model_type: ModelType::GPT2,
max_length: 30,
max_length: Some(30),
do_sample: false,
num_beams: 1,
temperature: 1.0,

View File

@ -43,7 +43,7 @@ fn main() -> anyhow::Result<()> {
vocab_resource,
merges_resource: Some(merges_resource),
min_length: 10,
max_length: 32,
max_length: Some(32),
do_sample: false,
early_stopping: true,
num_beams: 4,

View File

@ -40,7 +40,7 @@ fn main() -> anyhow::Result<()> {
vocab_resource,
merges_resource: None,
min_length: 100,
max_length: 100,
max_length: Some(100),
do_sample: true,
early_stopping: false,
num_beams: 3,

View File

@ -37,7 +37,7 @@ fn main() -> anyhow::Result<()> {
config_resource,
vocab_resource,
merges_resource: None,
max_length: 32,
max_length: Some(32),
do_sample: false,
num_beams: 3,
temperature: 1.0,

View File

@ -41,7 +41,7 @@ fn main() -> anyhow::Result<()> {
num_beams: 1,
length_penalty: 1.0,
min_length: 56,
max_length: 142,
max_length: Some(142),
device: Device::Cpu,
..Default::default()
};

View File

@ -1055,7 +1055,7 @@ impl BartGenerator {
/// # let weights_path = &home.as_path().join("model.ot");
/// let device = Device::cuda_if_available();
/// let generate_config = GenerateConfig {
/// max_length: 30,
/// max_length: Some(30),
/// do_sample: true,
/// num_beams: 5,
/// temperature: 1.1,
@ -1183,7 +1183,7 @@ impl PrivateLanguageGenerator<BartForConditionalGeneration, RobertaVocab, Robert
&self,
scores: &mut Tensor,
current_length: i64,
max_length: i64,
max_length: Option<i64>,
forced_bos_token_id: Option<i64>,
) {
if current_length == 1 {
@ -1191,8 +1191,10 @@ impl PrivateLanguageGenerator<BartForConditionalGeneration, RobertaVocab, Robert
scores,
&[forced_bos_token_id.unwrap_or_else(|| self.get_bos_id().unwrap())],
);
} else if current_length == max_length - 1 {
self.force_token_id_generation(scores, self.get_eos_ids().as_ref().unwrap());
} else if let Some(max_length) = max_length {
if current_length == max_length - 1 {
self.force_token_id_generation(scores, self.get_eos_ids().as_ref().unwrap());
}
}
}
@ -1231,7 +1233,7 @@ impl PrivateLanguageGenerator<BartForConditionalGeneration, RobertaVocab, Robert
fn encode_prompt_text<S>(
&self,
prompt_text: &[S],
max_len: i64,
max_len: Option<i64>,
pad_token_id: Option<i64>,
) -> Tensor
where
@ -1239,7 +1241,9 @@ impl PrivateLanguageGenerator<BartForConditionalGeneration, RobertaVocab, Robert
{
let tokens = self._get_tokenizer().encode_list(
prompt_text,
max_len as usize,
max_len
.map(|max_len| max_len as usize)
.unwrap_or(usize::MAX),
&TruncationStrategy::LongestFirst,
0,
);

View File

@ -695,7 +695,7 @@ impl GPT2Generator {
/// use rust_bert::pipelines::generation_utils::GenerateConfig;
///
/// let generate_config = GenerateConfig {
/// max_length: 30,
/// max_length: Some(30),
/// do_sample: true,
/// num_beams: 5,
/// temperature: 1.1,

View File

@ -670,7 +670,7 @@ impl GptNeoGenerator {
/// use rust_bert::pipelines::generation_utils::GenerateConfig;
///
/// let generate_config = GenerateConfig {
/// max_length: 30,
/// max_length: Some(30),
/// do_sample: true,
/// num_beams: 5,
/// temperature: 1.1,

View File

@ -604,7 +604,7 @@ impl M2M100Generator {
/// # let weights_path = &home.as_path().join("model.ot");
/// let device = Device::cuda_if_available();
/// let generate_config = GenerateConfig {
/// max_length: 30,
/// max_length: Some(30),
/// do_sample: true,
/// num_beams: 5,
/// temperature: 1.1,
@ -734,13 +734,15 @@ impl PrivateLanguageGenerator<M2M100ForConditionalGeneration, M2M100Vocab, M2M10
&self,
scores: &mut Tensor,
current_length: i64,
max_length: i64,
max_length: Option<i64>,
forced_bos_token_id: Option<i64>,
) {
if current_length == 1 {
self.force_token_id_generation(scores, &[forced_bos_token_id.unwrap_or(250004)]);
} else if current_length == max_length - 1 {
self.force_token_id_generation(scores, self.get_eos_ids().as_ref().unwrap());
} else if let Some(max_length) = max_length {
if current_length == max_length - 1 {
self.force_token_id_generation(scores, self.get_eos_ids().as_ref().unwrap());
}
}
}
@ -779,7 +781,7 @@ impl PrivateLanguageGenerator<M2M100ForConditionalGeneration, M2M100Vocab, M2M10
fn encode_prompt_text<S>(
&self,
prompt_text: &[S],
max_len: i64,
max_len: Option<i64>,
pad_token_id: Option<i64>,
) -> Tensor
where
@ -787,7 +789,9 @@ impl PrivateLanguageGenerator<M2M100ForConditionalGeneration, M2M100Vocab, M2M10
{
let tokens = self._get_tokenizer().encode_list(
prompt_text,
max_len as usize,
max_len
.map(|max_len| max_len as usize)
.unwrap_or(usize::MAX),
&TruncationStrategy::LongestFirst,
0,
);

View File

@ -824,7 +824,7 @@ impl MarianGenerator {
/// # let weights_path = &home.as_path().join("model.ot");
/// let device = Device::cuda_if_available();
/// let generate_config = GenerateConfig {
/// max_length: 512,
/// max_length: Some(512),
/// do_sample: true,
/// num_beams: 6,
/// temperature: 1.0,
@ -956,7 +956,7 @@ impl PrivateLanguageGenerator<MarianForConditionalGeneration, MarianVocab, Maria
&self,
scores: &mut Tensor,
current_length: i64,
max_length: i64,
max_length: Option<i64>,
_forced_bos_token_id: Option<i64>,
) {
let _ = scores.index_fill_(
@ -966,8 +966,10 @@ impl PrivateLanguageGenerator<MarianForConditionalGeneration, MarianVocab, Maria
.to_device(scores.device()),
f64::NEG_INFINITY,
);
if current_length == max_length - 1 {
self.force_token_id_generation(scores, self.get_eos_ids().as_ref().unwrap());
if let Some(max_length) = max_length {
if current_length == max_length - 1 {
self.force_token_id_generation(scores, self.get_eos_ids().as_ref().unwrap());
}
}
}
@ -1006,7 +1008,7 @@ impl PrivateLanguageGenerator<MarianForConditionalGeneration, MarianVocab, Maria
fn encode_prompt_text<S>(
&self,
prompt_text: &[S],
max_len: i64,
max_len: Option<i64>,
pad_token_id: Option<i64>,
) -> Tensor
where
@ -1014,7 +1016,9 @@ impl PrivateLanguageGenerator<MarianForConditionalGeneration, MarianVocab, Maria
{
let tokens = self._get_tokenizer().encode_list(
prompt_text,
max_len as usize,
max_len
.map(|max_len| max_len as usize)
.unwrap_or(usize::MAX),
&TruncationStrategy::LongestFirst,
0,
);

View File

@ -862,7 +862,7 @@ impl MBartGenerator {
/// # let weights_path = &home.as_path().join("model.ot");
/// let device = Device::cuda_if_available();
/// let generate_config = GenerateConfig {
/// max_length: 30,
/// max_length: Some(30),
/// do_sample: true,
/// num_beams: 5,
/// temperature: 1.1,
@ -983,13 +983,15 @@ impl PrivateLanguageGenerator<MBartForConditionalGeneration, MBart50Vocab, MBart
&self,
scores: &mut Tensor,
current_length: i64,
max_length: i64,
max_length: Option<i64>,
forced_bos_token_id: Option<i64>,
) {
if current_length == 1 {
self.force_token_id_generation(scores, &[forced_bos_token_id.unwrap_or(250004)]);
} else if current_length == max_length - 1 {
self.force_token_id_generation(scores, self.get_eos_ids().as_ref().unwrap());
} else if let Some(max_length) = max_length {
if current_length == max_length - 1 {
self.force_token_id_generation(scores, self.get_eos_ids().as_ref().unwrap());
}
}
}
@ -1028,7 +1030,7 @@ impl PrivateLanguageGenerator<MBartForConditionalGeneration, MBart50Vocab, MBart
fn encode_prompt_text<S>(
&self,
prompt_text: &[S],
max_len: i64,
max_len: Option<i64>,
pad_token_id: Option<i64>,
) -> Tensor
where
@ -1036,7 +1038,9 @@ impl PrivateLanguageGenerator<MBartForConditionalGeneration, MBart50Vocab, MBart
{
let tokens = self._get_tokenizer().encode_list(
prompt_text,
max_len as usize,
max_len
.map(|max_len| max_len as usize)
.unwrap_or(usize::MAX),
&TruncationStrategy::LongestFirst,
0,
);

View File

@ -457,7 +457,7 @@ impl OpenAIGenerator {
/// use rust_bert::openai_gpt::OpenAIGenerator;
/// use rust_bert::pipelines::generation_utils::GenerateConfig;
/// let generate_config = GenerateConfig {
/// max_length: 30,
/// max_length: Some(30),
/// do_sample: true,
/// num_beams: 5,
/// temperature: 1.1,

View File

@ -585,7 +585,7 @@ impl PegasusConditionalGenerator {
/// # let weights_path = &home.as_path().join("model.ot");
/// let device = Device::cuda_if_available();
/// let generate_config = GenerateConfig {
/// max_length: 30,
/// max_length: Some(30),
/// do_sample: true,
/// num_beams: 5,
/// temperature: 1.1,
@ -710,11 +710,13 @@ impl PrivateLanguageGenerator<PegasusForConditionalGeneration, PegasusVocab, Peg
&self,
scores: &mut Tensor,
current_length: i64,
max_length: i64,
max_length: Option<i64>,
_forced_bos_token_id: Option<i64>,
) {
if current_length == max_length - 1 {
self.force_token_id_generation(scores, self.get_eos_ids().as_ref().unwrap());
if let Some(max_length) = max_length {
if current_length == max_length - 1 {
self.force_token_id_generation(scores, self.get_eos_ids().as_ref().unwrap());
}
}
}
@ -753,7 +755,7 @@ impl PrivateLanguageGenerator<PegasusForConditionalGeneration, PegasusVocab, Peg
fn encode_prompt_text<S>(
&self,
prompt_text: &[S],
max_len: i64,
max_len: Option<i64>,
pad_token_id: Option<i64>,
) -> Tensor
where
@ -761,7 +763,9 @@ impl PrivateLanguageGenerator<PegasusForConditionalGeneration, PegasusVocab, Peg
{
let tokens = self._get_tokenizer().encode_list(
prompt_text,
max_len as usize,
max_len
.map(|max_len| max_len as usize)
.unwrap_or(usize::MAX),
&TruncationStrategy::LongestFirst,
0,
);

View File

@ -86,7 +86,7 @@ pub struct ConversationConfig {
/// Minimum sequence length (default: 0)
pub min_length: i64,
/// Maximum sequence length (default: 20)
pub max_length: i64,
pub max_length: Option<i64>,
/// Minimum free length available for generated responses (default: 32)
pub min_length_for_response: i64,
/// Sampling flag. If true, will perform top-k and/or nucleus sampling on generated tokens, otherwise greedy (deterministic) decoding (default: true)
@ -135,7 +135,7 @@ impl Default for ConversationConfig {
Gpt2MergesResources::DIALOGPT_MEDIUM,
))),
min_length: 0,
max_length: 1000,
max_length: Some(1000),
min_length_for_response: 64,
do_sample: true,
early_stopping: false,
@ -750,7 +750,7 @@ impl ConversationOption {
pub struct ConversationModel {
model: ConversationOption,
eos_token_id: i64,
max_allowed_context_length: i64,
max_allowed_context_length: Option<i64>,
device: Device,
}
@ -774,8 +774,9 @@ impl ConversationModel {
pub fn new(
conversation_config: ConversationConfig,
) -> Result<ConversationModel, RustBertError> {
let max_allowed_length =
conversation_config.max_length - conversation_config.min_length_for_response;
let max_allowed_length = conversation_config
.max_length
.map(|max_length| max_length - conversation_config.min_length_for_response);
let device = conversation_config.device;
let model = ConversationOption::new(conversation_config)?;
let eos_token_id = model.get_eos_id()?;
@ -921,17 +922,18 @@ impl ConversationModel {
let truncated_concatenated_inputs = concatenated_inputs
.iter()
.map(|input| {
if input.len() > self.max_allowed_context_length as usize {
.map(|input| match self.max_allowed_context_length {
Some(max_allowed_context_length)
if input.len() > max_allowed_context_length as usize =>
{
let start = self.get_truncated_input_index(
input,
self.max_allowed_context_length as usize,
max_allowed_context_length as usize,
pad_token,
);
&input[start..]
} else {
input.as_slice()
}
_ => input.as_slice(),
})
.collect::<Vec<&[i64]>>();
@ -1018,7 +1020,9 @@ impl ConversationModel {
.convert_tokens_to_ids(&prompt_tokens)
})
.map(|mut tokens| {
tokens.truncate(self.max_allowed_context_length as usize - 1);
if let Some(max_allowed_context_length) = self.max_allowed_context_length {
tokens.truncate(max_allowed_context_length as usize - 1);
}
tokens.push(self.eos_token_id);
tokens
})

View File

@ -107,7 +107,7 @@ pub struct GenerateConfig {
/// Minimum sequence length (default: 0)
pub min_length: i64,
/// Maximum sequence length (default: 20)
pub max_length: i64,
pub max_length: Option<i64>,
/// Sampling flag. If true, will perform top-k and/or nucleus sampling on generated tokens, otherwise greedy (deterministic) decoding (default: true)
pub do_sample: bool,
/// Early stopping flag indicating if the beam search should stop as soon as `num_beam` hypotheses have been generated (default: false)
@ -147,7 +147,7 @@ impl Default for GenerateConfig {
Gpt2MergesResources::GPT2,
))),
min_length: 0,
max_length: 20,
max_length: Some(56),
do_sample: true,
early_stopping: true,
num_beams: 5,
@ -234,7 +234,6 @@ pub(crate) mod private_generation_utils {
use rust_tokenizers::tokenizer::{truncate_sequences, Tokenizer, TruncationStrategy};
use rust_tokenizers::vocab::Vocab;
use rust_tokenizers::TokenIdsWithOffsets;
use tch::kind::Kind::{Bool, Float, Int64};
use tch::{nn, Device, Kind, Tensor};
use crate::pipelines::common::TokenizerOption;
@ -247,7 +246,7 @@ pub(crate) mod private_generation_utils {
pub struct InternalGenerateOptions<'a> {
pub min_length: i64,
pub max_length: i64,
pub max_length: Option<i64>,
pub do_sample: bool,
pub temperature: f64,
pub top_k: i64,
@ -299,7 +298,7 @@ pub(crate) mod private_generation_utils {
&self,
_scores: &mut Tensor,
_current_length: i64,
_max_length: i64,
_max_length: Option<i64>,
_forced_bos_token_id: Option<i64>,
) {
}
@ -328,7 +327,7 @@ pub(crate) mod private_generation_utils {
fn encode_prompt_text<S>(
&self,
prompt_text: &[S],
max_len: i64,
max_len: Option<i64>,
pad_token_id: Option<i64>,
) -> Tensor
where
@ -343,11 +342,15 @@ pub(crate) mod private_generation_utils {
let num_truncated_tokens = token_ids
.iter()
.map(|token_ids| {
if token_ids.len() > max_len as usize {
token_ids.len() - max_len as usize
} else {
0
}
max_len
.map(|max_len| {
if token_ids.len() > max_len as usize {
token_ids.len() - max_len as usize
} else {
0
}
})
.unwrap_or(0)
})
.collect::<Vec<usize>>();
@ -408,7 +411,7 @@ pub(crate) mod private_generation_utils {
let _ = next_token_logits.get(i).index_fill_(
0,
&Tensor::of_slice(&[token])
.to_kind(Int64)
.to_kind(Kind::Int64)
.to_device(next_token_logits.device()),
updated_value * repetition_penalty,
);
@ -416,7 +419,7 @@ pub(crate) mod private_generation_utils {
let _ = next_token_logits.get(i).index_fill_(
0,
&Tensor::of_slice(&[token])
.to_kind(Int64)
.to_kind(Kind::Int64)
.to_device(next_token_logits.device()),
updated_value / repetition_penalty,
);
@ -498,17 +501,21 @@ pub(crate) mod private_generation_utils {
.softmax(-1, sorted_logits.kind())
.cumsum(-1, sorted_logits.kind());
let mut sorted_indices_to_remove =
cumulative_probabilities.ge(top_p).to_kind(Int64);
cumulative_probabilities.ge(top_p).to_kind(Kind::Int64);
if min_tokens_to_keep > 1 {
let _ = sorted_indices_to_remove.index_fill_(
1,
&Tensor::arange_start(0, min_tokens_to_keep + 1, (Int64, logits.device())),
&Tensor::arange_start(
0,
min_tokens_to_keep + 1,
(Kind::Int64, logits.device()),
),
0,
);
}
let _ = sorted_indices_to_remove.index_copy_(
1,
&Tensor::arange_start(1, vocab_size, (Int64, logits.device())),
&Tensor::arange_start(1, vocab_size, (Kind::Int64, logits.device())),
&sorted_indices_to_remove
.slice(1, 0, vocab_size - 1, 1)
.copy(),
@ -516,13 +523,13 @@ pub(crate) mod private_generation_utils {
let _ = sorted_indices_to_remove.index_fill_(
1,
&Tensor::of_slice(&[0])
.to_kind(Int64)
.to_kind(Kind::Int64)
.to_device(sorted_indices_to_remove.device()),
0,
);
let indices_to_remove = sorted_indices_to_remove
.scatter(1, &sorted_indices, &sorted_indices_to_remove)
.to_kind(Bool);
.to_kind(Kind::Bool);
let _ = logits.masked_fill_(&indices_to_remove, f64::NEG_INFINITY);
}
}
@ -746,10 +753,9 @@ pub(crate) mod private_generation_utils {
output_scores: bool,
) -> GeneratedOutputWithScores {
let mut unfinished_sentences =
Tensor::ones(&[batch_size], (Int64, self.get_var_store().device()));
Tensor::ones(&[batch_size], (Kind::Int64, self.get_var_store().device()));
let mut sentence_lengths: Tensor =
Tensor::ones(&[batch_size], (Int64, self.get_var_store().device()))
* gen_opt.max_length as i64;
Tensor::ones(&[batch_size], (Kind::Int64, self.get_var_store().device()));
let (bad_word_ids_length_1, bad_word_ids_length_greater_than_1) =
self.split_bad_word_ids(gen_opt.bad_word_ids);
let mut static_bad_words_mask: Option<Tensor> = None;
@ -761,7 +767,7 @@ pub(crate) mod private_generation_utils {
let mut token_scores_output: Option<Vec<Tensor>> =
if output_scores { Some(vec![]) } else { None };
while current_length < gen_opt.max_length {
loop {
let prepared_input = self.prepare_inputs_for_generation(
input_ids.copy(),
encoder_outputs.as_ref(),
@ -902,11 +908,12 @@ pub(crate) mod private_generation_utils {
input_ids = Tensor::cat(&[input_ids, tokens_to_add.unsqueeze(-1)], -1);
if gen_opt.eos_token_ids.is_some() {
for eos_token_id in gen_opt.eos_token_ids.as_ref().unwrap() {
let sentence_with_eos = tokens_to_add.eq(*eos_token_id).to_kind(Int64);
let sentence_with_eos =
tokens_to_add.eq(*eos_token_id).to_kind(Kind::Int64);
let sentence_with_eos: Tensor = sentence_with_eos * &unfinished_sentences;
let _ = sentence_lengths.masked_fill_(
&sentence_with_eos
.to_kind(Bool)
.to_kind(Kind::Bool)
.to_device(sentence_lengths.device()),
current_length as i64 + 1,
);
@ -922,7 +929,7 @@ pub(crate) mod private_generation_utils {
attention_mask.as_ref(),
Tensor::ones(
&[*attention_mask.size().first().unwrap(), 1],
(Int64, attention_mask.device()),
(Kind::Int64, attention_mask.device()),
)
.as_ref(),
],
@ -930,6 +937,17 @@ pub(crate) mod private_generation_utils {
);
}
current_length += 1;
if let Some(max_length) = gen_opt.max_length {
if current_length >= max_length {
let _ = sentence_lengths.masked_fill_(
&unfinished_sentences
.to_kind(Kind::Bool)
.to_device(sentence_lengths.device()),
current_length as i64,
);
break;
}
}
}
let scores_output = token_scores_output.as_ref().map(|scores_tensor| {
(Tensor::stack(scores_tensor, 1).sum_dim_intlist(
@ -993,7 +1011,7 @@ pub(crate) mod private_generation_utils {
let vocab_size = self.get_vocab_size();
let beam_scores = Tensor::ones(
&[batch_size, gen_opt.num_beams],
(Float, self.get_var_store().device()),
(Kind::Float, self.get_var_store().device()),
) * -1e9;
let _ = beam_scores
.slice(1, 0, *beam_scores.size().last().unwrap(), num_sub_beams)
@ -1002,11 +1020,11 @@ pub(crate) mod private_generation_utils {
let mut beam_scores = beam_scores.view_(&[-1]);
let mut beam_tokens = Tensor::zeros(
&[batch_size * gen_opt.num_beams],
(Int64, self.get_var_store().device()),
(Kind::Int64, self.get_var_store().device()),
);
let mut beam_indices = Tensor::zeros(
&[batch_size * gen_opt.num_beams],
(Int64, self.get_var_store().device()),
(Kind::Int64, self.get_var_store().device()),
);
let mut saved_beam_scores: Option<Vec<Tensor>> =
if output_scores { Some(vec![]) } else { None };
@ -1019,7 +1037,7 @@ pub(crate) mod private_generation_utils {
let mut encoder_outputs = encoder_outputs;
let mut current_length = cur_len;
while current_length < gen_opt.max_length {
loop {
if num_beam_groups > 1 {
current_tokens = Tensor::zeros(
&[batch_size * gen_opt.num_beams],
@ -1209,19 +1227,19 @@ pub(crate) mod private_generation_utils {
let eos_token_ids = gen_opt.eos_token_ids.as_ref();
let beam_ids_tensor = &next_tokens.divide_scalar_mode(vocab_size, "floor");
let effective_beam_ids_tensor = (&next_tokens.ones_like().cumsum(0, Int64) - 1)
* group_size
+ beam_ids_tensor;
let effective_beam_ids_tensor =
(&next_tokens.ones_like().cumsum(0, Kind::Int64) - 1) * group_size
+ beam_ids_tensor;
let token_id_tensor = &next_tokens - beam_ids_tensor * vocab_size;
let (max_scores, _) = next_scores.max_dim(1, false);
let mut eos_mask = token_id_tensor.ones_like();
if let Some(eos_token_id) = eos_token_ids {
eos_mask -= token_id_tensor.eq(eos_token_id[0]).to_kind(Int64);
eos_mask -= token_id_tensor.eq(eos_token_id[0]).to_kind(Kind::Int64);
}
let eos_mask2 = eos_mask
.cumsum(1, Int64)
.cumsum(1, Kind::Int64)
.le(group_size)
.to_kind(Bool)
.to_kind(Kind::Bool)
.logical_and(&eos_mask);
let group_beam_scores = next_scores.masked_select(&eos_mask2);
@ -1321,6 +1339,13 @@ pub(crate) mod private_generation_utils {
],
-1,
);
current_length += 1;
if let Some(max_length) = gen_opt.max_length {
if current_length >= max_length {
break;
}
}
encoder_outputs = self.reorder_cache(&mut past, encoder_outputs, &beam_indices);
if !self.is_encoder_decoder() {
@ -1329,15 +1354,13 @@ pub(crate) mod private_generation_utils {
attention_mask.as_ref(),
Tensor::ones(
&[*attention_mask.size().first().unwrap(), 1],
(Int64, attention_mask.device()),
(Kind::Int64, attention_mask.device()),
)
.as_ref(),
],
-1,
);
}
current_length += 1;
}
let mut batch_index = 0i64;
@ -1377,7 +1400,7 @@ pub(crate) mod private_generation_utils {
};
let mut sentence_lengths =
Tensor::zeros(&[output_batch_size], (Int64, input_ids.device()));
Tensor::zeros(&[output_batch_size], (Kind::Int64, input_ids.device()));
let mut best_ids = vec![];
let mut scores_output = if output_scores {
@ -1421,11 +1444,14 @@ pub(crate) mod private_generation_utils {
}
}
}
let sentence_max_length =
min(i64::from(sentence_lengths.max()) + 1, gen_opt.max_length);
let sentence_max_length = gen_opt
.max_length
.map(|max_length| min(i64::from(sentence_lengths.max()) + 1, max_length))
.unwrap_or(i64::from(sentence_lengths.max()) + 1);
let mut decoded = input_ids.new_empty(
&[output_batch_size, sentence_max_length],
(Int64, input_ids.device()),
(Kind::Int64, input_ids.device()),
);
if i64::from(sentence_lengths.max()) != i64::from(sentence_lengths.min()) {
let _ = decoded.fill_(
@ -1440,12 +1466,15 @@ pub(crate) mod private_generation_utils {
&Tensor::arange_start(
0,
i64::from(sentence_lengths.get(hypothesis_index as i64)),
(Int64, input_ids.device()),
(Kind::Int64, input_ids.device()),
),
best_id,
);
let sentence_length = i64::from(sentence_lengths.get(hypothesis_index as i64));
if sentence_length < gen_opt.max_length {
let sentence_length_max = gen_opt
.max_length
.unwrap_or_else(|| i64::from(sentence_lengths.max()));
if sentence_length < sentence_length_max {
let _ = decoded.get(hypothesis_index as i64).index_fill_(
0,
&Tensor::of_slice(&[sentence_length]).to_device(input_ids.device()),
@ -1591,7 +1620,7 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
/// # let weights_path = &home.as_path().join("model.ot");
/// let device = Device::cuda_if_available();
/// let generate_config = GenerateConfig {
/// max_length: 30,
/// max_length: Some(30),
/// do_sample: true,
/// num_beams: 5,
/// temperature: 1.1,
@ -1698,7 +1727,7 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
/// # let weights_path = &home.as_path().join("model.ot");
/// let device = Device::cuda_if_available();
/// let generate_config = GenerateConfig {
/// max_length: 30,
/// max_length: Some(30),
/// do_sample: true,
/// num_beams: 5,
/// temperature: 1.1,
@ -1752,9 +1781,12 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
let eos_token_ids = self.get_eos_ids();
let config = self.get_config();
let max_length = unpack_config!(max_length, generate_options, config);
let max_length = generate_options.map_or(config.max_length, |generate_options| {
generate_options.max_length
});
let encoding_max_len = if self.is_encoder_decoder() {
self.get_max_positions_embeddings()
Some(self.get_max_positions_embeddings())
} else {
max_length
};
@ -1975,14 +2007,21 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
let max_length = if let Some(generate_options) = generate_options {
match (generate_options.max_length, generate_options.max_new_tokens) {
(Some(max_length), _) => max_length,
(None, Some(max_new_tokens)) => max_new_tokens + input_ids.size().last().unwrap(),
(Some(max_length), _) => Some(max_length),
(None, Some(max_new_tokens)) => {
Some(max_new_tokens + input_ids.size().last().unwrap())
}
(None, None) => config.max_length,
}
} else {
config.max_length
};
if max_length.is_none() & eos_token_ids.is_none() {
panic!("No maximum length given for a model without an EOS token. \
This would lead to an infinite generation loop. Please provide a `max_length` or `max_new_tokens`")
}
let gen_opt = InternalGenerateOptions {
min_length,
max_length,
@ -2083,7 +2122,7 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
/// # let weights_path = &home.as_path().join("model.ot");
/// let device = Device::cuda_if_available();
/// let generate_config = GenerateConfig {
/// max_length: 30,
/// max_length: Some(30),
/// do_sample: true,
/// num_beams: 5,
/// temperature: 1.1,
@ -2115,7 +2154,7 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
#[derive(Debug)]
struct BeamHypotheses {
max_length: i64,
max_length: Option<i64>,
length_penalty: f64,
early_stopping: bool,
num_beams: i64,
@ -2151,12 +2190,12 @@ impl Clone for BeamHypotheses {
impl BeamHypotheses {
fn new(
num_beams: i64,
max_length: i64,
max_length: Option<i64>,
length_penalty: f64,
early_stopping: bool,
) -> BeamHypotheses {
BeamHypotheses {
max_length: max_length - 1,
max_length: max_length.map(|max_length| max_length - 1),
length_penalty,
early_stopping,
num_beams,

View File

@ -96,7 +96,7 @@ pub struct SummarizationConfig {
/// Minimum sequence length (default: 0)
pub min_length: i64,
/// Maximum sequence length (default: 20)
pub max_length: i64,
pub max_length: Option<i64>,
/// Sampling flag. If true, will perform top-k and/or nucleus sampling on generated tokens, otherwise greedy (deterministic) decoding (default: true)
pub do_sample: bool,
/// Early stopping flag indicating if the beam search should stop as soon as `num_beam` hypotheses have been generated (default: false)
@ -154,7 +154,7 @@ impl SummarizationConfig {
vocab_resource: Box::new(vocab_resource),
merges_resource: merges_resource.map(|r| Box::new(r) as Box<_>),
min_length: 56,
max_length: 142,
max_length: Some(142),
do_sample: false,
early_stopping: true,
num_beams: 3,

View File

@ -66,8 +66,8 @@ pub struct TextGenerationConfig {
pub merges_resource: Option<Box<dyn ResourceProvider + Send>>,
/// Minimum sequence length (default: 0)
pub min_length: i64,
/// Maximum sequence length (default: 20)
pub max_length: i64,
/// Maximum sequence length (default: 56)
pub max_length: Option<i64>,
/// Sampling flag. If true, will perform top-k and/or nucleus sampling on generated tokens, otherwise greedy (deterministic) decoding (default: true)
pub do_sample: bool,
/// Early stopping flag indicating if the beam search should stop as soon as `num_beam` hypotheses have been generated (default: false)
@ -125,7 +125,7 @@ impl TextGenerationConfig {
vocab_resource: Box::new(vocab_resource),
merges_resource: merges_resource.map(|r| Box::new(r) as Box<_>),
min_length: 0,
max_length: 20,
max_length: Some(56),
do_sample: true,
early_stopping: true,
num_beams: 5,
@ -326,7 +326,7 @@ pub struct TextGenerationModel {
prefix: Option<String>,
prefix_length: Option<i64>,
min_length: i64,
max_length: i64,
max_length: Option<i64>,
}
impl TextGenerationModel {
@ -445,7 +445,7 @@ with people, even a bishop, begging for his blessing. <eod> </s> <eos>"
self.model.generate_indices(
Some(&texts),
Some(self.min_length + prefix_length),
Some(self.max_length + prefix_length),
self.max_length.map(|max_length| max_length + prefix_length),
)
}
_ => panic!("Prefix length not defined but prefix provided!"),

View File

@ -387,8 +387,8 @@ pub struct TranslationConfig {
pub target_languages: HashSet<Language>,
/// Minimum sequence length (default: 0)
pub min_length: i64,
/// Maximum sequence length (default: 20)
pub max_length: i64,
/// Maximum sequence length (default: 512)
pub max_length: Option<i64>,
/// Sampling flag. If true, will perform top-k and/or nucleus sampling on generated tokens, otherwise greedy (deterministic) decoding (default: true)
pub do_sample: bool,
/// Early stopping flag indicating if the beam search should stop as soon as `num_beam` hypotheses have been generated (default: false)
@ -488,7 +488,7 @@ impl TranslationConfig {
target_languages: target_languages.as_ref().iter().cloned().collect(),
device,
min_length: 0,
max_length: 512,
max_length: Some(512),
do_sample: false,
early_stopping: true,
num_beams: 3,

View File

@ -926,7 +926,7 @@ impl ProphetNetConditionalGenerator {
/// # let weights_path = &home.as_path().join("model.ot");
/// let device = Device::cuda_if_available();
/// let generate_config = GenerateConfig {
/// max_length: 30,
/// max_length: Some(30),
/// do_sample: true,
/// num_beams: 5,
/// temperature: 1.1,
@ -1075,7 +1075,7 @@ impl
fn encode_prompt_text<S>(
&self,
prompt_text: &[S],
max_len: i64,
max_len: Option<i64>,
pad_token_id: Option<i64>,
) -> Tensor
where
@ -1083,7 +1083,9 @@ impl
{
let tokens = self._get_tokenizer().encode_list(
prompt_text,
max_len as usize,
max_len
.map(|max_len| max_len as usize)
.unwrap_or(usize::MAX),
&TruncationStrategy::LongestFirst,
0,
);

View File

@ -985,7 +985,7 @@ impl PrivateLanguageGenerator<T5ForConditionalGeneration, T5Vocab, T5Tokenizer>
fn encode_prompt_text<S>(
&self,
prompt_text: &[S],
max_len: i64,
max_len: Option<i64>,
pad_token_id: Option<i64>,
) -> Tensor
where
@ -993,7 +993,9 @@ impl PrivateLanguageGenerator<T5ForConditionalGeneration, T5Vocab, T5Tokenizer>
{
let tokens = self._get_tokenizer().encode_list(
prompt_text,
max_len as usize,
max_len
.map(|max_len| max_len as usize)
.unwrap_or(usize::MAX),
&TruncationStrategy::LongestFirst,
0,
);

View File

@ -39,7 +39,7 @@
//! config_resource,
//! vocab_resource,
//! merges_resource: None,
//! max_length: 56,
//! max_length: Some(56),
//! do_sample: true,
//! num_beams: 3,
//! temperature: 1.0,

View File

@ -1610,7 +1610,7 @@ impl XLNetGenerator {
/// use rust_bert::xlnet::XLNetGenerator;
///
/// let generate_config = GenerateConfig {
/// max_length: 30,
/// max_length: Some(30),
/// do_sample: true,
/// num_beams: 5,
/// temperature: 1.1,

View File

@ -97,7 +97,7 @@ fn bart_summarization_greedy() -> anyhow::Result<()> {
num_beams: 1,
length_penalty: 1.0,
min_length: 56,
max_length: 142,
max_length: Some(142),
device: Device::Cpu,
..Default::default()
};
@ -157,7 +157,7 @@ fn bart_summarization_beam_search() -> anyhow::Result<()> {
merges_resource: Some(merges_resource),
num_beams: 4,
min_length: 56,
max_length: 142,
max_length: Some(142),
length_penalty: 1.0,
device: Device::Cpu,
..Default::default()

View File

@ -121,7 +121,7 @@ fn gpt2_generation_greedy() -> anyhow::Result<()> {
config_resource,
vocab_resource,
merges_resource: Some(merges_resource),
max_length: 40,
max_length: Some(40),
do_sample: false,
num_beams: 1,
temperature: 1.1,
@ -153,7 +153,7 @@ fn gpt2_generation_beam_search() -> anyhow::Result<()> {
config_resource,
vocab_resource,
merges_resource: Some(merges_resource),
max_length: 20,
max_length: Some(20),
do_sample: false,
num_beams: 5,
temperature: 1.2,
@ -197,7 +197,7 @@ fn gpt2_generation_beam_search_multiple_prompts_without_padding() -> anyhow::Res
config_resource,
vocab_resource,
merges_resource: Some(merges_resource),
max_length: 20,
max_length: Some(20),
do_sample: false,
num_beams: 5,
temperature: 1.2,
@ -254,7 +254,7 @@ fn gpt2_generation_beam_search_multiple_prompts_with_padding() -> anyhow::Result
config_resource,
vocab_resource,
merges_resource: Some(merges_resource),
max_length: 20,
max_length: Some(20),
do_sample: false,
num_beams: 5,
temperature: 1.2,
@ -311,7 +311,7 @@ fn gpt2_diverse_beam_search_multiple_prompts_with_padding() -> anyhow::Result<()
vocab_resource,
merges_resource: Some(merges_resource),
min_length: 10,
max_length: 20,
max_length: Some(20),
do_sample: false,
num_beams: 6,
num_beam_groups: Some(3),
@ -378,7 +378,7 @@ fn gpt2_prefix_allowed_token_greedy() -> anyhow::Result<()> {
}
let generate_config = GenerateConfig {
max_length: 56,
max_length: Some(56),
model_resource,
config_resource,
vocab_resource,
@ -428,7 +428,7 @@ fn gpt2_bad_tokens_greedy() -> anyhow::Result<()> {
let model_resource = Box::new(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
let generate_config = GenerateConfig {
max_length: 36,
max_length: Some(36),
model_resource,
config_resource,
vocab_resource,
@ -494,7 +494,7 @@ fn gpt2_bad_tokens_beam_search() -> anyhow::Result<()> {
let model_resource = Box::new(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
let generate_config = GenerateConfig {
max_length: 36,
max_length: Some(36),
model_resource,
config_resource,
vocab_resource,
@ -575,7 +575,7 @@ fn gpt2_prefix_allowed_token_beam_search() -> anyhow::Result<()> {
}
let generate_config = GenerateConfig {
max_length: 32,
max_length: Some(32),
model_resource,
config_resource,
vocab_resource,
@ -625,7 +625,7 @@ fn gpt2_greedy_token_scores() -> anyhow::Result<()> {
let model_resource = Box::new(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
let generate_config = GenerateConfig {
max_length: 16,
max_length: Some(16),
model_resource,
config_resource,
vocab_resource,
@ -681,7 +681,7 @@ fn gpt2_beam_search_token_scores() -> anyhow::Result<()> {
let model_resource = Box::new(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
let generate_config = GenerateConfig {
max_length: 16,
max_length: Some(16),
model_resource,
config_resource,
vocab_resource,
@ -812,7 +812,7 @@ fn dialogpt_multiple_multi_turn_conversation() -> anyhow::Result<()> {
fn dialogpt_multiple_multi_turn_conversation_with_truncation() -> anyhow::Result<()> {
// Set-up conversation model
let conversation_config = ConversationConfig {
max_length: 36,
max_length: Some(36),
min_length_for_response: 24,
do_sample: false,
device: Device::Cpu,

View File

@ -130,7 +130,7 @@ fn test_generation_gpt_neo() -> anyhow::Result<()> {
vocab_resource,
merges_resource: Some(merges_resource),
min_length: 10,
max_length: 32,
max_length: Some(32),
do_sample: false,
early_stopping: true,
num_beams: 4,

View File

@ -123,7 +123,7 @@ fn openai_gpt_generation_greedy() -> anyhow::Result<()> {
config_resource,
vocab_resource,
merges_resource: Some(merges_resource),
max_length: 40,
max_length: Some(40),
do_sample: false,
num_beams: 1,
top_p: 1.0,
@ -165,7 +165,7 @@ fn openai_gpt_generation_beam_search() -> anyhow::Result<()> {
config_resource,
vocab_resource,
merges_resource: Some(merges_resource),
max_length: 20,
max_length: Some(20),
do_sample: false,
early_stopping: true,
num_beams: 5,
@ -218,7 +218,7 @@ fn openai_gpt_generation_beam_search_multiple_prompts_without_padding() -> anyho
config_resource,
vocab_resource,
merges_resource: Some(merges_resource),
max_length: 20,
max_length: Some(20),
do_sample: false,
early_stopping: true,
num_beams: 5,
@ -287,7 +287,7 @@ fn openai_gpt_generation_beam_search_multiple_prompts_with_padding() -> anyhow::
config_resource,
vocab_resource,
merges_resource: Some(merges_resource),
max_length: 20,
max_length: Some(20),
do_sample: false,
num_beams: 5,
temperature: 2.0,

View File

@ -50,7 +50,7 @@ fn test_generation_reformer() -> anyhow::Result<()> {
vocab_resource,
merges_resource: None,
min_length: 100,
max_length: 100,
max_length: Some(100),
do_sample: false,
early_stopping: true,
no_repeat_ngram_size: 3,

View File

@ -70,7 +70,7 @@ fn test_summarization_t5() -> anyhow::Result<()> {
vocab_resource: Box::new(RemoteResource::from_pretrained(T5VocabResources::T5_SMALL)),
merges_resource: None,
min_length: 30,
max_length: 200,
max_length: Some(200),
early_stopping: true,
num_beams: 4,
length_penalty: 2.0,

View File

@ -212,7 +212,7 @@ fn xlnet_generation_beam_search() -> anyhow::Result<()> {
config_resource,
vocab_resource,
merges_resource: None,
max_length: 32,
max_length: Some(32),
do_sample: false,
num_beams: 3,
temperature: 1.0,