Updated tests and docstrings

This commit is contained in:
Guillaume B 2021-06-03 10:17:52 +02:00
parent 6dd2510e30
commit d401fea891
3 changed files with 126 additions and 1 deletions

View File

@ -41,7 +41,7 @@ fn main() -> anyhow::Result<()> {
// Define input
let input = ["translate English to German: This sentence will get translated to German"];
let output = t5_model.generate(Some(input.to_vec()), None, None, None, None);
let output = t5_model.generate(Some(input.to_vec()), None, None, None, None, None);
println!("{:?}", output);
Ok(())

View File

@ -45,6 +45,7 @@
//! min_length,
//! max_length,
//! decoder_start_id,
//! None,
//! );
//! # Ok(())
//! # }
@ -1183,6 +1184,10 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
///
/// * `prompt_texts` - `Option<Vec<&str>>` Optional vector of text prompts. An empty prompt to the model may be passed if the model implement a `bos_id`.
/// * `attention_mask` - `Option<Tensor>` Optional attention mask to hide portions of the prompt.
/// * `min_length` - `impl Into<Option<i64>>` Optional minimum output sequence length
/// * `max_length` - `impl Into<Option<i64>>` Optional maximum output sequence length
/// * `decoder_start_token_id` - `impl Into<Option<i64>>` Optional decoder start token id
/// * `prefix_allowed_tokens_fn` - `Option<&dyn Fn(i64, &Tensor)` Optional function to control the generation process. The function should take a `batch_id` (i64) and a tensor of token_ids already generated and returns a `Vec<i64>` of allowed tokens.
///
/// # Returns
/// * `Vec<String>` Vector of generated strings based on the prompts of length *number_of_prompts* x *num_return_sequences*.
@ -1195,6 +1200,7 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
/// # fn main() -> anyhow::Result<()> {
/// use rust_bert::gpt2::GPT2Generator;
/// use rust_bert::pipelines::generation_utils::{GenerateConfig, LanguageGenerator};
/// use tch::Tensor;
/// # let mut home: PathBuf = dirs::home_dir().unwrap();
/// # home.push("rustbert");
/// # home.push("gpt2");
@ -1220,12 +1226,30 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
/// let max_length = 128;
/// let decoder_start_token_id = None;
///
/// //Example custom function for fine-grained generation control
/// fn force_one_paragraph(_batch_id: i64, previous_token_ids: &Tensor) -> Vec<i64> {
/// let paragraph_tokens = [198, 628];
///
/// for paragraph_token in paragraph_tokens.iter() {
/// if previous_token_ids
/// .iter::<i64>()
/// .unwrap()
/// .collect::<Vec<i64>>()
/// .contains(paragraph_token)
/// {
/// return vec![50256];
/// }
/// }
/// (0..50255).collect()
/// }
///
/// let output = gpt2_generator.generate(
/// Some(vec![input_context, second_input_context]),
/// attention_mask,
/// min_length,
/// max_length,
/// decoder_start_token_id,
/// Some(&force_one_paragraph)
/// );
/// # Ok(())
/// # }
@ -1276,6 +1300,10 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
///
/// * `prompt_texts` - `Option<Vec<&str>>` Optional vector of text prompts. An empty prompt to the model may be passed if the model implement a `bos_id`.
/// * `attention_mask` - `Option<Tensor>` Optional attention mask to hide portions of the prompt.
/// * `min_length` - `impl Into<Option<i64>>` Optional minimum output sequence length
/// * `max_length` - `impl Into<Option<i64>>` Optional maximum output sequence length
/// * `decoder_start_token_id` - `impl Into<Option<i64>>` Optional decoder start token id
/// * `prefix_allowed_tokens_fn` - `Option<&dyn Fn(i64, &Tensor)` Optional function to control the generation process. The function should take a `batch_id` (i64) and a tensor of token_ids already generated and returns a `Vec<i64>` of allowed tokens.
///
/// # Returns
/// * `Vec<Vec<i64>>` Vector of Vector of generated token indices based on the prompts of length *number_of_prompts* x *num_return_sequences*.
@ -1288,6 +1316,7 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
/// # fn main() -> anyhow::Result<()> {
/// use rust_bert::gpt2::GPT2Generator;
/// use rust_bert::pipelines::generation_utils::{GenerateConfig, LanguageGenerator};
/// use tch::Tensor;
/// # let mut home: PathBuf = dirs::home_dir().unwrap();
/// # home.push("rustbert");
/// # home.push("gpt2");
@ -1312,12 +1341,30 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
/// let max_length = 128;
/// let decoder_start_token_id = None;
///
/// //Example custom function for fine-grained generation control
/// fn force_one_paragraph(_batch_id: i64, previous_token_ids: &Tensor) -> Vec<i64> {
/// let paragraph_tokens = [198, 628];
///
/// for paragraph_token in paragraph_tokens.iter() {
/// if previous_token_ids
/// .iter::<i64>()
/// .unwrap()
/// .collect::<Vec<i64>>()
/// .contains(paragraph_token)
/// {
/// return vec![50256];
/// }
/// }
/// (0..50255).collect()
/// }
///
/// let output = gpt2_generator.generate_indices(
/// Some(vec![input_context, second_input_context]),
/// attention_mask,
/// min_length,
/// max_length,
/// decoder_start_token_id,
/// Some(&force_one_paragraph),
/// );
/// # Ok(())
/// # }
@ -1369,6 +1416,82 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
)
}
/// Generate token indices given a list of indices (useful when the input has been pre-tokenized).
/// Returns a list of output tokens that need to be decoded using a tokenizer.
///
/// # Arguments
///
/// * `input_ids` - `Tensor` pre-tokenized and encoded input for generation.
/// * `attention_mask` - `Option<Tensor>` Optional attention mask to hide portions of the prompt.
/// * `min_length` - `impl Into<Option<i64>>` Optional minimum output sequence length
/// * `max_length` - `impl Into<Option<i64>>` Optional maximum output sequence length
/// * `decoder_start_token_id` - `impl Into<Option<i64>>` Optional decoder start token id
/// * `prefix_allowed_tokens_fn` - `Option<&dyn Fn(i64, &Tensor)` Optional function to control the generation process. The function should take a `batch_id` (i64) and a tensor of token_ids already generated and returns a `Vec<i64>` of allowed tokens.
///
/// # Returns
/// * `Vec<Vec<i64>>` Vector of Vector of generated token indices based on the prompts of length *number_of_prompts* x *num_return_sequences*.
///
/// # Example
///
/// ```no_run
/// # use std::path::PathBuf;
/// # use tch::Device;
/// # fn main() -> anyhow::Result<()> {
/// use rust_bert::gpt2::GPT2Generator;
/// use rust_bert::pipelines::generation_utils::{GenerateConfig, LanguageGenerator};
/// use tch::Tensor;
/// # let mut home: PathBuf = dirs::home_dir().unwrap();
/// # home.push("rustbert");
/// # home.push("gpt2");
/// # let config_path = &home.as_path().join("config.json");
/// # let vocab_path = &home.as_path().join("vocab.txt");
/// # let merges_path = &home.as_path().join("merges.txt");
/// # let weights_path = &home.as_path().join("model.ot");
/// let device = Device::cuda_if_available();
/// let generate_config = GenerateConfig {
/// max_length: 30,
/// do_sample: true,
/// num_beams: 5,
/// temperature: 1.1,
/// num_return_sequences: 3,
/// ..Default::default()
/// };
/// let mut gpt2_generator = GPT2Generator::new(generate_config)?;
/// let input_context = "The dog";
/// let second_input_context = "The cat was";
/// let attention_mask = None;
/// let min_length = 32;
/// let max_length = 128;
/// let decoder_start_token_id = None;
///
/// //Example custom function for fine-grained generation control
/// fn force_one_paragraph(_batch_id: i64, previous_token_ids: &Tensor) -> Vec<i64> {
/// let paragraph_tokens = [198, 628];
///
/// for paragraph_token in paragraph_tokens.iter() {
/// if previous_token_ids
/// .iter::<i64>()
/// .unwrap()
/// .collect::<Vec<i64>>()
/// .contains(paragraph_token)
/// {
/// return vec![50256];
/// }
/// }
/// (0..50255).collect()
/// }
///
/// let output = gpt2_generator.generate_indices(
/// Some(vec![input_context, second_input_context]),
/// attention_mask,
/// min_length,
/// max_length,
/// decoder_start_token_id,
/// Some(&force_one_paragraph),
/// );
/// # Ok(())
/// # }
/// ```
fn generate_from_ids_and_past(
&self,
input_ids: Tensor,

View File

@ -414,6 +414,7 @@ fn gpt2_prefix_allowed_token_greedy() -> anyhow::Result<()> {
merges_resource,
do_sample: false,
num_beams: 1,
device: Device::Cpu,
..Default::default()
};
let model = GPT2Generator::new(generate_config)?;
@ -478,6 +479,7 @@ fn gpt2_prefix_allowed_token_beam_search() -> anyhow::Result<()> {
merges_resource,
do_sample: false,
num_beams: 3,
device: Device::Cpu,
..Default::default()
};
let model = GPT2Generator::new(generate_config)?;