Make generics less generic. (#189)

* Make generics less generic.

Fix examples, tests and docs.

* Address outstanding issues

* Take less ownership where possible

* Fixup some clippy warnings

* Updated tokenizer crate version

Co-authored-by: Guillaume Becquin <guillaume.becquin@gmail.com>
This commit is contained in:
sftse 2021-11-07 09:42:56 +01:00 committed by GitHub
parent 1a56594483
commit e297f395af
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
31 changed files with 147 additions and 158 deletions

View File

@ -57,7 +57,7 @@ all-tests = []
features = ["doc-only"]
[dependencies]
rust_tokenizers = "~6.2.5"
rust_tokenizers = "~7.0.0"
tch = "~0.6.1"
serde_json = "1.0.68"
serde = { version = "1.0.130", features = ["derive"] }

View File

@ -53,9 +53,9 @@ fn main() -> anyhow::Result<()> {
let source_sentence = "This sentence will be translated in multiple languages.";
let mut outputs = Vec::new();
outputs.extend(model.translate([source_sentence], Language::English, Language::French)?);
outputs.extend(model.translate([source_sentence], Language::English, Language::Spanish)?);
outputs.extend(model.translate([source_sentence], Language::English, Language::Hindi)?);
outputs.extend(model.translate(&[source_sentence], Language::English, Language::French)?);
outputs.extend(model.translate(&[source_sentence], Language::English, Language::Spanish)?);
outputs.extend(model.translate(&[source_sentence], Language::English, Language::Hindi)?);
for sentence in outputs {
println!("{}", sentence);

View File

@ -53,9 +53,9 @@ fn main() -> anyhow::Result<()> {
let source_sentence = "This sentence will be translated in multiple languages.";
let mut outputs = Vec::new();
outputs.extend(model.translate([source_sentence], Language::English, Language::French)?);
outputs.extend(model.translate([source_sentence], Language::English, Language::Spanish)?);
outputs.extend(model.translate([source_sentence], Language::English, Language::Hindi)?);
outputs.extend(model.translate(&[source_sentence], Language::English, Language::French)?);
outputs.extend(model.translate(&[source_sentence], Language::English, Language::Spanish)?);
outputs.extend(model.translate(&[source_sentence], Language::English, Language::Hindi)?);
for sentence in outputs {
println!("{}", sentence);

View File

@ -56,9 +56,9 @@ fn main() -> anyhow::Result<()> {
let source_sentence = "This sentence will be translated in multiple languages.";
let mut outputs = Vec::new();
outputs.extend(model.translate([source_sentence], Language::English, Language::French)?);
outputs.extend(model.translate([source_sentence], Language::English, Language::German)?);
outputs.extend(model.translate([source_sentence], Language::English, Language::Romanian)?);
outputs.extend(model.translate(&[source_sentence], Language::English, Language::French)?);
outputs.extend(model.translate(&[source_sentence], Language::English, Language::German)?);
outputs.extend(model.translate(&[source_sentence], Language::English, Language::Romanian)?);
for sentence in outputs {
println!("{}", sentence);

View File

@ -1195,17 +1195,17 @@ impl PrivateLanguageGenerator<BartForConditionalGeneration, RobertaVocab, Robert
}
}
fn encode_prompt_text<'a, S>(
fn encode_prompt_text<S>(
&self,
prompt_text: S,
prompt_text: &[S],
max_len: i64,
pad_token_id: Option<i64>,
) -> Tensor
where
S: AsRef<[&'a str]>,
S: AsRef<str> + Sync,
{
let tokens = self._get_tokenizer().encode_list(
prompt_text.as_ref(),
prompt_text,
max_len as usize,
&TruncationStrategy::LongestFirst,
0,

View File

@ -798,17 +798,17 @@ impl PrivateLanguageGenerator<M2M100ForConditionalGeneration, M2M100Vocab, M2M10
}
}
fn encode_prompt_text<'a, S>(
fn encode_prompt_text<S>(
&self,
prompt_text: S,
prompt_text: &[S],
max_len: i64,
pad_token_id: Option<i64>,
) -> Tensor
where
S: AsRef<[&'a str]>,
S: AsRef<str> + Sync,
{
let tokens = self._get_tokenizer().encode_list(
prompt_text.as_ref(),
prompt_text,
max_len as usize,
&TruncationStrategy::LongestFirst,
0,

View File

@ -976,17 +976,17 @@ impl PrivateLanguageGenerator<MarianForConditionalGeneration, MarianVocab, Maria
}
}
fn encode_prompt_text<'a, T>(
fn encode_prompt_text<S>(
&self,
prompt_text: T,
prompt_text: &[S],
max_len: i64,
pad_token_id: Option<i64>,
) -> Tensor
where
T: AsRef<[&'a str]>,
S: AsRef<str> + Sync,
{
let tokens = self._get_tokenizer().encode_list(
prompt_text.as_ref(),
prompt_text,
max_len as usize,
&TruncationStrategy::LongestFirst,
0,

View File

@ -1008,17 +1008,17 @@ impl PrivateLanguageGenerator<MBartForConditionalGeneration, MBart50Vocab, MBart
}
}
fn encode_prompt_text<'a, S>(
fn encode_prompt_text<S>(
&self,
prompt_text: S,
prompt_text: &[S],
max_len: i64,
pad_token_id: Option<i64>,
) -> Tensor
where
S: AsRef<[&'a str]>,
S: AsRef<str> + Sync,
{
let tokens = self._get_tokenizer().encode_list(
prompt_text.as_ref(),
prompt_text,
max_len as usize,
&TruncationStrategy::LongestFirst,
0,

View File

@ -766,17 +766,17 @@ impl PrivateLanguageGenerator<PegasusForConditionalGeneration, PegasusVocab, Peg
}
}
fn encode_prompt_text<'a, S>(
fn encode_prompt_text<S>(
&self,
prompt_text: S,
prompt_text: &[S],
max_len: i64,
pad_token_id: Option<i64>,
) -> Tensor
where
S: AsRef<[&'a str]>,
S: AsRef<str> + Sync,
{
let tokens = self._get_tokenizer().encode_list(
prompt_text.as_ref(),
prompt_text,
max_len as usize,
&TruncationStrategy::LongestFirst,
0,

View File

@ -494,13 +494,16 @@ impl TokenizerOption {
}
/// Interface method
pub fn encode_list(
pub fn encode_list<S>(
&self,
text_list: &[&str],
text_list: &[S],
max_len: usize,
truncation_strategy: &TruncationStrategy,
stride: usize,
) -> Vec<TokenizedInput> {
) -> Vec<TokenizedInput>
where
S: AsRef<str> + Sync,
{
match *self {
Self::Bert(ref tokenizer) => MultiThreadedTokenizer::encode_list(
tokenizer,
@ -809,7 +812,10 @@ impl TokenizerOption {
}
/// Interface method to tokenization
pub fn tokenize_list(&self, text: &[&str]) -> Vec<Vec<String>> {
pub fn tokenize_list<S>(&self, text: &[S]) -> Vec<Vec<String>>
where
S: AsRef<str> + Sync,
{
match *self {
Self::Bert(ref tokenizer) => MultiThreadedTokenizer::tokenize_list(tokenizer, text),
Self::Roberta(ref tokenizer) => MultiThreadedTokenizer::tokenize_list(tokenizer, text),
@ -837,7 +843,7 @@ impl TokenizerOption {
/// Interface method to decoding
pub fn decode(
&self,
token_ids: Vec<i64>,
token_ids: &[i64],
skip_special_tokens: bool,
clean_up_tokenization_spaces: bool,
) -> String {
@ -964,10 +970,9 @@ impl TokenizerOption {
}
/// Interface method to convert tokens to ids
pub fn convert_tokens_to_ids<S, ST>(&self, tokens: S) -> Vec<i64>
pub fn convert_tokens_to_ids<S>(&self, tokens: &[S]) -> Vec<i64>
where
S: AsRef<[ST]>,
ST: AsRef<str>,
S: AsRef<str>,
{
match *self {
Self::Bert(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens),

View File

@ -416,22 +416,20 @@ impl Conversation {
/// let _ = conversation_manager
/// .get(&conversation_1_id)
/// .unwrap()
/// .load_from_history(history, encoded_history);
/// .load_from_history(&history, &encoded_history);
/// # Ok(())
/// # }
/// ```
pub fn load_from_history<ST, SI, STR, SIN>(&mut self, texts: ST, ids: SI)
pub fn load_from_history<S, SI>(&mut self, texts: &[S], ids: &[SI])
where
ST: AsRef<[STR]>,
SI: AsRef<[SIN]>,
STR: AsRef<str>,
SIN: AsRef<[i64]>,
S: AsRef<str>,
SI: AsRef<[i64]>,
{
for (round_text, round_ids) in texts.as_ref().iter().zip(ids.as_ref().iter()) {
for (round_text, round_ids) in texts.iter().zip(ids.iter()) {
self.append(round_text.as_ref(), round_ids.as_ref());
}
if texts.as_ref().len() / 2 == 1 {
if texts.len() / 2 == 1 {
self.history.pop();
}
}
@ -840,11 +838,11 @@ impl ConversationModel {
let generated_response = &generated_sequence[input_length - removed_padding.0..];
conversation
.generated_responses
.push(self.model.get_tokenizer().decode(
generated_response.to_vec(),
true,
true,
));
.push(
self.model
.get_tokenizer()
.decode(generated_response, true, true),
);
conversation.history.push(conversation_promp_ids);
conversation.history.push(generated_response.to_vec());
conversation.mark_processed();

View File

@ -41,7 +41,7 @@
//! let input_context = "The dog";
//! let second_input_context = "The cat was";
//! let output = gpt2_generator.generate(
//! Some(vec![input_context, second_input_context]),
//! Some(&[input_context, second_input_context]),
//! None,
//! min_length,
//! max_length,
@ -319,16 +319,16 @@ pub(crate) mod private_generation_utils {
}
}
fn encode_prompt_text<'a, S>(
fn encode_prompt_text<S>(
&self,
prompt_text: S,
prompt_text: &[S],
max_len: i64,
pad_token_id: Option<i64>,
) -> Tensor
where
S: AsRef<[&'a str]>,
S: AsRef<str> + Sync,
{
let tokens = self._get_tokenizer().tokenize_list(prompt_text.as_ref());
let tokens = self._get_tokenizer().tokenize_list(prompt_text);
let token_ids = tokens
.into_iter()
.map(|prompt_tokens| self._get_tokenizer().convert_tokens_to_ids(&prompt_tokens))
@ -1500,7 +1500,7 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
/// }
///
/// let output = gpt2_generator.generate(
/// Some(vec![input_context, second_input_context]),
/// Some(&[input_context, second_input_context]),
/// attention_mask,
/// min_length,
/// max_length,
@ -1526,9 +1526,9 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
/// ]
/// # ;
/// ```
fn generate<'a, S>(
fn generate<S>(
&self,
prompt_texts: Option<S>,
prompt_texts: Option<&[S]>,
attention_mask: Option<Tensor>,
min_length: impl Into<Option<i64>>,
max_length: impl Into<Option<i64>>,
@ -1539,7 +1539,7 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
output_scores: bool,
) -> Vec<GeneratedTextOutput>
where
S: AsRef<[&'a str]>,
S: AsRef<str> + Sync,
{
let indices_outputs = self.generate_indices(
prompt_texts,
@ -1557,7 +1557,7 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
output.push(GeneratedTextOutput {
text: self
._get_tokenizer()
.decode(generated_sequence.indices, true, true),
.decode(&generated_sequence.indices, true, true),
score: generated_sequence.score,
});
}
@ -1632,7 +1632,7 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
/// }
///
/// let output = gpt2_generator.generate_indices(
/// Some(vec![input_context, second_input_context]),
/// Some(&[input_context, second_input_context]),
/// attention_mask,
/// min_length,
/// max_length,
@ -1645,9 +1645,9 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
/// # Ok(())
/// # }
/// ```
fn generate_indices<'a, S>(
fn generate_indices<S>(
&self,
prompt_texts: Option<S>,
prompt_texts: Option<&[S]>,
attention_mask: Option<Tensor>,
min_length: impl Into<Option<i64>>,
max_length: impl Into<Option<i64>>,
@ -1658,7 +1658,7 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
output_scores: bool,
) -> Vec<GeneratedIndicesOutput>
where
S: AsRef<[&'a str]>,
S: AsRef<str> + Sync,
{
let eos_token_ids = PrivateLanguageGenerator::get_eos_ids(self).clone();
@ -1767,7 +1767,7 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
/// }
///
/// let output = gpt2_generator.generate_indices(
/// Some(vec![input_context, second_input_context]),
/// Some(&[input_context, second_input_context]),
/// attention_mask,
/// min_length,
/// max_length,

View File

@ -191,9 +191,9 @@ impl NERModel {
/// # Ok(())
/// # }
/// ```
pub fn predict<'a, S>(&self, input: S) -> Vec<Vec<Entity>>
pub fn predict<S>(&self, input: &[S]) -> Vec<Vec<Entity>>
where
S: AsRef<[&'a str]>,
S: AsRef<str>,
{
self.token_classification_model
.predict(input, true, false)

View File

@ -195,9 +195,9 @@ impl POSModel {
/// # Ok(())
/// # }
/// ```
pub fn predict<'a, S>(&self, input: S) -> Vec<Vec<POSTag>>
pub fn predict<S>(&self, input: &[S]) -> Vec<Vec<POSTag>>
where
S: AsRef<[&'a str]>,
S: AsRef<str>,
{
self.token_classification_model
.predict(input, true, false)

View File

@ -254,13 +254,13 @@ impl SummarizationOption {
}
/// Interface method to generate() of the particular models.
pub fn generate<'a, S>(
pub fn generate<S>(
&self,
prompt_texts: Option<S>,
prompt_texts: Option<&[S]>,
attention_mask: Option<Tensor>,
) -> Vec<String>
where
S: AsRef<[&'a str]>,
S: AsRef<str> + Sync,
{
match *self {
Self::Bart(ref model) => model
@ -406,22 +406,18 @@ impl SummarizationModel {
/// # }
/// ```
/// (New sample credits: [WikiNews](https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b))
pub fn summarize<'a, S>(&self, texts: S) -> Vec<String>
pub fn summarize<S>(&self, texts: &[S]) -> Vec<String>
where
S: AsRef<[&'a str]>,
S: AsRef<str> + Sync,
{
match &self.prefix {
None => self.model.generate(Some(texts), None),
Some(prefix) => {
let texts = texts
.as_ref()
.iter()
.map(|text| format!("{}{}", prefix, text))
.map(|text| format!("{}{}", prefix, text.as_ref()))
.collect::<Vec<String>>();
self.model.generate(
Some(texts.iter().map(|x| &**x).collect::<Vec<&str>>()),
None,
)
self.model.generate(Some(&texts), None)
}
}
}

View File

@ -238,15 +238,15 @@ impl TextGenerationOption {
}
/// Interface method to generate() of the particular models.
pub fn generate_indices<'a, S>(
pub fn generate_indices<S>(
&self,
prompt_texts: Option<S>,
prompt_texts: Option<&[S]>,
attention_mask: Option<Tensor>,
min_length: Option<i64>,
max_length: Option<i64>,
) -> Vec<Vec<i64>>
where
S: AsRef<[&'a str]>,
S: AsRef<str> + Sync,
{
match *self {
Self::GPT(ref model) => model
@ -418,9 +418,9 @@ with people, even a bishop, begging for his blessing. <eod> </s> <eos>"
/// # Ok(())
/// # }
/// ```
pub fn generate<'a, S>(&self, texts: S, prefix: impl Into<Option<&'a str>>) -> Vec<String>
pub fn generate<'a, S>(&self, texts: &[S], prefix: impl Into<Option<&'a str>>) -> Vec<String>
where
S: AsRef<[&'a str]>,
S: AsRef<str> + Sync,
{
let (prefix, prefix_length) = match (prefix.into(), &self.prefix) {
(Some(query_prefix), _) => (
@ -436,10 +436,10 @@ with people, even a bishop, begging for his blessing. <eod> </s> <eos>"
let texts = texts
.as_ref()
.iter()
.map(|text| format!("{} {}", prefix, text))
.map(|text| format!("{} {}", prefix, text.as_ref()))
.collect::<Vec<String>>();
self.model.generate_indices(
Some(texts.iter().map(|x| &**x).collect::<Vec<&str>>()),
Some(&texts),
None,
Some(self.min_length + prefix_length),
Some(self.max_length + prefix_length),
@ -451,14 +451,7 @@ with people, even a bishop, begging for his blessing. <eod> </s> <eos>"
let mut output = Vec::with_capacity(generated_indices.len());
for generated_sequence in generated_indices {
output.push(self.model.get_tokenizer().decode(
if prefix_length.is_some() {
generated_sequence
.into_iter()
.skip(prefix_length.unwrap_or(0) as usize)
.collect::<Vec<i64>>()
} else {
generated_sequence
},
&generated_sequence[prefix_length.unwrap_or(0) as usize..],
true,
true,
));

View File

@ -776,17 +776,16 @@ impl TokenClassificationModel {
/// # Ok(())
/// # }
/// ```
pub fn predict<'a, S>(
pub fn predict<S>(
&self,
input: S,
input: &[S],
consolidate_sub_tokens: bool,
return_special: bool,
) -> Vec<Vec<Token>>
where
S: AsRef<[&'a str]>,
S: AsRef<str>,
{
let mut features: Vec<InputFeature> = input
.as_ref()
.iter()
.enumerate()
.map(|(example_index, example)| self.generate_features(example, example_index))
@ -794,7 +793,7 @@ impl TokenClassificationModel {
.collect();
let mut example_tokens_map: HashMap<usize, Vec<Token>> = HashMap::new();
for example_idx in 0..input.as_ref().len() {
for example_idx in 0..input.len() {
example_tokens_map.insert(example_idx, Vec::new());
}
let mut start = 0usize;
@ -820,7 +819,8 @@ impl TokenClassificationModel {
let labels = label_indices.get(sentence_idx);
let feature = &features[sentence_idx as usize];
let sentence_reference_flag = &feature.reference_feature;
let original_chars = input.as_ref()[feature.example_index]
let original_chars = input[feature.example_index]
.as_ref()
.chars()
.collect::<Vec<char>>();
let mut word_idx: u16 = 0;
@ -935,19 +935,19 @@ impl TokenClassificationModel {
let text = match offsets {
None => match self.tokenizer {
TokenizerOption::Bert(ref tokenizer) => {
Tokenizer::decode(tokenizer, vec![token_id], false, false)
Tokenizer::decode(tokenizer, &[token_id], false, false)
}
TokenizerOption::Roberta(ref tokenizer) => {
Tokenizer::decode(tokenizer, vec![token_id], false, false)
Tokenizer::decode(tokenizer, &[token_id], false, false)
}
TokenizerOption::XLMRoberta(ref tokenizer) => {
Tokenizer::decode(tokenizer, vec![token_id], false, false)
Tokenizer::decode(tokenizer, &[token_id], false, false)
}
TokenizerOption::Albert(ref tokenizer) => {
Tokenizer::decode(tokenizer, vec![token_id], false, false)
Tokenizer::decode(tokenizer, &[token_id], false, false)
}
TokenizerOption::XLNet(ref tokenizer) => {
Tokenizer::decode(tokenizer, vec![token_id], false, false)
Tokenizer::decode(tokenizer, &[token_id], false, false)
}
_ => panic!(
"Token classification not implemented for {:?}!",

View File

@ -55,9 +55,13 @@
//! let source_sentence = "This sentence will be translated in multiple languages.";
//!
//! let mut outputs = Vec::new();
//! outputs.extend(model.translate([source_sentence], Language::English, Language::French)?);
//! outputs.extend(model.translate([source_sentence], Language::English, Language::Spanish)?);
//! outputs.extend(model.translate([source_sentence], Language::English, Language::Hindi)?);
//! outputs.extend(model.translate(&[source_sentence], Language::English, Language::French)?);
//! outputs.extend(model.translate(
//! &[source_sentence],
//! Language::English,
//! Language::Spanish,
//! )?);
//! outputs.extend(model.translate(&[source_sentence], Language::English, Language::Hindi)?);
//!
//! for sentence in outputs {
//! println!("{}", sentence);

View File

@ -586,8 +586,7 @@ impl TranslationOption {
if !supported_source_languages.contains(source_language) {
return Err(RustBertError::ValueError(format!(
"{} not in list of supported languages: {:?}",
source_language.to_string(),
supported_source_languages
source_language, supported_source_languages
)));
}
}
@ -596,8 +595,7 @@ impl TranslationOption {
if !supported_target_languages.contains(target_language) {
return Err(RustBertError::ValueError(format!(
"{} not in list of supported languages: {:?}",
target_language.to_string(),
supported_target_languages
target_language, supported_target_languages
)));
}
}
@ -630,7 +628,7 @@ impl TranslationOption {
Some(format!(
"translate {} to {}:",
match source_language {
Some(value) => value.to_string(),
Some(value) => value,
None => {
return Err(RustBertError::ValueError(
"Missing source language for T5".to_string(),
@ -638,7 +636,7 @@ impl TranslationOption {
}
},
match target_language {
Some(value) => value.to_string(),
Some(value) => value,
None => {
return Err(RustBertError::ValueError(
"Missing target language for T5".to_string(),
@ -665,7 +663,7 @@ impl TranslationOption {
)),
if let Some(target_language) = target_language {
Some(
model._get_tokenizer().convert_tokens_to_ids([format!(
model._get_tokenizer().convert_tokens_to_ids(&[format!(
">>{}<<",
target_language.get_iso_639_1_code()
)])[0],
@ -705,9 +703,8 @@ impl TranslationOption {
if let Some(target_language) = target_language {
let language_code = target_language.get_iso_639_1_code();
Some(
model
._get_tokenizer()
.convert_tokens_to_ids([match language_code.len() {
model._get_tokenizer().convert_tokens_to_ids(&[
match language_code.len() {
2 => format!(">>{}.<<", language_code),
3 => format!(">>{}<<", language_code),
_ => {
@ -715,7 +712,8 @@ impl TranslationOption {
"Invalid ISO 639-3 code".to_string(),
));
}
}])[0],
},
])[0],
)
} else {
return Err(RustBertError::ValueError(format!(
@ -730,14 +728,14 @@ impl TranslationOption {
}
/// Interface method to generate() of the particular models.
pub fn generate<'a, S>(
pub fn generate<S>(
&self,
prompt_texts: Option<S>,
prompt_texts: Option<&[S]>,
attention_mask: Option<Tensor>,
forced_bos_token_id: Option<i64>,
) -> Vec<String>
where
S: AsRef<[&'a str]>,
S: AsRef<str> + Sync,
{
match *self {
Self::Marian(ref model) => model
@ -927,14 +925,14 @@ impl TranslationModel {
/// # Ok(())
/// # }
/// ```
pub fn translate<'a, S>(
pub fn translate<S>(
&self,
texts: S,
texts: &[S],
source_language: impl Into<Option<Language>>,
target_language: impl Into<Option<Language>>,
) -> Result<Vec<String>, RustBertError>
where
S: AsRef<[&'a str]>,
S: AsRef<str> + Sync,
{
let (prefix, forced_bos_token_id) = self.model.validate_and_get_prefix_and_forced_bos_id(
source_language.into().as_ref(),
@ -946,15 +944,10 @@ impl TranslationModel {
Ok(match prefix {
Some(value) => {
let texts = texts
.as_ref()
.iter()
.map(|&v| format!("{}{}", value, v))
.map(|v| format!("{}{}", value, v.as_ref()))
.collect::<Vec<String>>();
self.model.generate(
Some(texts.iter().map(AsRef::as_ref).collect::<Vec<&str>>()),
None,
forced_bos_token_id,
)
self.model.generate(Some(&texts), None, forced_bos_token_id)
}
None => self.model.generate(Some(texts), None, forced_bos_token_id),
})

View File

@ -1060,17 +1060,17 @@ impl
}
}
fn encode_prompt_text<'a, S>(
fn encode_prompt_text<S>(
&self,
prompt_text: S,
prompt_text: &[S],
max_len: i64,
pad_token_id: Option<i64>,
) -> Tensor
where
S: AsRef<[&'a str]>,
S: AsRef<str> + Sync,
{
let tokens = self._get_tokenizer().encode_list(
prompt_text.as_ref(),
prompt_text,
max_len as usize,
&TruncationStrategy::LongestFirst,
0,

View File

@ -855,17 +855,17 @@ impl PrivateLanguageGenerator<T5ForConditionalGeneration, T5Vocab, T5Tokenizer>
}
}
fn encode_prompt_text<'a, S>(
fn encode_prompt_text<S>(
&self,
prompt_text: S,
prompt_text: &[S],
max_len: i64,
pad_token_id: Option<i64>,
) -> Tensor
where
S: AsRef<[&'a str]>,
S: AsRef<str> + Sync,
{
let tokens = self._get_tokenizer().encode_list(
prompt_text.as_ref(),
prompt_text,
max_len as usize,
&TruncationStrategy::LongestFirst,
0,

View File

@ -80,7 +80,7 @@ fn distilgpt2_lm_model() -> anyhow::Result<()> {
.get(-1)
.argmax(-1, true)
.int64_value(&[0]);
let next_word = tokenizer.decode(vec![next_word_id], true, true);
let next_word = tokenizer.decode(&[next_word_id], true, true);
assert_eq!(model_output.lm_logits.size(), vec!(1, 11, 50257));
match model_output.cache {

View File

@ -82,7 +82,7 @@ fn gpt2_lm_model() -> anyhow::Result<()> {
.get(-1)
.argmax(-1, true)
.int64_value(&[0]);
let next_word = tokenizer.decode(vec![next_word_id], true, true);
let next_word = tokenizer.decode(&[next_word_id], true, true);
assert_eq!(model_output.lm_logits.size(), vec!(1, 4, 50257));
match model_output.cache {

View File

@ -72,7 +72,7 @@ fn gpt_neo_lm() -> anyhow::Result<()> {
.get(-1)
.argmax(-1, true)
.int64_value(&[0]);
let next_word = tokenizer.decode(vec![next_word_id], true, true);
let next_word = tokenizer.decode(&[next_word_id], true, true);
let next_score = model_output
.lm_logits
.get(0)

View File

@ -274,7 +274,7 @@ fn longformer_for_multiple_choice() -> anyhow::Result<()> {
let prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced.";
let inputs = ["Very positive sentence", "Second sentence input"];
let tokenized_input = tokenizer.encode_pair_list(
inputs
&inputs
.iter()
.map(|&inp| (prompt, inp))
.collect::<Vec<(&str, &str)>>(),

View File

@ -107,9 +107,9 @@ fn m2m100_translation() -> anyhow::Result<()> {
let source_sentence = "This sentence will be translated in multiple languages.";
let mut outputs = Vec::new();
outputs.extend(model.translate([source_sentence], Language::English, Language::French)?);
outputs.extend(model.translate([source_sentence], Language::English, Language::Spanish)?);
outputs.extend(model.translate([source_sentence], Language::English, Language::Hindi)?);
outputs.extend(model.translate(&[source_sentence], Language::English, Language::French)?);
outputs.extend(model.translate(&[source_sentence], Language::English, Language::Spanish)?);
outputs.extend(model.translate(&[source_sentence], Language::English, Language::Hindi)?);
assert_eq!(outputs.len(), 3);
assert_eq!(

View File

@ -73,9 +73,9 @@ fn mbart_translation() -> anyhow::Result<()> {
let source_sentence = "This sentence will be translated in multiple languages.";
let mut outputs = Vec::new();
outputs.extend(model.translate([source_sentence], Language::English, Language::French)?);
outputs.extend(model.translate([source_sentence], Language::English, Language::Spanish)?);
outputs.extend(model.translate([source_sentence], Language::English, Language::Hindi)?);
outputs.extend(model.translate(&[source_sentence], Language::English, Language::French)?);
outputs.extend(model.translate(&[source_sentence], Language::English, Language::Spanish)?);
outputs.extend(model.translate(&[source_sentence], Language::English, Language::Hindi)?);
assert_eq!(outputs.len(), 3);
assert_eq!(

View File

@ -182,7 +182,7 @@ fn mobilebert_for_multiple_choice() -> anyhow::Result<()> {
let prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced.";
let inputs = ["Very positive sentence", "Second sentence input"];
let tokenized_input = tokenizer.encode_pair_list(
inputs
&inputs
.iter()
.map(|&inp| (prompt, inp))
.collect::<Vec<(&str, &str)>>(),

View File

@ -83,7 +83,7 @@ fn openai_gpt_lm_model() -> anyhow::Result<()> {
.get(-1)
.argmax(-1, true)
.int64_value(&[0]);
let next_word = tokenizer.decode(vec![next_word_id], true, true);
let next_word = tokenizer.decode(&[next_word_id], true, true);
assert_eq!(model_output.lm_logits.size(), vec!(1, 6, 40478));
assert!(

View File

@ -44,9 +44,9 @@ fn test_translation_t5() -> anyhow::Result<()> {
let source_sentence = "This sentence will be translated in multiple languages.";
let mut outputs = Vec::new();
outputs.extend(model.translate([source_sentence], Language::English, Language::French)?);
outputs.extend(model.translate([source_sentence], Language::English, Language::German)?);
outputs.extend(model.translate([source_sentence], Language::English, Language::Romanian)?);
outputs.extend(model.translate(&[source_sentence], Language::English, Language::French)?);
outputs.extend(model.translate(&[source_sentence], Language::English, Language::German)?);
outputs.extend(model.translate(&[source_sentence], Language::English, Language::Romanian)?);
assert_eq!(outputs.len(), 3);
assert_eq!(

View File

@ -331,7 +331,7 @@ fn xlnet_for_multiple_choice() -> anyhow::Result<()> {
let prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced.";
let inputs = ["Very positive sentence", "Second sentence input"];
let tokenized_input = tokenizer.encode_pair_list(
inputs
&inputs
.iter()
.map(|&inp| (prompt, inp))
.collect::<Vec<(&str, &str)>>(),