Make max_length optional (#296)

* Made `max_length` an optional argument for generation methods and pipelines

* Updated changelog
This commit is contained in:
guillaume-be 2022-11-15 19:20:51 +00:00 committed by GitHub
parent 5d2b107e99
commit 05367b4df2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
31 changed files with 216 additions and 147 deletions

View File

@ -8,13 +8,15 @@ All notable changes to this project will be documented in this file. The format
- Addition of Keyword/Keyphrases extraction pipeline based on KeyBERT (https://github.com/MaartenGr/KeyBERT) - Addition of Keyword/Keyphrases extraction pipeline based on KeyBERT (https://github.com/MaartenGr/KeyBERT)
## Changed ## Changed
- Addition of type aliases for the controlled generation (`PrefixAllowedFunction`) and zero-shot classification (`ZeroShotTemplate`) - Addition of type aliases for the controlled generation (`PrefixAllowedFunction`) and zero-shot classification (`ZeroShotTemplate`).
- (BREAKING) `merges_resource` now optional for all pipelines - (BREAKING) `merges_resource` now optional for all pipelines.
- Allow mixing local and remote resources in pipelines - Allow mixing local and remote resources in pipelines.
- Upgraded to `torch` 1.13 (via `tch` 0.9.0) - Upgraded to `torch` 1.13 (via `tch` 0.9.0).
- (BREAKING) Made the `max_length` argument for generation methods and pipelines optional.
## Fixed ## Fixed
- Fixed configuration check for RoBERTa models for sentence classification. - Fixed configuration check for RoBERTa models for sentence classification.
- Fixed a bug causing the input prompt to be truncated for text generation if the prompt length was longer than `max_length`
## [0.18.0] - 2022-07-24 ## [0.18.0] - 2022-07-24
## Added ## Added

View File

@ -21,7 +21,7 @@ fn create_text_generation_model() -> TextGenerationModel {
Gpt2MergesResources::GPT2, Gpt2MergesResources::GPT2,
))), ))),
min_length: 0, min_length: 0,
max_length: 30, max_length: Some(30),
do_sample: true, do_sample: true,
early_stopping: false, early_stopping: false,
num_beams: 5, num_beams: 5,

View File

@ -19,7 +19,7 @@ fn main() -> anyhow::Result<()> {
// Set-up model // Set-up model
let generate_config = TextGenerationConfig { let generate_config = TextGenerationConfig {
model_type: ModelType::GPT2, model_type: ModelType::GPT2,
max_length: 30, max_length: Some(30),
do_sample: false, do_sample: false,
num_beams: 1, num_beams: 1,
temperature: 1.0, temperature: 1.0,

View File

@ -43,7 +43,7 @@ fn main() -> anyhow::Result<()> {
vocab_resource, vocab_resource,
merges_resource: Some(merges_resource), merges_resource: Some(merges_resource),
min_length: 10, min_length: 10,
max_length: 32, max_length: Some(32),
do_sample: false, do_sample: false,
early_stopping: true, early_stopping: true,
num_beams: 4, num_beams: 4,

View File

@ -40,7 +40,7 @@ fn main() -> anyhow::Result<()> {
vocab_resource, vocab_resource,
merges_resource: None, merges_resource: None,
min_length: 100, min_length: 100,
max_length: 100, max_length: Some(100),
do_sample: true, do_sample: true,
early_stopping: false, early_stopping: false,
num_beams: 3, num_beams: 3,

View File

@ -37,7 +37,7 @@ fn main() -> anyhow::Result<()> {
config_resource, config_resource,
vocab_resource, vocab_resource,
merges_resource: None, merges_resource: None,
max_length: 32, max_length: Some(32),
do_sample: false, do_sample: false,
num_beams: 3, num_beams: 3,
temperature: 1.0, temperature: 1.0,

View File

@ -41,7 +41,7 @@ fn main() -> anyhow::Result<()> {
num_beams: 1, num_beams: 1,
length_penalty: 1.0, length_penalty: 1.0,
min_length: 56, min_length: 56,
max_length: 142, max_length: Some(142),
device: Device::Cpu, device: Device::Cpu,
..Default::default() ..Default::default()
}; };

View File

@ -1055,7 +1055,7 @@ impl BartGenerator {
/// # let weights_path = &home.as_path().join("model.ot"); /// # let weights_path = &home.as_path().join("model.ot");
/// let device = Device::cuda_if_available(); /// let device = Device::cuda_if_available();
/// let generate_config = GenerateConfig { /// let generate_config = GenerateConfig {
/// max_length: 30, /// max_length: Some(30),
/// do_sample: true, /// do_sample: true,
/// num_beams: 5, /// num_beams: 5,
/// temperature: 1.1, /// temperature: 1.1,
@ -1183,7 +1183,7 @@ impl PrivateLanguageGenerator<BartForConditionalGeneration, RobertaVocab, Robert
&self, &self,
scores: &mut Tensor, scores: &mut Tensor,
current_length: i64, current_length: i64,
max_length: i64, max_length: Option<i64>,
forced_bos_token_id: Option<i64>, forced_bos_token_id: Option<i64>,
) { ) {
if current_length == 1 { if current_length == 1 {
@ -1191,8 +1191,10 @@ impl PrivateLanguageGenerator<BartForConditionalGeneration, RobertaVocab, Robert
scores, scores,
&[forced_bos_token_id.unwrap_or_else(|| self.get_bos_id().unwrap())], &[forced_bos_token_id.unwrap_or_else(|| self.get_bos_id().unwrap())],
); );
} else if current_length == max_length - 1 { } else if let Some(max_length) = max_length {
self.force_token_id_generation(scores, self.get_eos_ids().as_ref().unwrap()); if current_length == max_length - 1 {
self.force_token_id_generation(scores, self.get_eos_ids().as_ref().unwrap());
}
} }
} }
@ -1231,7 +1233,7 @@ impl PrivateLanguageGenerator<BartForConditionalGeneration, RobertaVocab, Robert
fn encode_prompt_text<S>( fn encode_prompt_text<S>(
&self, &self,
prompt_text: &[S], prompt_text: &[S],
max_len: i64, max_len: Option<i64>,
pad_token_id: Option<i64>, pad_token_id: Option<i64>,
) -> Tensor ) -> Tensor
where where
@ -1239,7 +1241,9 @@ impl PrivateLanguageGenerator<BartForConditionalGeneration, RobertaVocab, Robert
{ {
let tokens = self._get_tokenizer().encode_list( let tokens = self._get_tokenizer().encode_list(
prompt_text, prompt_text,
max_len as usize, max_len
.map(|max_len| max_len as usize)
.unwrap_or(usize::MAX),
&TruncationStrategy::LongestFirst, &TruncationStrategy::LongestFirst,
0, 0,
); );

View File

@ -695,7 +695,7 @@ impl GPT2Generator {
/// use rust_bert::pipelines::generation_utils::GenerateConfig; /// use rust_bert::pipelines::generation_utils::GenerateConfig;
/// ///
/// let generate_config = GenerateConfig { /// let generate_config = GenerateConfig {
/// max_length: 30, /// max_length: Some(30),
/// do_sample: true, /// do_sample: true,
/// num_beams: 5, /// num_beams: 5,
/// temperature: 1.1, /// temperature: 1.1,

View File

@ -670,7 +670,7 @@ impl GptNeoGenerator {
/// use rust_bert::pipelines::generation_utils::GenerateConfig; /// use rust_bert::pipelines::generation_utils::GenerateConfig;
/// ///
/// let generate_config = GenerateConfig { /// let generate_config = GenerateConfig {
/// max_length: 30, /// max_length: Some(30),
/// do_sample: true, /// do_sample: true,
/// num_beams: 5, /// num_beams: 5,
/// temperature: 1.1, /// temperature: 1.1,

View File

@ -604,7 +604,7 @@ impl M2M100Generator {
/// # let weights_path = &home.as_path().join("model.ot"); /// # let weights_path = &home.as_path().join("model.ot");
/// let device = Device::cuda_if_available(); /// let device = Device::cuda_if_available();
/// let generate_config = GenerateConfig { /// let generate_config = GenerateConfig {
/// max_length: 30, /// max_length: Some(30),
/// do_sample: true, /// do_sample: true,
/// num_beams: 5, /// num_beams: 5,
/// temperature: 1.1, /// temperature: 1.1,
@ -734,13 +734,15 @@ impl PrivateLanguageGenerator<M2M100ForConditionalGeneration, M2M100Vocab, M2M10
&self, &self,
scores: &mut Tensor, scores: &mut Tensor,
current_length: i64, current_length: i64,
max_length: i64, max_length: Option<i64>,
forced_bos_token_id: Option<i64>, forced_bos_token_id: Option<i64>,
) { ) {
if current_length == 1 { if current_length == 1 {
self.force_token_id_generation(scores, &[forced_bos_token_id.unwrap_or(250004)]); self.force_token_id_generation(scores, &[forced_bos_token_id.unwrap_or(250004)]);
} else if current_length == max_length - 1 { } else if let Some(max_length) = max_length {
self.force_token_id_generation(scores, self.get_eos_ids().as_ref().unwrap()); if current_length == max_length - 1 {
self.force_token_id_generation(scores, self.get_eos_ids().as_ref().unwrap());
}
} }
} }
@ -779,7 +781,7 @@ impl PrivateLanguageGenerator<M2M100ForConditionalGeneration, M2M100Vocab, M2M10
fn encode_prompt_text<S>( fn encode_prompt_text<S>(
&self, &self,
prompt_text: &[S], prompt_text: &[S],
max_len: i64, max_len: Option<i64>,
pad_token_id: Option<i64>, pad_token_id: Option<i64>,
) -> Tensor ) -> Tensor
where where
@ -787,7 +789,9 @@ impl PrivateLanguageGenerator<M2M100ForConditionalGeneration, M2M100Vocab, M2M10
{ {
let tokens = self._get_tokenizer().encode_list( let tokens = self._get_tokenizer().encode_list(
prompt_text, prompt_text,
max_len as usize, max_len
.map(|max_len| max_len as usize)
.unwrap_or(usize::MAX),
&TruncationStrategy::LongestFirst, &TruncationStrategy::LongestFirst,
0, 0,
); );

View File

@ -824,7 +824,7 @@ impl MarianGenerator {
/// # let weights_path = &home.as_path().join("model.ot"); /// # let weights_path = &home.as_path().join("model.ot");
/// let device = Device::cuda_if_available(); /// let device = Device::cuda_if_available();
/// let generate_config = GenerateConfig { /// let generate_config = GenerateConfig {
/// max_length: 512, /// max_length: Some(512),
/// do_sample: true, /// do_sample: true,
/// num_beams: 6, /// num_beams: 6,
/// temperature: 1.0, /// temperature: 1.0,
@ -956,7 +956,7 @@ impl PrivateLanguageGenerator<MarianForConditionalGeneration, MarianVocab, Maria
&self, &self,
scores: &mut Tensor, scores: &mut Tensor,
current_length: i64, current_length: i64,
max_length: i64, max_length: Option<i64>,
_forced_bos_token_id: Option<i64>, _forced_bos_token_id: Option<i64>,
) { ) {
let _ = scores.index_fill_( let _ = scores.index_fill_(
@ -966,8 +966,10 @@ impl PrivateLanguageGenerator<MarianForConditionalGeneration, MarianVocab, Maria
.to_device(scores.device()), .to_device(scores.device()),
f64::NEG_INFINITY, f64::NEG_INFINITY,
); );
if current_length == max_length - 1 { if let Some(max_length) = max_length {
self.force_token_id_generation(scores, self.get_eos_ids().as_ref().unwrap()); if current_length == max_length - 1 {
self.force_token_id_generation(scores, self.get_eos_ids().as_ref().unwrap());
}
} }
} }
@ -1006,7 +1008,7 @@ impl PrivateLanguageGenerator<MarianForConditionalGeneration, MarianVocab, Maria
fn encode_prompt_text<S>( fn encode_prompt_text<S>(
&self, &self,
prompt_text: &[S], prompt_text: &[S],
max_len: i64, max_len: Option<i64>,
pad_token_id: Option<i64>, pad_token_id: Option<i64>,
) -> Tensor ) -> Tensor
where where
@ -1014,7 +1016,9 @@ impl PrivateLanguageGenerator<MarianForConditionalGeneration, MarianVocab, Maria
{ {
let tokens = self._get_tokenizer().encode_list( let tokens = self._get_tokenizer().encode_list(
prompt_text, prompt_text,
max_len as usize, max_len
.map(|max_len| max_len as usize)
.unwrap_or(usize::MAX),
&TruncationStrategy::LongestFirst, &TruncationStrategy::LongestFirst,
0, 0,
); );

View File

@ -862,7 +862,7 @@ impl MBartGenerator {
/// # let weights_path = &home.as_path().join("model.ot"); /// # let weights_path = &home.as_path().join("model.ot");
/// let device = Device::cuda_if_available(); /// let device = Device::cuda_if_available();
/// let generate_config = GenerateConfig { /// let generate_config = GenerateConfig {
/// max_length: 30, /// max_length: Some(30),
/// do_sample: true, /// do_sample: true,
/// num_beams: 5, /// num_beams: 5,
/// temperature: 1.1, /// temperature: 1.1,
@ -983,13 +983,15 @@ impl PrivateLanguageGenerator<MBartForConditionalGeneration, MBart50Vocab, MBart
&self, &self,
scores: &mut Tensor, scores: &mut Tensor,
current_length: i64, current_length: i64,
max_length: i64, max_length: Option<i64>,
forced_bos_token_id: Option<i64>, forced_bos_token_id: Option<i64>,
) { ) {
if current_length == 1 { if current_length == 1 {
self.force_token_id_generation(scores, &[forced_bos_token_id.unwrap_or(250004)]); self.force_token_id_generation(scores, &[forced_bos_token_id.unwrap_or(250004)]);
} else if current_length == max_length - 1 { } else if let Some(max_length) = max_length {
self.force_token_id_generation(scores, self.get_eos_ids().as_ref().unwrap()); if current_length == max_length - 1 {
self.force_token_id_generation(scores, self.get_eos_ids().as_ref().unwrap());
}
} }
} }
@ -1028,7 +1030,7 @@ impl PrivateLanguageGenerator<MBartForConditionalGeneration, MBart50Vocab, MBart
fn encode_prompt_text<S>( fn encode_prompt_text<S>(
&self, &self,
prompt_text: &[S], prompt_text: &[S],
max_len: i64, max_len: Option<i64>,
pad_token_id: Option<i64>, pad_token_id: Option<i64>,
) -> Tensor ) -> Tensor
where where
@ -1036,7 +1038,9 @@ impl PrivateLanguageGenerator<MBartForConditionalGeneration, MBart50Vocab, MBart
{ {
let tokens = self._get_tokenizer().encode_list( let tokens = self._get_tokenizer().encode_list(
prompt_text, prompt_text,
max_len as usize, max_len
.map(|max_len| max_len as usize)
.unwrap_or(usize::MAX),
&TruncationStrategy::LongestFirst, &TruncationStrategy::LongestFirst,
0, 0,
); );

View File

@ -457,7 +457,7 @@ impl OpenAIGenerator {
/// use rust_bert::openai_gpt::OpenAIGenerator; /// use rust_bert::openai_gpt::OpenAIGenerator;
/// use rust_bert::pipelines::generation_utils::GenerateConfig; /// use rust_bert::pipelines::generation_utils::GenerateConfig;
/// let generate_config = GenerateConfig { /// let generate_config = GenerateConfig {
/// max_length: 30, /// max_length: Some(30),
/// do_sample: true, /// do_sample: true,
/// num_beams: 5, /// num_beams: 5,
/// temperature: 1.1, /// temperature: 1.1,

View File

@ -585,7 +585,7 @@ impl PegasusConditionalGenerator {
/// # let weights_path = &home.as_path().join("model.ot"); /// # let weights_path = &home.as_path().join("model.ot");
/// let device = Device::cuda_if_available(); /// let device = Device::cuda_if_available();
/// let generate_config = GenerateConfig { /// let generate_config = GenerateConfig {
/// max_length: 30, /// max_length: Some(30),
/// do_sample: true, /// do_sample: true,
/// num_beams: 5, /// num_beams: 5,
/// temperature: 1.1, /// temperature: 1.1,
@ -710,11 +710,13 @@ impl PrivateLanguageGenerator<PegasusForConditionalGeneration, PegasusVocab, Peg
&self, &self,
scores: &mut Tensor, scores: &mut Tensor,
current_length: i64, current_length: i64,
max_length: i64, max_length: Option<i64>,
_forced_bos_token_id: Option<i64>, _forced_bos_token_id: Option<i64>,
) { ) {
if current_length == max_length - 1 { if let Some(max_length) = max_length {
self.force_token_id_generation(scores, self.get_eos_ids().as_ref().unwrap()); if current_length == max_length - 1 {
self.force_token_id_generation(scores, self.get_eos_ids().as_ref().unwrap());
}
} }
} }
@ -753,7 +755,7 @@ impl PrivateLanguageGenerator<PegasusForConditionalGeneration, PegasusVocab, Peg
fn encode_prompt_text<S>( fn encode_prompt_text<S>(
&self, &self,
prompt_text: &[S], prompt_text: &[S],
max_len: i64, max_len: Option<i64>,
pad_token_id: Option<i64>, pad_token_id: Option<i64>,
) -> Tensor ) -> Tensor
where where
@ -761,7 +763,9 @@ impl PrivateLanguageGenerator<PegasusForConditionalGeneration, PegasusVocab, Peg
{ {
let tokens = self._get_tokenizer().encode_list( let tokens = self._get_tokenizer().encode_list(
prompt_text, prompt_text,
max_len as usize, max_len
.map(|max_len| max_len as usize)
.unwrap_or(usize::MAX),
&TruncationStrategy::LongestFirst, &TruncationStrategy::LongestFirst,
0, 0,
); );

View File

@ -86,7 +86,7 @@ pub struct ConversationConfig {
/// Minimum sequence length (default: 0) /// Minimum sequence length (default: 0)
pub min_length: i64, pub min_length: i64,
/// Maximum sequence length (default: 20) /// Maximum sequence length (default: 20)
pub max_length: i64, pub max_length: Option<i64>,
/// Minimum free length available for generated responses (default: 32) /// Minimum free length available for generated responses (default: 32)
pub min_length_for_response: i64, pub min_length_for_response: i64,
/// Sampling flag. If true, will perform top-k and/or nucleus sampling on generated tokens, otherwise greedy (deterministic) decoding (default: true) /// Sampling flag. If true, will perform top-k and/or nucleus sampling on generated tokens, otherwise greedy (deterministic) decoding (default: true)
@ -135,7 +135,7 @@ impl Default for ConversationConfig {
Gpt2MergesResources::DIALOGPT_MEDIUM, Gpt2MergesResources::DIALOGPT_MEDIUM,
))), ))),
min_length: 0, min_length: 0,
max_length: 1000, max_length: Some(1000),
min_length_for_response: 64, min_length_for_response: 64,
do_sample: true, do_sample: true,
early_stopping: false, early_stopping: false,
@ -750,7 +750,7 @@ impl ConversationOption {
pub struct ConversationModel { pub struct ConversationModel {
model: ConversationOption, model: ConversationOption,
eos_token_id: i64, eos_token_id: i64,
max_allowed_context_length: i64, max_allowed_context_length: Option<i64>,
device: Device, device: Device,
} }
@ -774,8 +774,9 @@ impl ConversationModel {
pub fn new( pub fn new(
conversation_config: ConversationConfig, conversation_config: ConversationConfig,
) -> Result<ConversationModel, RustBertError> { ) -> Result<ConversationModel, RustBertError> {
let max_allowed_length = let max_allowed_length = conversation_config
conversation_config.max_length - conversation_config.min_length_for_response; .max_length
.map(|max_length| max_length - conversation_config.min_length_for_response);
let device = conversation_config.device; let device = conversation_config.device;
let model = ConversationOption::new(conversation_config)?; let model = ConversationOption::new(conversation_config)?;
let eos_token_id = model.get_eos_id()?; let eos_token_id = model.get_eos_id()?;
@ -921,17 +922,18 @@ impl ConversationModel {
let truncated_concatenated_inputs = concatenated_inputs let truncated_concatenated_inputs = concatenated_inputs
.iter() .iter()
.map(|input| { .map(|input| match self.max_allowed_context_length {
if input.len() > self.max_allowed_context_length as usize { Some(max_allowed_context_length)
if input.len() > max_allowed_context_length as usize =>
{
let start = self.get_truncated_input_index( let start = self.get_truncated_input_index(
input, input,
self.max_allowed_context_length as usize, max_allowed_context_length as usize,
pad_token, pad_token,
); );
&input[start..] &input[start..]
} else {
input.as_slice()
} }
_ => input.as_slice(),
}) })
.collect::<Vec<&[i64]>>(); .collect::<Vec<&[i64]>>();
@ -1018,7 +1020,9 @@ impl ConversationModel {
.convert_tokens_to_ids(&prompt_tokens) .convert_tokens_to_ids(&prompt_tokens)
}) })
.map(|mut tokens| { .map(|mut tokens| {
tokens.truncate(self.max_allowed_context_length as usize - 1); if let Some(max_allowed_context_length) = self.max_allowed_context_length {
tokens.truncate(max_allowed_context_length as usize - 1);
}
tokens.push(self.eos_token_id); tokens.push(self.eos_token_id);
tokens tokens
}) })

View File

@ -107,7 +107,7 @@ pub struct GenerateConfig {
/// Minimum sequence length (default: 0) /// Minimum sequence length (default: 0)
pub min_length: i64, pub min_length: i64,
/// Maximum sequence length (default: 20) /// Maximum sequence length (default: 20)
pub max_length: i64, pub max_length: Option<i64>,
/// Sampling flag. If true, will perform top-k and/or nucleus sampling on generated tokens, otherwise greedy (deterministic) decoding (default: true) /// Sampling flag. If true, will perform top-k and/or nucleus sampling on generated tokens, otherwise greedy (deterministic) decoding (default: true)
pub do_sample: bool, pub do_sample: bool,
/// Early stopping flag indicating if the beam search should stop as soon as `num_beam` hypotheses have been generated (default: false) /// Early stopping flag indicating if the beam search should stop as soon as `num_beam` hypotheses have been generated (default: false)
@ -147,7 +147,7 @@ impl Default for GenerateConfig {
Gpt2MergesResources::GPT2, Gpt2MergesResources::GPT2,
))), ))),
min_length: 0, min_length: 0,
max_length: 20, max_length: Some(56),
do_sample: true, do_sample: true,
early_stopping: true, early_stopping: true,
num_beams: 5, num_beams: 5,
@ -234,7 +234,6 @@ pub(crate) mod private_generation_utils {
use rust_tokenizers::tokenizer::{truncate_sequences, Tokenizer, TruncationStrategy}; use rust_tokenizers::tokenizer::{truncate_sequences, Tokenizer, TruncationStrategy};
use rust_tokenizers::vocab::Vocab; use rust_tokenizers::vocab::Vocab;
use rust_tokenizers::TokenIdsWithOffsets; use rust_tokenizers::TokenIdsWithOffsets;
use tch::kind::Kind::{Bool, Float, Int64};
use tch::{nn, Device, Kind, Tensor}; use tch::{nn, Device, Kind, Tensor};
use crate::pipelines::common::TokenizerOption; use crate::pipelines::common::TokenizerOption;
@ -247,7 +246,7 @@ pub(crate) mod private_generation_utils {
pub struct InternalGenerateOptions<'a> { pub struct InternalGenerateOptions<'a> {
pub min_length: i64, pub min_length: i64,
pub max_length: i64, pub max_length: Option<i64>,
pub do_sample: bool, pub do_sample: bool,
pub temperature: f64, pub temperature: f64,
pub top_k: i64, pub top_k: i64,
@ -299,7 +298,7 @@ pub(crate) mod private_generation_utils {
&self, &self,
_scores: &mut Tensor, _scores: &mut Tensor,
_current_length: i64, _current_length: i64,
_max_length: i64, _max_length: Option<i64>,
_forced_bos_token_id: Option<i64>, _forced_bos_token_id: Option<i64>,
) { ) {
} }
@ -328,7 +327,7 @@ pub(crate) mod private_generation_utils {
fn encode_prompt_text<S>( fn encode_prompt_text<S>(
&self, &self,
prompt_text: &[S], prompt_text: &[S],
max_len: i64, max_len: Option<i64>,
pad_token_id: Option<i64>, pad_token_id: Option<i64>,
) -> Tensor ) -> Tensor
where where
@ -343,11 +342,15 @@ pub(crate) mod private_generation_utils {
let num_truncated_tokens = token_ids let num_truncated_tokens = token_ids
.iter() .iter()
.map(|token_ids| { .map(|token_ids| {
if token_ids.len() > max_len as usize { max_len
token_ids.len() - max_len as usize .map(|max_len| {
} else { if token_ids.len() > max_len as usize {
0 token_ids.len() - max_len as usize
} } else {
0
}
})
.unwrap_or(0)
}) })
.collect::<Vec<usize>>(); .collect::<Vec<usize>>();
@ -408,7 +411,7 @@ pub(crate) mod private_generation_utils {
let _ = next_token_logits.get(i).index_fill_( let _ = next_token_logits.get(i).index_fill_(
0, 0,
&Tensor::of_slice(&[token]) &Tensor::of_slice(&[token])
.to_kind(Int64) .to_kind(Kind::Int64)
.to_device(next_token_logits.device()), .to_device(next_token_logits.device()),
updated_value * repetition_penalty, updated_value * repetition_penalty,
); );
@ -416,7 +419,7 @@ pub(crate) mod private_generation_utils {
let _ = next_token_logits.get(i).index_fill_( let _ = next_token_logits.get(i).index_fill_(
0, 0,
&Tensor::of_slice(&[token]) &Tensor::of_slice(&[token])
.to_kind(Int64) .to_kind(Kind::Int64)
.to_device(next_token_logits.device()), .to_device(next_token_logits.device()),
updated_value / repetition_penalty, updated_value / repetition_penalty,
); );
@ -498,17 +501,21 @@ pub(crate) mod private_generation_utils {
.softmax(-1, sorted_logits.kind()) .softmax(-1, sorted_logits.kind())
.cumsum(-1, sorted_logits.kind()); .cumsum(-1, sorted_logits.kind());
let mut sorted_indices_to_remove = let mut sorted_indices_to_remove =
cumulative_probabilities.ge(top_p).to_kind(Int64); cumulative_probabilities.ge(top_p).to_kind(Kind::Int64);
if min_tokens_to_keep > 1 { if min_tokens_to_keep > 1 {
let _ = sorted_indices_to_remove.index_fill_( let _ = sorted_indices_to_remove.index_fill_(
1, 1,
&Tensor::arange_start(0, min_tokens_to_keep + 1, (Int64, logits.device())), &Tensor::arange_start(
0,
min_tokens_to_keep + 1,
(Kind::Int64, logits.device()),
),
0, 0,
); );
} }
let _ = sorted_indices_to_remove.index_copy_( let _ = sorted_indices_to_remove.index_copy_(
1, 1,
&Tensor::arange_start(1, vocab_size, (Int64, logits.device())), &Tensor::arange_start(1, vocab_size, (Kind::Int64, logits.device())),
&sorted_indices_to_remove &sorted_indices_to_remove
.slice(1, 0, vocab_size - 1, 1) .slice(1, 0, vocab_size - 1, 1)
.copy(), .copy(),
@ -516,13 +523,13 @@ pub(crate) mod private_generation_utils {
let _ = sorted_indices_to_remove.index_fill_( let _ = sorted_indices_to_remove.index_fill_(
1, 1,
&Tensor::of_slice(&[0]) &Tensor::of_slice(&[0])
.to_kind(Int64) .to_kind(Kind::Int64)
.to_device(sorted_indices_to_remove.device()), .to_device(sorted_indices_to_remove.device()),
0, 0,
); );
let indices_to_remove = sorted_indices_to_remove let indices_to_remove = sorted_indices_to_remove
.scatter(1, &sorted_indices, &sorted_indices_to_remove) .scatter(1, &sorted_indices, &sorted_indices_to_remove)
.to_kind(Bool); .to_kind(Kind::Bool);
let _ = logits.masked_fill_(&indices_to_remove, f64::NEG_INFINITY); let _ = logits.masked_fill_(&indices_to_remove, f64::NEG_INFINITY);
} }
} }
@ -746,10 +753,9 @@ pub(crate) mod private_generation_utils {
output_scores: bool, output_scores: bool,
) -> GeneratedOutputWithScores { ) -> GeneratedOutputWithScores {
let mut unfinished_sentences = let mut unfinished_sentences =
Tensor::ones(&[batch_size], (Int64, self.get_var_store().device())); Tensor::ones(&[batch_size], (Kind::Int64, self.get_var_store().device()));
let mut sentence_lengths: Tensor = let mut sentence_lengths: Tensor =
Tensor::ones(&[batch_size], (Int64, self.get_var_store().device())) Tensor::ones(&[batch_size], (Kind::Int64, self.get_var_store().device()));
* gen_opt.max_length as i64;
let (bad_word_ids_length_1, bad_word_ids_length_greater_than_1) = let (bad_word_ids_length_1, bad_word_ids_length_greater_than_1) =
self.split_bad_word_ids(gen_opt.bad_word_ids); self.split_bad_word_ids(gen_opt.bad_word_ids);
let mut static_bad_words_mask: Option<Tensor> = None; let mut static_bad_words_mask: Option<Tensor> = None;
@ -761,7 +767,7 @@ pub(crate) mod private_generation_utils {
let mut token_scores_output: Option<Vec<Tensor>> = let mut token_scores_output: Option<Vec<Tensor>> =
if output_scores { Some(vec![]) } else { None }; if output_scores { Some(vec![]) } else { None };
while current_length < gen_opt.max_length { loop {
let prepared_input = self.prepare_inputs_for_generation( let prepared_input = self.prepare_inputs_for_generation(
input_ids.copy(), input_ids.copy(),
encoder_outputs.as_ref(), encoder_outputs.as_ref(),
@ -902,11 +908,12 @@ pub(crate) mod private_generation_utils {
input_ids = Tensor::cat(&[input_ids, tokens_to_add.unsqueeze(-1)], -1); input_ids = Tensor::cat(&[input_ids, tokens_to_add.unsqueeze(-1)], -1);
if gen_opt.eos_token_ids.is_some() { if gen_opt.eos_token_ids.is_some() {
for eos_token_id in gen_opt.eos_token_ids.as_ref().unwrap() { for eos_token_id in gen_opt.eos_token_ids.as_ref().unwrap() {
let sentence_with_eos = tokens_to_add.eq(*eos_token_id).to_kind(Int64); let sentence_with_eos =
tokens_to_add.eq(*eos_token_id).to_kind(Kind::Int64);
let sentence_with_eos: Tensor = sentence_with_eos * &unfinished_sentences; let sentence_with_eos: Tensor = sentence_with_eos * &unfinished_sentences;
let _ = sentence_lengths.masked_fill_( let _ = sentence_lengths.masked_fill_(
&sentence_with_eos &sentence_with_eos
.to_kind(Bool) .to_kind(Kind::Bool)
.to_device(sentence_lengths.device()), .to_device(sentence_lengths.device()),
current_length as i64 + 1, current_length as i64 + 1,
); );
@ -922,7 +929,7 @@ pub(crate) mod private_generation_utils {
attention_mask.as_ref(), attention_mask.as_ref(),
Tensor::ones( Tensor::ones(
&[*attention_mask.size().first().unwrap(), 1], &[*attention_mask.size().first().unwrap(), 1],
(Int64, attention_mask.device()), (Kind::Int64, attention_mask.device()),
) )
.as_ref(), .as_ref(),
], ],
@ -930,6 +937,17 @@ pub(crate) mod private_generation_utils {
); );
} }
current_length += 1; current_length += 1;
if let Some(max_length) = gen_opt.max_length {
if current_length >= max_length {
let _ = sentence_lengths.masked_fill_(
&unfinished_sentences
.to_kind(Kind::Bool)
.to_device(sentence_lengths.device()),
current_length as i64,
);
break;
}
}
} }
let scores_output = token_scores_output.as_ref().map(|scores_tensor| { let scores_output = token_scores_output.as_ref().map(|scores_tensor| {
(Tensor::stack(scores_tensor, 1).sum_dim_intlist( (Tensor::stack(scores_tensor, 1).sum_dim_intlist(
@ -993,7 +1011,7 @@ pub(crate) mod private_generation_utils {
let vocab_size = self.get_vocab_size(); let vocab_size = self.get_vocab_size();
let beam_scores = Tensor::ones( let beam_scores = Tensor::ones(
&[batch_size, gen_opt.num_beams], &[batch_size, gen_opt.num_beams],
(Float, self.get_var_store().device()), (Kind::Float, self.get_var_store().device()),
) * -1e9; ) * -1e9;
let _ = beam_scores let _ = beam_scores
.slice(1, 0, *beam_scores.size().last().unwrap(), num_sub_beams) .slice(1, 0, *beam_scores.size().last().unwrap(), num_sub_beams)
@ -1002,11 +1020,11 @@ pub(crate) mod private_generation_utils {
let mut beam_scores = beam_scores.view_(&[-1]); let mut beam_scores = beam_scores.view_(&[-1]);
let mut beam_tokens = Tensor::zeros( let mut beam_tokens = Tensor::zeros(
&[batch_size * gen_opt.num_beams], &[batch_size * gen_opt.num_beams],
(Int64, self.get_var_store().device()), (Kind::Int64, self.get_var_store().device()),
); );
let mut beam_indices = Tensor::zeros( let mut beam_indices = Tensor::zeros(
&[batch_size * gen_opt.num_beams], &[batch_size * gen_opt.num_beams],
(Int64, self.get_var_store().device()), (Kind::Int64, self.get_var_store().device()),
); );
let mut saved_beam_scores: Option<Vec<Tensor>> = let mut saved_beam_scores: Option<Vec<Tensor>> =
if output_scores { Some(vec![]) } else { None }; if output_scores { Some(vec![]) } else { None };
@ -1019,7 +1037,7 @@ pub(crate) mod private_generation_utils {
let mut encoder_outputs = encoder_outputs; let mut encoder_outputs = encoder_outputs;
let mut current_length = cur_len; let mut current_length = cur_len;
while current_length < gen_opt.max_length { loop {
if num_beam_groups > 1 { if num_beam_groups > 1 {
current_tokens = Tensor::zeros( current_tokens = Tensor::zeros(
&[batch_size * gen_opt.num_beams], &[batch_size * gen_opt.num_beams],
@ -1209,19 +1227,19 @@ pub(crate) mod private_generation_utils {
let eos_token_ids = gen_opt.eos_token_ids.as_ref(); let eos_token_ids = gen_opt.eos_token_ids.as_ref();
let beam_ids_tensor = &next_tokens.divide_scalar_mode(vocab_size, "floor"); let beam_ids_tensor = &next_tokens.divide_scalar_mode(vocab_size, "floor");
let effective_beam_ids_tensor = (&next_tokens.ones_like().cumsum(0, Int64) - 1) let effective_beam_ids_tensor =
* group_size (&next_tokens.ones_like().cumsum(0, Kind::Int64) - 1) * group_size
+ beam_ids_tensor; + beam_ids_tensor;
let token_id_tensor = &next_tokens - beam_ids_tensor * vocab_size; let token_id_tensor = &next_tokens - beam_ids_tensor * vocab_size;
let (max_scores, _) = next_scores.max_dim(1, false); let (max_scores, _) = next_scores.max_dim(1, false);
let mut eos_mask = token_id_tensor.ones_like(); let mut eos_mask = token_id_tensor.ones_like();
if let Some(eos_token_id) = eos_token_ids { if let Some(eos_token_id) = eos_token_ids {
eos_mask -= token_id_tensor.eq(eos_token_id[0]).to_kind(Int64); eos_mask -= token_id_tensor.eq(eos_token_id[0]).to_kind(Kind::Int64);
} }
let eos_mask2 = eos_mask let eos_mask2 = eos_mask
.cumsum(1, Int64) .cumsum(1, Kind::Int64)
.le(group_size) .le(group_size)
.to_kind(Bool) .to_kind(Kind::Bool)
.logical_and(&eos_mask); .logical_and(&eos_mask);
let group_beam_scores = next_scores.masked_select(&eos_mask2); let group_beam_scores = next_scores.masked_select(&eos_mask2);
@ -1321,6 +1339,13 @@ pub(crate) mod private_generation_utils {
], ],
-1, -1,
); );
current_length += 1;
if let Some(max_length) = gen_opt.max_length {
if current_length >= max_length {
break;
}
}
encoder_outputs = self.reorder_cache(&mut past, encoder_outputs, &beam_indices); encoder_outputs = self.reorder_cache(&mut past, encoder_outputs, &beam_indices);
if !self.is_encoder_decoder() { if !self.is_encoder_decoder() {
@ -1329,15 +1354,13 @@ pub(crate) mod private_generation_utils {
attention_mask.as_ref(), attention_mask.as_ref(),
Tensor::ones( Tensor::ones(
&[*attention_mask.size().first().unwrap(), 1], &[*attention_mask.size().first().unwrap(), 1],
(Int64, attention_mask.device()), (Kind::Int64, attention_mask.device()),
) )
.as_ref(), .as_ref(),
], ],
-1, -1,
); );
} }
current_length += 1;
} }
let mut batch_index = 0i64; let mut batch_index = 0i64;
@ -1377,7 +1400,7 @@ pub(crate) mod private_generation_utils {
}; };
let mut sentence_lengths = let mut sentence_lengths =
Tensor::zeros(&[output_batch_size], (Int64, input_ids.device())); Tensor::zeros(&[output_batch_size], (Kind::Int64, input_ids.device()));
let mut best_ids = vec![]; let mut best_ids = vec![];
let mut scores_output = if output_scores { let mut scores_output = if output_scores {
@ -1421,11 +1444,14 @@ pub(crate) mod private_generation_utils {
} }
} }
} }
let sentence_max_length = let sentence_max_length = gen_opt
min(i64::from(sentence_lengths.max()) + 1, gen_opt.max_length); .max_length
.map(|max_length| min(i64::from(sentence_lengths.max()) + 1, max_length))
.unwrap_or(i64::from(sentence_lengths.max()) + 1);
let mut decoded = input_ids.new_empty( let mut decoded = input_ids.new_empty(
&[output_batch_size, sentence_max_length], &[output_batch_size, sentence_max_length],
(Int64, input_ids.device()), (Kind::Int64, input_ids.device()),
); );
if i64::from(sentence_lengths.max()) != i64::from(sentence_lengths.min()) { if i64::from(sentence_lengths.max()) != i64::from(sentence_lengths.min()) {
let _ = decoded.fill_( let _ = decoded.fill_(
@ -1440,12 +1466,15 @@ pub(crate) mod private_generation_utils {
&Tensor::arange_start( &Tensor::arange_start(
0, 0,
i64::from(sentence_lengths.get(hypothesis_index as i64)), i64::from(sentence_lengths.get(hypothesis_index as i64)),
(Int64, input_ids.device()), (Kind::Int64, input_ids.device()),
), ),
best_id, best_id,
); );
let sentence_length = i64::from(sentence_lengths.get(hypothesis_index as i64)); let sentence_length = i64::from(sentence_lengths.get(hypothesis_index as i64));
if sentence_length < gen_opt.max_length { let sentence_length_max = gen_opt
.max_length
.unwrap_or_else(|| i64::from(sentence_lengths.max()));
if sentence_length < sentence_length_max {
let _ = decoded.get(hypothesis_index as i64).index_fill_( let _ = decoded.get(hypothesis_index as i64).index_fill_(
0, 0,
&Tensor::of_slice(&[sentence_length]).to_device(input_ids.device()), &Tensor::of_slice(&[sentence_length]).to_device(input_ids.device()),
@ -1591,7 +1620,7 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
/// # let weights_path = &home.as_path().join("model.ot"); /// # let weights_path = &home.as_path().join("model.ot");
/// let device = Device::cuda_if_available(); /// let device = Device::cuda_if_available();
/// let generate_config = GenerateConfig { /// let generate_config = GenerateConfig {
/// max_length: 30, /// max_length: Some(30),
/// do_sample: true, /// do_sample: true,
/// num_beams: 5, /// num_beams: 5,
/// temperature: 1.1, /// temperature: 1.1,
@ -1698,7 +1727,7 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
/// # let weights_path = &home.as_path().join("model.ot"); /// # let weights_path = &home.as_path().join("model.ot");
/// let device = Device::cuda_if_available(); /// let device = Device::cuda_if_available();
/// let generate_config = GenerateConfig { /// let generate_config = GenerateConfig {
/// max_length: 30, /// max_length: Some(30),
/// do_sample: true, /// do_sample: true,
/// num_beams: 5, /// num_beams: 5,
/// temperature: 1.1, /// temperature: 1.1,
@ -1752,9 +1781,12 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
let eos_token_ids = self.get_eos_ids(); let eos_token_ids = self.get_eos_ids();
let config = self.get_config(); let config = self.get_config();
let max_length = unpack_config!(max_length, generate_options, config);
let max_length = generate_options.map_or(config.max_length, |generate_options| {
generate_options.max_length
});
let encoding_max_len = if self.is_encoder_decoder() { let encoding_max_len = if self.is_encoder_decoder() {
self.get_max_positions_embeddings() Some(self.get_max_positions_embeddings())
} else { } else {
max_length max_length
}; };
@ -1975,14 +2007,21 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
let max_length = if let Some(generate_options) = generate_options { let max_length = if let Some(generate_options) = generate_options {
match (generate_options.max_length, generate_options.max_new_tokens) { match (generate_options.max_length, generate_options.max_new_tokens) {
(Some(max_length), _) => max_length, (Some(max_length), _) => Some(max_length),
(None, Some(max_new_tokens)) => max_new_tokens + input_ids.size().last().unwrap(), (None, Some(max_new_tokens)) => {
Some(max_new_tokens + input_ids.size().last().unwrap())
}
(None, None) => config.max_length, (None, None) => config.max_length,
} }
} else { } else {
config.max_length config.max_length
}; };
if max_length.is_none() & eos_token_ids.is_none() {
panic!("No maximum length given for a model without an EOS token. \
This would lead to an infinite generation loop. Please provide a `max_length` or `max_new_tokens`")
}
let gen_opt = InternalGenerateOptions { let gen_opt = InternalGenerateOptions {
min_length, min_length,
max_length, max_length,
@ -2083,7 +2122,7 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
/// # let weights_path = &home.as_path().join("model.ot"); /// # let weights_path = &home.as_path().join("model.ot");
/// let device = Device::cuda_if_available(); /// let device = Device::cuda_if_available();
/// let generate_config = GenerateConfig { /// let generate_config = GenerateConfig {
/// max_length: 30, /// max_length: Some(30),
/// do_sample: true, /// do_sample: true,
/// num_beams: 5, /// num_beams: 5,
/// temperature: 1.1, /// temperature: 1.1,
@ -2115,7 +2154,7 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
#[derive(Debug)] #[derive(Debug)]
struct BeamHypotheses { struct BeamHypotheses {
max_length: i64, max_length: Option<i64>,
length_penalty: f64, length_penalty: f64,
early_stopping: bool, early_stopping: bool,
num_beams: i64, num_beams: i64,
@ -2151,12 +2190,12 @@ impl Clone for BeamHypotheses {
impl BeamHypotheses { impl BeamHypotheses {
fn new( fn new(
num_beams: i64, num_beams: i64,
max_length: i64, max_length: Option<i64>,
length_penalty: f64, length_penalty: f64,
early_stopping: bool, early_stopping: bool,
) -> BeamHypotheses { ) -> BeamHypotheses {
BeamHypotheses { BeamHypotheses {
max_length: max_length - 1, max_length: max_length.map(|max_length| max_length - 1),
length_penalty, length_penalty,
early_stopping, early_stopping,
num_beams, num_beams,

View File

@ -96,7 +96,7 @@ pub struct SummarizationConfig {
/// Minimum sequence length (default: 0) /// Minimum sequence length (default: 0)
pub min_length: i64, pub min_length: i64,
/// Maximum sequence length (default: 20) /// Maximum sequence length (default: 20)
pub max_length: i64, pub max_length: Option<i64>,
/// Sampling flag. If true, will perform top-k and/or nucleus sampling on generated tokens, otherwise greedy (deterministic) decoding (default: true) /// Sampling flag. If true, will perform top-k and/or nucleus sampling on generated tokens, otherwise greedy (deterministic) decoding (default: true)
pub do_sample: bool, pub do_sample: bool,
/// Early stopping flag indicating if the beam search should stop as soon as `num_beam` hypotheses have been generated (default: false) /// Early stopping flag indicating if the beam search should stop as soon as `num_beam` hypotheses have been generated (default: false)
@ -154,7 +154,7 @@ impl SummarizationConfig {
vocab_resource: Box::new(vocab_resource), vocab_resource: Box::new(vocab_resource),
merges_resource: merges_resource.map(|r| Box::new(r) as Box<_>), merges_resource: merges_resource.map(|r| Box::new(r) as Box<_>),
min_length: 56, min_length: 56,
max_length: 142, max_length: Some(142),
do_sample: false, do_sample: false,
early_stopping: true, early_stopping: true,
num_beams: 3, num_beams: 3,

View File

@ -66,8 +66,8 @@ pub struct TextGenerationConfig {
pub merges_resource: Option<Box<dyn ResourceProvider + Send>>, pub merges_resource: Option<Box<dyn ResourceProvider + Send>>,
/// Minimum sequence length (default: 0) /// Minimum sequence length (default: 0)
pub min_length: i64, pub min_length: i64,
/// Maximum sequence length (default: 20) /// Maximum sequence length (default: 56)
pub max_length: i64, pub max_length: Option<i64>,
/// Sampling flag. If true, will perform top-k and/or nucleus sampling on generated tokens, otherwise greedy (deterministic) decoding (default: true) /// Sampling flag. If true, will perform top-k and/or nucleus sampling on generated tokens, otherwise greedy (deterministic) decoding (default: true)
pub do_sample: bool, pub do_sample: bool,
/// Early stopping flag indicating if the beam search should stop as soon as `num_beam` hypotheses have been generated (default: false) /// Early stopping flag indicating if the beam search should stop as soon as `num_beam` hypotheses have been generated (default: false)
@ -125,7 +125,7 @@ impl TextGenerationConfig {
vocab_resource: Box::new(vocab_resource), vocab_resource: Box::new(vocab_resource),
merges_resource: merges_resource.map(|r| Box::new(r) as Box<_>), merges_resource: merges_resource.map(|r| Box::new(r) as Box<_>),
min_length: 0, min_length: 0,
max_length: 20, max_length: Some(56),
do_sample: true, do_sample: true,
early_stopping: true, early_stopping: true,
num_beams: 5, num_beams: 5,
@ -326,7 +326,7 @@ pub struct TextGenerationModel {
prefix: Option<String>, prefix: Option<String>,
prefix_length: Option<i64>, prefix_length: Option<i64>,
min_length: i64, min_length: i64,
max_length: i64, max_length: Option<i64>,
} }
impl TextGenerationModel { impl TextGenerationModel {
@ -445,7 +445,7 @@ with people, even a bishop, begging for his blessing. <eod> </s> <eos>"
self.model.generate_indices( self.model.generate_indices(
Some(&texts), Some(&texts),
Some(self.min_length + prefix_length), Some(self.min_length + prefix_length),
Some(self.max_length + prefix_length), self.max_length.map(|max_length| max_length + prefix_length),
) )
} }
_ => panic!("Prefix length not defined but prefix provided!"), _ => panic!("Prefix length not defined but prefix provided!"),

View File

@ -387,8 +387,8 @@ pub struct TranslationConfig {
pub target_languages: HashSet<Language>, pub target_languages: HashSet<Language>,
/// Minimum sequence length (default: 0) /// Minimum sequence length (default: 0)
pub min_length: i64, pub min_length: i64,
/// Maximum sequence length (default: 20) /// Maximum sequence length (default: 512)
pub max_length: i64, pub max_length: Option<i64>,
/// Sampling flag. If true, will perform top-k and/or nucleus sampling on generated tokens, otherwise greedy (deterministic) decoding (default: true) /// Sampling flag. If true, will perform top-k and/or nucleus sampling on generated tokens, otherwise greedy (deterministic) decoding (default: true)
pub do_sample: bool, pub do_sample: bool,
/// Early stopping flag indicating if the beam search should stop as soon as `num_beam` hypotheses have been generated (default: false) /// Early stopping flag indicating if the beam search should stop as soon as `num_beam` hypotheses have been generated (default: false)
@ -488,7 +488,7 @@ impl TranslationConfig {
target_languages: target_languages.as_ref().iter().cloned().collect(), target_languages: target_languages.as_ref().iter().cloned().collect(),
device, device,
min_length: 0, min_length: 0,
max_length: 512, max_length: Some(512),
do_sample: false, do_sample: false,
early_stopping: true, early_stopping: true,
num_beams: 3, num_beams: 3,

View File

@ -926,7 +926,7 @@ impl ProphetNetConditionalGenerator {
/// # let weights_path = &home.as_path().join("model.ot"); /// # let weights_path = &home.as_path().join("model.ot");
/// let device = Device::cuda_if_available(); /// let device = Device::cuda_if_available();
/// let generate_config = GenerateConfig { /// let generate_config = GenerateConfig {
/// max_length: 30, /// max_length: Some(30),
/// do_sample: true, /// do_sample: true,
/// num_beams: 5, /// num_beams: 5,
/// temperature: 1.1, /// temperature: 1.1,
@ -1075,7 +1075,7 @@ impl
fn encode_prompt_text<S>( fn encode_prompt_text<S>(
&self, &self,
prompt_text: &[S], prompt_text: &[S],
max_len: i64, max_len: Option<i64>,
pad_token_id: Option<i64>, pad_token_id: Option<i64>,
) -> Tensor ) -> Tensor
where where
@ -1083,7 +1083,9 @@ impl
{ {
let tokens = self._get_tokenizer().encode_list( let tokens = self._get_tokenizer().encode_list(
prompt_text, prompt_text,
max_len as usize, max_len
.map(|max_len| max_len as usize)
.unwrap_or(usize::MAX),
&TruncationStrategy::LongestFirst, &TruncationStrategy::LongestFirst,
0, 0,
); );

View File

@ -985,7 +985,7 @@ impl PrivateLanguageGenerator<T5ForConditionalGeneration, T5Vocab, T5Tokenizer>
fn encode_prompt_text<S>( fn encode_prompt_text<S>(
&self, &self,
prompt_text: &[S], prompt_text: &[S],
max_len: i64, max_len: Option<i64>,
pad_token_id: Option<i64>, pad_token_id: Option<i64>,
) -> Tensor ) -> Tensor
where where
@ -993,7 +993,9 @@ impl PrivateLanguageGenerator<T5ForConditionalGeneration, T5Vocab, T5Tokenizer>
{ {
let tokens = self._get_tokenizer().encode_list( let tokens = self._get_tokenizer().encode_list(
prompt_text, prompt_text,
max_len as usize, max_len
.map(|max_len| max_len as usize)
.unwrap_or(usize::MAX),
&TruncationStrategy::LongestFirst, &TruncationStrategy::LongestFirst,
0, 0,
); );

View File

@ -39,7 +39,7 @@
//! config_resource, //! config_resource,
//! vocab_resource, //! vocab_resource,
//! merges_resource: None, //! merges_resource: None,
//! max_length: 56, //! max_length: Some(56),
//! do_sample: true, //! do_sample: true,
//! num_beams: 3, //! num_beams: 3,
//! temperature: 1.0, //! temperature: 1.0,

View File

@ -1610,7 +1610,7 @@ impl XLNetGenerator {
/// use rust_bert::xlnet::XLNetGenerator; /// use rust_bert::xlnet::XLNetGenerator;
/// ///
/// let generate_config = GenerateConfig { /// let generate_config = GenerateConfig {
/// max_length: 30, /// max_length: Some(30),
/// do_sample: true, /// do_sample: true,
/// num_beams: 5, /// num_beams: 5,
/// temperature: 1.1, /// temperature: 1.1,

View File

@ -97,7 +97,7 @@ fn bart_summarization_greedy() -> anyhow::Result<()> {
num_beams: 1, num_beams: 1,
length_penalty: 1.0, length_penalty: 1.0,
min_length: 56, min_length: 56,
max_length: 142, max_length: Some(142),
device: Device::Cpu, device: Device::Cpu,
..Default::default() ..Default::default()
}; };
@ -157,7 +157,7 @@ fn bart_summarization_beam_search() -> anyhow::Result<()> {
merges_resource: Some(merges_resource), merges_resource: Some(merges_resource),
num_beams: 4, num_beams: 4,
min_length: 56, min_length: 56,
max_length: 142, max_length: Some(142),
length_penalty: 1.0, length_penalty: 1.0,
device: Device::Cpu, device: Device::Cpu,
..Default::default() ..Default::default()

View File

@ -121,7 +121,7 @@ fn gpt2_generation_greedy() -> anyhow::Result<()> {
config_resource, config_resource,
vocab_resource, vocab_resource,
merges_resource: Some(merges_resource), merges_resource: Some(merges_resource),
max_length: 40, max_length: Some(40),
do_sample: false, do_sample: false,
num_beams: 1, num_beams: 1,
temperature: 1.1, temperature: 1.1,
@ -153,7 +153,7 @@ fn gpt2_generation_beam_search() -> anyhow::Result<()> {
config_resource, config_resource,
vocab_resource, vocab_resource,
merges_resource: Some(merges_resource), merges_resource: Some(merges_resource),
max_length: 20, max_length: Some(20),
do_sample: false, do_sample: false,
num_beams: 5, num_beams: 5,
temperature: 1.2, temperature: 1.2,
@ -197,7 +197,7 @@ fn gpt2_generation_beam_search_multiple_prompts_without_padding() -> anyhow::Res
config_resource, config_resource,
vocab_resource, vocab_resource,
merges_resource: Some(merges_resource), merges_resource: Some(merges_resource),
max_length: 20, max_length: Some(20),
do_sample: false, do_sample: false,
num_beams: 5, num_beams: 5,
temperature: 1.2, temperature: 1.2,
@ -254,7 +254,7 @@ fn gpt2_generation_beam_search_multiple_prompts_with_padding() -> anyhow::Result
config_resource, config_resource,
vocab_resource, vocab_resource,
merges_resource: Some(merges_resource), merges_resource: Some(merges_resource),
max_length: 20, max_length: Some(20),
do_sample: false, do_sample: false,
num_beams: 5, num_beams: 5,
temperature: 1.2, temperature: 1.2,
@ -311,7 +311,7 @@ fn gpt2_diverse_beam_search_multiple_prompts_with_padding() -> anyhow::Result<()
vocab_resource, vocab_resource,
merges_resource: Some(merges_resource), merges_resource: Some(merges_resource),
min_length: 10, min_length: 10,
max_length: 20, max_length: Some(20),
do_sample: false, do_sample: false,
num_beams: 6, num_beams: 6,
num_beam_groups: Some(3), num_beam_groups: Some(3),
@ -378,7 +378,7 @@ fn gpt2_prefix_allowed_token_greedy() -> anyhow::Result<()> {
} }
let generate_config = GenerateConfig { let generate_config = GenerateConfig {
max_length: 56, max_length: Some(56),
model_resource, model_resource,
config_resource, config_resource,
vocab_resource, vocab_resource,
@ -428,7 +428,7 @@ fn gpt2_bad_tokens_greedy() -> anyhow::Result<()> {
let model_resource = Box::new(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2)); let model_resource = Box::new(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
let generate_config = GenerateConfig { let generate_config = GenerateConfig {
max_length: 36, max_length: Some(36),
model_resource, model_resource,
config_resource, config_resource,
vocab_resource, vocab_resource,
@ -494,7 +494,7 @@ fn gpt2_bad_tokens_beam_search() -> anyhow::Result<()> {
let model_resource = Box::new(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2)); let model_resource = Box::new(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
let generate_config = GenerateConfig { let generate_config = GenerateConfig {
max_length: 36, max_length: Some(36),
model_resource, model_resource,
config_resource, config_resource,
vocab_resource, vocab_resource,
@ -575,7 +575,7 @@ fn gpt2_prefix_allowed_token_beam_search() -> anyhow::Result<()> {
} }
let generate_config = GenerateConfig { let generate_config = GenerateConfig {
max_length: 32, max_length: Some(32),
model_resource, model_resource,
config_resource, config_resource,
vocab_resource, vocab_resource,
@ -625,7 +625,7 @@ fn gpt2_greedy_token_scores() -> anyhow::Result<()> {
let model_resource = Box::new(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2)); let model_resource = Box::new(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
let generate_config = GenerateConfig { let generate_config = GenerateConfig {
max_length: 16, max_length: Some(16),
model_resource, model_resource,
config_resource, config_resource,
vocab_resource, vocab_resource,
@ -681,7 +681,7 @@ fn gpt2_beam_search_token_scores() -> anyhow::Result<()> {
let model_resource = Box::new(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2)); let model_resource = Box::new(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
let generate_config = GenerateConfig { let generate_config = GenerateConfig {
max_length: 16, max_length: Some(16),
model_resource, model_resource,
config_resource, config_resource,
vocab_resource, vocab_resource,
@ -812,7 +812,7 @@ fn dialogpt_multiple_multi_turn_conversation() -> anyhow::Result<()> {
fn dialogpt_multiple_multi_turn_conversation_with_truncation() -> anyhow::Result<()> { fn dialogpt_multiple_multi_turn_conversation_with_truncation() -> anyhow::Result<()> {
// Set-up conversation model // Set-up conversation model
let conversation_config = ConversationConfig { let conversation_config = ConversationConfig {
max_length: 36, max_length: Some(36),
min_length_for_response: 24, min_length_for_response: 24,
do_sample: false, do_sample: false,
device: Device::Cpu, device: Device::Cpu,

View File

@ -130,7 +130,7 @@ fn test_generation_gpt_neo() -> anyhow::Result<()> {
vocab_resource, vocab_resource,
merges_resource: Some(merges_resource), merges_resource: Some(merges_resource),
min_length: 10, min_length: 10,
max_length: 32, max_length: Some(32),
do_sample: false, do_sample: false,
early_stopping: true, early_stopping: true,
num_beams: 4, num_beams: 4,

View File

@ -123,7 +123,7 @@ fn openai_gpt_generation_greedy() -> anyhow::Result<()> {
config_resource, config_resource,
vocab_resource, vocab_resource,
merges_resource: Some(merges_resource), merges_resource: Some(merges_resource),
max_length: 40, max_length: Some(40),
do_sample: false, do_sample: false,
num_beams: 1, num_beams: 1,
top_p: 1.0, top_p: 1.0,
@ -165,7 +165,7 @@ fn openai_gpt_generation_beam_search() -> anyhow::Result<()> {
config_resource, config_resource,
vocab_resource, vocab_resource,
merges_resource: Some(merges_resource), merges_resource: Some(merges_resource),
max_length: 20, max_length: Some(20),
do_sample: false, do_sample: false,
early_stopping: true, early_stopping: true,
num_beams: 5, num_beams: 5,
@ -218,7 +218,7 @@ fn openai_gpt_generation_beam_search_multiple_prompts_without_padding() -> anyho
config_resource, config_resource,
vocab_resource, vocab_resource,
merges_resource: Some(merges_resource), merges_resource: Some(merges_resource),
max_length: 20, max_length: Some(20),
do_sample: false, do_sample: false,
early_stopping: true, early_stopping: true,
num_beams: 5, num_beams: 5,
@ -287,7 +287,7 @@ fn openai_gpt_generation_beam_search_multiple_prompts_with_padding() -> anyhow::
config_resource, config_resource,
vocab_resource, vocab_resource,
merges_resource: Some(merges_resource), merges_resource: Some(merges_resource),
max_length: 20, max_length: Some(20),
do_sample: false, do_sample: false,
num_beams: 5, num_beams: 5,
temperature: 2.0, temperature: 2.0,

View File

@ -50,7 +50,7 @@ fn test_generation_reformer() -> anyhow::Result<()> {
vocab_resource, vocab_resource,
merges_resource: None, merges_resource: None,
min_length: 100, min_length: 100,
max_length: 100, max_length: Some(100),
do_sample: false, do_sample: false,
early_stopping: true, early_stopping: true,
no_repeat_ngram_size: 3, no_repeat_ngram_size: 3,

View File

@ -70,7 +70,7 @@ fn test_summarization_t5() -> anyhow::Result<()> {
vocab_resource: Box::new(RemoteResource::from_pretrained(T5VocabResources::T5_SMALL)), vocab_resource: Box::new(RemoteResource::from_pretrained(T5VocabResources::T5_SMALL)),
merges_resource: None, merges_resource: None,
min_length: 30, min_length: 30,
max_length: 200, max_length: Some(200),
early_stopping: true, early_stopping: true,
num_beams: 4, num_beams: 4,
length_penalty: 2.0, length_penalty: 2.0,

View File

@ -212,7 +212,7 @@ fn xlnet_generation_beam_search() -> anyhow::Result<()> {
config_resource, config_resource,
vocab_resource, vocab_resource,
merges_resource: None, merges_resource: None,
max_length: 32, max_length: Some(32),
do_sample: false, do_sample: false,
num_beams: 3, num_beams: 3,
temperature: 1.0, temperature: 1.0,