Generalization of input types for pipelines

This commit is contained in:
Guillaume B 2020-10-11 16:18:44 +02:00
parent 426430ae0b
commit 97ee8ee928
31 changed files with 248 additions and 284 deletions

View File

@ -78,7 +78,7 @@ fn bench_squad(c: &mut Criterion) {
}
// Define input
let mut squad_path = PathBuf::from(env::var("squad_dataset")
.expect("Please set the \"squad_dataset\" environment variable pointing to the SQuAD dataset folder"));
.expect("Please set the \"squad_dataset\" environment variable pointing to the SQuAD dataset folder"));
squad_path.push("dev-v2.0.json");
let mut qa_inputs = squad_processor(squad_path);
qa_inputs.truncate(1000);

View File

@ -52,8 +52,7 @@ fn main() -> anyhow::Result<()> {
"Looks like one [MASK] is missing",
"It was a very nice and [MASK] day",
];
let tokenized_input =
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let tokenized_input = tokenizer.encode_list(&input, 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())

View File

@ -54,8 +54,7 @@ fn main() -> anyhow::Result<()> {
// Credits: WikiNews, CC BY 2.5 license (https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b)
let tokenized_input =
tokenizer.encode_list(input.to_vec(), 1024, &TruncationStrategy::LongestFirst, 0);
let tokenized_input = tokenizer.encode_list(&input, 1024, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())

View File

@ -47,8 +47,7 @@ fn main() -> anyhow::Result<()> {
"Looks like one [MASK] is missing",
"It was a very nice and [MASK] day",
];
let tokenized_input =
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let tokenized_input = tokenizer.encode_list(&input, 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())

View File

@ -50,8 +50,7 @@ fn main() -> anyhow::Result<()> {
"Looks like one thing is missing",
"It\'s like comparing oranges to apples",
];
let tokenized_input =
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let tokenized_input = tokenizer.encode_list(input, 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())

View File

@ -51,7 +51,7 @@ fn main() -> anyhow::Result<()> {
let input = ["One Two Three Ten Five Six Seven Eight"];
let tokenized_input = MultiThreadedTokenizer::encode_list(
&tokenizer,
input.to_vec(),
&input,
128,
&TruncationStrategy::LongestFirst,
0,

View File

@ -51,8 +51,7 @@ fn main() -> anyhow::Result<()> {
"Looks like one [MASK] is missing",
"It was a very nice and [MASK] day",
];
let tokenized_input =
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let tokenized_input = tokenizer.encode_list(&input, 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())

View File

@ -51,8 +51,7 @@ fn main() -> anyhow::Result<()> {
// Define input
let input = ["One two three four five six seven eight nine ten eleven"];
let tokenized_input =
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let tokenized_input = tokenizer.encode_list(&input, 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())

View File

@ -56,8 +56,7 @@ fn main() -> anyhow::Result<()> {
// Define input
let input = ["Wondering what the next word will"];
let tokenized_input =
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let tokenized_input = tokenizer.encode_list(&input, 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())

View File

@ -60,8 +60,7 @@ fn main() -> anyhow::Result<()> {
"<pad> Looks like one thing is missing",
"It\'s like comparing oranges to apples",
];
let tokenized_input =
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let tokenized_input = tokenizer.encode_list(&input, 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())

View File

@ -49,8 +49,7 @@ fn main() -> anyhow::Result<()> {
// Define input
let input = ["One two three four"];
let tokenized_input =
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let tokenized_input = tokenizer.encode_list(&input, 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())

View File

@ -32,7 +32,7 @@ use rust_tokenizers::tokenizer::{
use rust_tokenizers::vocab::{
AlbertVocab, BertVocab, MarianVocab, RobertaVocab, T5Vocab, XLMRobertaVocab, XLNetVocab,
};
use rust_tokenizers::{Mask, Offset, OffsetSize, TokenizedInput};
use rust_tokenizers::{TokenIdsWithOffsets, TokenizedInput};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::Path;
@ -273,7 +273,7 @@ impl TokenizerOption {
/// Interface method
pub fn encode_list(
&self,
text_list: Vec<&str>,
text_list: &[&str],
max_len: usize,
truncation_strategy: &TruncationStrategy,
stride: usize,
@ -330,7 +330,7 @@ impl TokenizerOption {
/// Interface method for pair encoding
pub fn encode_pair_list(
&self,
text_pair_list: Vec<(&str, &str)>,
text_pair_list: &[(&str, &str)],
max_len: usize,
truncation_strategy: &TruncationStrategy,
stride: usize,
@ -400,110 +400,61 @@ impl TokenizerOption {
/// Interface method to build input with special tokens
pub fn build_input_with_special_tokens(
&self,
tokens_1: Vec<i64>,
tokens_2: Option<Vec<i64>>,
offsets_1: Vec<Option<Offset>>,
offsets_2: Option<Vec<Option<Offset>>>,
original_offsets_1: Vec<Vec<OffsetSize>>,
original_offsets_2: Option<Vec<Vec<OffsetSize>>>,
mask_1: Vec<Mask>,
mask_2: Option<Vec<Mask>>,
token_ids_with_offsets_1: TokenIdsWithOffsets,
token_ids_with_offsets_2: Option<TokenIdsWithOffsets>,
) -> TokenizedInput {
let (token_ids, segment_ids, special_tokens_mask, token_offsets, reference_offsets, mask) =
match *self {
Self::Bert(ref tokenizer) => tokenizer.build_input_with_special_tokens(
tokens_1,
tokens_2,
offsets_1,
offsets_2,
original_offsets_1,
original_offsets_2,
mask_1,
mask_2,
),
Self::Roberta(ref tokenizer) => tokenizer.build_input_with_special_tokens(
tokens_1,
tokens_2,
offsets_1,
offsets_2,
original_offsets_1,
original_offsets_2,
mask_1,
mask_2,
),
Self::XLMRoberta(ref tokenizer) => tokenizer.build_input_with_special_tokens(
tokens_1,
tokens_2,
offsets_1,
offsets_2,
original_offsets_1,
original_offsets_2,
mask_1,
mask_2,
),
Self::Marian(ref tokenizer) => tokenizer.build_input_with_special_tokens(
tokens_1,
tokens_2,
offsets_1,
offsets_2,
original_offsets_1,
original_offsets_2,
mask_1,
mask_2,
),
Self::T5(ref tokenizer) => tokenizer.build_input_with_special_tokens(
tokens_1,
tokens_2,
offsets_1,
offsets_2,
original_offsets_1,
original_offsets_2,
mask_1,
mask_2,
),
Self::Albert(ref tokenizer) => tokenizer.build_input_with_special_tokens(
tokens_1,
tokens_2,
offsets_1,
offsets_2,
original_offsets_1,
original_offsets_2,
mask_1,
mask_2,
),
Self::XLNet(ref tokenizer) => tokenizer.build_input_with_special_tokens(
tokens_1,
tokens_2,
offsets_1,
offsets_2,
original_offsets_1,
original_offsets_2,
mask_1,
mask_2,
),
};
let token_ids_with_special_tokens = match *self {
Self::Bert(ref tokenizer) => tokenizer.build_input_with_special_tokens(
token_ids_with_offsets_1,
token_ids_with_offsets_2,
),
Self::Roberta(ref tokenizer) => tokenizer.build_input_with_special_tokens(
token_ids_with_offsets_1,
token_ids_with_offsets_2,
),
Self::XLMRoberta(ref tokenizer) => tokenizer.build_input_with_special_tokens(
token_ids_with_offsets_1,
token_ids_with_offsets_2,
),
Self::Marian(ref tokenizer) => tokenizer.build_input_with_special_tokens(
token_ids_with_offsets_1,
token_ids_with_offsets_2,
),
Self::T5(ref tokenizer) => tokenizer.build_input_with_special_tokens(
token_ids_with_offsets_1,
token_ids_with_offsets_2,
),
Self::Albert(ref tokenizer) => tokenizer.build_input_with_special_tokens(
token_ids_with_offsets_1,
token_ids_with_offsets_2,
),
Self::XLNet(ref tokenizer) => tokenizer.build_input_with_special_tokens(
token_ids_with_offsets_1,
token_ids_with_offsets_2,
),
};
TokenizedInput {
token_ids,
segment_ids,
special_tokens_mask,
token_ids: token_ids_with_special_tokens.token_ids,
segment_ids: token_ids_with_special_tokens.segment_ids,
special_tokens_mask: token_ids_with_special_tokens.special_tokens_mask,
overflowing_tokens: vec![],
num_truncated_tokens: 0,
token_offsets,
reference_offsets,
mask,
token_offsets: token_ids_with_special_tokens.token_offsets,
reference_offsets: token_ids_with_special_tokens.reference_offsets,
mask: token_ids_with_special_tokens.mask,
}
}
/// Interface method to convert tokens to ids
pub fn convert_tokens_to_ids(&self, tokens: &[String]) -> Vec<i64> {
match *self {
Self::Bert(ref tokenizer) => tokenizer.convert_tokens_to_ids(&tokens.into()),
Self::Roberta(ref tokenizer) => tokenizer.convert_tokens_to_ids(&tokens.into()),
Self::Marian(ref tokenizer) => tokenizer.convert_tokens_to_ids(&tokens.into()),
Self::T5(ref tokenizer) => tokenizer.convert_tokens_to_ids(&tokens.into()),
Self::XLMRoberta(ref tokenizer) => tokenizer.convert_tokens_to_ids(&tokens.into()),
Self::Albert(ref tokenizer) => tokenizer.convert_tokens_to_ids(&tokens.into()),
Self::XLNet(ref tokenizer) => tokenizer.convert_tokens_to_ids(&tokens.into()),
Self::Bert(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens.into()),
Self::Roberta(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens.into()),
Self::Marian(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens.into()),
Self::T5(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens.into()),
Self::XLMRoberta(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens.into()),
Self::Albert(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens.into()),
Self::XLNet(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens.into()),
}
}

View File

@ -671,7 +671,7 @@ impl ConversationModel {
.map(|c| &c.history)
.collect_vec();
let prompt_ids = self.encode_prompts(texts.as_slice());
let prompt_ids = self.encode_prompts(texts.as_ref());
let input_tensor = self.concat_input_history(prompt_ids, history);
let input_length = *input_tensor.size().last().unwrap() as usize;
let mut generated = self.model.generate_from_ids_and_past(input_tensor, None);
@ -791,7 +791,7 @@ impl ConversationModel {
fn encode_prompts(&self, texts: &[&str]) -> Vec<Vec<i64>> {
// Encode the user prompt into token ids
let tokens = self.model.get_tokenizer().tokenize_list(texts.to_vec());
let tokens = self.model.get_tokenizer().tokenize_list(texts);
tokens
.into_iter()

View File

@ -766,14 +766,17 @@ impl PrivateLanguageGenerator<BartForConditionalGeneration, RobertaVocab, Robert
}
}
fn encode_prompt_text(
fn encode_prompt_text<'a, S>(
&self,
prompt_text: Vec<&str>,
prompt_text: S,
max_len: i64,
pad_token_id: Option<i64>,
) -> Tensor {
) -> Tensor
where
S: AsRef<[&'a str]>,
{
let tokens = self.get_tokenizer().encode_list(
prompt_text,
prompt_text.as_ref(),
max_len as usize,
&TruncationStrategy::LongestFirst,
0,
@ -1041,14 +1044,17 @@ impl PrivateLanguageGenerator<MarianForConditionalGeneration, MarianVocab, Maria
}
}
fn encode_prompt_text(
fn encode_prompt_text<'a, T>(
&self,
prompt_text: Vec<&str>,
prompt_text: T,
max_len: i64,
pad_token_id: Option<i64>,
) -> Tensor {
) -> Tensor
where
T: AsRef<[&'a str]>,
{
let tokens = self.get_tokenizer().encode_list(
prompt_text,
prompt_text.as_ref(),
max_len as usize,
&TruncationStrategy::LongestFirst,
0,
@ -1270,14 +1276,17 @@ impl PrivateLanguageGenerator<T5ForConditionalGeneration, T5Vocab, T5Tokenizer>
}
}
fn encode_prompt_text(
fn encode_prompt_text<'a, S>(
&self,
prompt_text: Vec<&str>,
prompt_text: S,
max_len: i64,
pad_token_id: Option<i64>,
) -> Tensor {
) -> Tensor
where
S: AsRef<[&'a str]>,
{
let tokens = self.get_tokenizer().encode_list(
prompt_text,
prompt_text.as_ref(),
max_len as usize,
&TruncationStrategy::LongestFirst,
0,
@ -1581,11 +1590,14 @@ impl PrivateLanguageGenerator<XLNetLMHeadModel, XLNetVocab, XLNetTokenizer> for
}
impl LanguageGenerator<XLNetLMHeadModel, XLNetVocab, XLNetTokenizer> for XLNetGenerator {
fn generate(
fn generate<'a, S>(
&self,
prompt_texts: Option<Vec<&str>>,
prompt_texts: Option<S>,
attention_mask: Option<Tensor>,
) -> Vec<String> {
) -> Vec<String>
where
S: AsRef<[&'a str]>,
{
let eos_token_ids = PrivateLanguageGenerator::get_eos_ids(self).clone();
let config = PrivateLanguageGenerator::get_config(self);
@ -1618,6 +1630,7 @@ with people, even a bishop, begging for his blessing. <eod> </s> <eos>";
let input_ids = match prompt_texts {
Some(texts) => {
let texts = texts
.as_ref()
.iter()
.map(|text| format!("{} {}", prefix, text))
.collect::<Vec<String>>();
@ -1663,6 +1676,7 @@ pub(crate) mod private_generation_utils {
use crate::pipelines::generation::{BeamHypotheses, Cache, GenerateConfig, LMHeadModel};
use rust_tokenizers::tokenizer::{truncate_sequences, Tokenizer, TruncationStrategy};
use rust_tokenizers::vocab::Vocab;
use rust_tokenizers::TokenIdsWithOffsets;
use std::cmp::{max, min};
use std::collections::HashMap;
use tch::kind::Kind::{Bool, Float, Int64};
@ -1725,13 +1739,16 @@ pub(crate) mod private_generation_utils {
(Some(input_ids), Some(attention_mask), None, None, past)
}
fn encode_prompt_text(
fn encode_prompt_text<'a, S>(
&self,
prompt_text: Vec<&str>,
prompt_text: S,
max_len: i64,
pad_token_id: Option<i64>,
) -> Tensor {
let tokens = self.get_tokenizer().tokenize_list(prompt_text);
) -> Tensor
where
S: AsRef<[&'a str]>,
{
let tokens = self.get_tokenizer().tokenize_list(prompt_text.as_ref());
let token_ids = tokens
.into_iter()
.map(|prompt_tokens| self.get_tokenizer().convert_tokens_to_ids(&prompt_tokens))
@ -1753,13 +1770,12 @@ pub(crate) mod private_generation_utils {
.zip(num_truncated_tokens)
.map(|(tokens, num_truncated_tokens)| {
truncate_sequences(
tokens,
None,
vec![],
None,
vec![],
None,
vec![],
TokenIdsWithOffsets {
ids: tokens,
offsets: vec![],
reference_offsets: vec![],
masks: vec![],
},
None,
num_truncated_tokens,
&TruncationStrategy::LongestFirst,
@ -1767,6 +1783,7 @@ pub(crate) mod private_generation_utils {
)
.unwrap()
.0
.ids
})
.collect::<Vec<Vec<i64>>>();
@ -2495,11 +2512,14 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
/// ]
/// # ;
/// ```
fn generate(
fn generate<'a, S>(
&self,
prompt_texts: Option<Vec<&str>>,
prompt_texts: Option<S>,
attention_mask: Option<Tensor>,
) -> Vec<String> {
) -> Vec<String>
where
S: AsRef<[&'a str]>,
{
let eos_token_ids = PrivateLanguageGenerator::get_eos_ids(self).clone();
let config = PrivateLanguageGenerator::get_config(self);

View File

@ -55,7 +55,7 @@ use crate::pipelines::common::{ConfigOption, ModelType, TokenizerOption};
use crate::roberta::RobertaForQuestionAnswering;
use crate::xlnet::XLNetForQuestionAnswering;
use rust_tokenizers::tokenizer::{truncate_sequences, TruncationStrategy};
use rust_tokenizers::{Mask, TokenizedInput};
use rust_tokenizers::{Mask, TokenIdsWithOffsets, TokenizedInput};
use std::borrow::Borrow;
use std::cmp::min;
use std::collections::HashMap;
@ -697,13 +697,12 @@ impl QuestionAnsweringModel {
TokenizerOption::Roberta(_) => {
self.tokenizer
.build_input_with_special_tokens(
vec![],
None,
vec![],
None,
vec![],
None,
vec![],
TokenIdsWithOffsets {
ids: vec![],
offsets: vec![],
reference_offsets: vec![],
masks: vec![],
},
None,
)
.token_ids
@ -713,13 +712,12 @@ impl QuestionAnsweringModel {
_ => self
.tokenizer
.build_input_with_special_tokens(
vec![],
None,
vec![],
None,
vec![],
None,
vec![],
TokenIdsWithOffsets {
ids: vec![],
offsets: vec![],
reference_offsets: vec![],
masks: vec![],
},
None,
)
.token_ids
@ -729,14 +727,18 @@ impl QuestionAnsweringModel {
let sequence_pair_added_tokens = self
.tokenizer
.build_input_with_special_tokens(
vec![],
Some(vec![]),
vec![],
Some(vec![]),
vec![],
Some(vec![]),
vec![],
Some(vec![]),
TokenIdsWithOffsets {
ids: vec![],
offsets: vec![],
reference_offsets: vec![],
masks: vec![],
},
Some(TokenIdsWithOffsets {
ids: vec![],
offsets: vec![],
reference_offsets: vec![],
masks: vec![],
}),
)
.token_ids
.len();
@ -795,20 +797,21 @@ impl QuestionAnsweringModel {
} else {
0
};
let (truncated_query, _, _, _, _, _, _, _, _, _) = truncate_sequences(
truncated_query,
None,
vec![],
None,
vec![],
None,
vec![],
let truncated_query = truncate_sequences(
TokenIdsWithOffsets {
ids: truncated_query,
offsets: vec![],
reference_offsets: vec![],
masks: vec![],
},
None,
num_query_tokens_to_remove,
&TruncationStrategy::OnlyFirst,
0,
)
.unwrap();
.unwrap()
.0
.ids;
truncated_query
}
@ -829,32 +832,28 @@ impl QuestionAnsweringModel {
0
};
let (truncated_query, truncated_context, _, _, _, _, _, _, overflowing_tokens, _) =
truncate_sequences(
truncated_query.into(),
Some(spans_token_ids.into()),
vec![],
None,
vec![],
None,
vec![],
None,
num_truncated_tokens,
&TruncationStrategy::OnlySecond,
max_seq_length - doc_stride - len_1 - sequence_pair_added_tokens,
)
.unwrap();
let (truncated_query, truncated_context, overflowing_tokens, _) = truncate_sequences(
TokenIdsWithOffsets {
ids: truncated_query.into(),
offsets: vec![],
reference_offsets: vec![],
masks: vec![],
},
Some(TokenIdsWithOffsets {
ids: spans_token_ids.into(),
offsets: vec![],
reference_offsets: vec![],
masks: vec![],
}),
num_truncated_tokens,
&TruncationStrategy::OnlySecond,
max_seq_length - doc_stride - len_1 - sequence_pair_added_tokens,
)
.unwrap();
let mut tokenized_input = self.tokenizer.build_input_with_special_tokens(
truncated_query,
truncated_context,
vec![],
None,
vec![],
None,
vec![],
None,
);
let mut tokenized_input = self
.tokenizer
.build_input_with_special_tokens(truncated_query, truncated_context);
let mut attention_mask = vec![1; tokenized_input.token_ids.len()];
if tokenized_input.token_ids.len() < max_seq_length {
tokenized_input.token_ids.append(&mut vec![

View File

@ -133,7 +133,10 @@ impl SentimentModel {
/// # Ok(())
/// # }
/// ```
pub fn predict(&self, input: &[&str]) -> Vec<Sentiment> {
pub fn predict<'a, S>(&self, input: S) -> Vec<Sentiment>
where
S: AsRef<[&'a str]>,
{
let labels = self.sequence_classification_model.predict(input);
let mut sentiments = Vec::with_capacity(labels.len());
for label in labels {

View File

@ -435,10 +435,13 @@ impl SequenceClassificationModel {
})
}
fn prepare_for_model(&self, input: Vec<&str>) -> Tensor {
fn prepare_for_model<'a, S>(&self, input: S) -> Tensor
where
S: AsRef<[&'a str]>,
{
let tokenized_input: Vec<TokenizedInput> =
self.tokenizer
.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
.encode_list(input.as_ref(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
@ -487,8 +490,11 @@ impl SequenceClassificationModel {
/// # Ok(())
/// # }
/// ```
pub fn predict(&self, input: &[&str]) -> Vec<Label> {
let input_tensor = self.prepare_for_model(input.to_vec());
pub fn predict<'a, S>(&self, input: S) -> Vec<Label>
where
S: AsRef<[&'a str]>,
{
let input_tensor = self.prepare_for_model(input.as_ref());
let output = no_grad(|| {
let output = self.sequence_classifier.forward_t(
Some(input_tensor.copy()),

View File

@ -236,7 +236,10 @@ impl SummarizationModel {
/// # }
/// ```
/// (New sample credits: [WikiNews](https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b))
pub fn summarize(&self, texts: &[&str]) -> Vec<String> {
self.model.generate(Some(texts.to_vec()), None)
pub fn summarize<'a, S>(&self, texts: S) -> Vec<String>
where
S: AsRef<[&'a str]>,
{
self.model.generate(Some(texts.as_ref()), None)
}
}

View File

@ -546,10 +546,13 @@ impl TokenClassificationModel {
})
}
fn prepare_for_model(&self, input: Vec<&str>) -> (Vec<TokenizedInput>, Tensor) {
fn prepare_for_model<'a, S>(&self, input: S) -> (Vec<TokenizedInput>, Tensor)
where
S: AsRef<[&'a str]>,
{
let tokenized_input: Vec<TokenizedInput> =
self.tokenizer
.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
.encode_list(input.as_ref(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
@ -602,13 +605,16 @@ impl TokenClassificationModel {
/// # Ok(())
/// # }
/// ```
pub fn predict(
pub fn predict<'a, S>(
&self,
input: &[&str],
input: S,
consolidate_sub_tokens: bool,
return_special: bool,
) -> Vec<Token> {
let (tokenized_input, input_tensor) = self.prepare_for_model(input.to_vec());
) -> Vec<Token>
where
S: AsRef<[&'a str]>,
{
let (tokenized_input, input_tensor) = self.prepare_for_model(input.as_ref());
let output = no_grad(|| {
self.token_sequence_classifier.forward_t(
Some(input_tensor.copy()),
@ -626,7 +632,7 @@ impl TokenClassificationModel {
for sentence_idx in 0..labels_idx.size()[0] {
let labels = labels_idx.get(sentence_idx);
let sentence_tokens = &tokenized_input[sentence_idx as usize];
let original_chars = input[sentence_idx as usize].chars().collect_vec();
let original_chars = input.as_ref()[sentence_idx as usize].chars().collect_vec();
let mut word_idx: u16 = 0;
for position_idx in 0..sentence_tokens.token_ids.len() {
let mask = sentence_tokens.mask[position_idx];

View File

@ -540,11 +540,14 @@ impl TranslationOption {
}
/// Interface method to generate() of the particular models.
pub fn generate(
pub fn generate<'a, S>(
&self,
prompt_texts: Option<Vec<&str>>,
prompt_texts: Option<S>,
attention_mask: Option<Tensor>,
) -> Vec<String> {
) -> Vec<String>
where
S: AsRef<[&'a str]>,
{
match *self {
Self::Marian(ref model) => model.generate(prompt_texts, attention_mask),
Self::T5(ref model) => model.generate(prompt_texts, attention_mask),
@ -612,17 +615,23 @@ impl TranslationModel {
/// # Ok(())
/// # }
/// ```
pub fn translate(&self, texts: &[&str]) -> Vec<String> {
pub fn translate<'a, S>(&self, texts: S) -> Vec<String>
where
S: AsRef<[&'a str]>,
{
match &self.prefix {
Some(value) => {
let texts = texts
.as_ref()
.iter()
.map(|&v| format!("{}{}", value, v))
.collect::<Vec<String>>();
self.model
.generate(Some(texts.iter().map(AsRef::as_ref).collect()), None)
self.model.generate(
Some(texts.iter().map(AsRef::as_ref).collect::<Vec<&str>>()),
None,
)
}
None => self.model.generate(Some(texts.to_vec()), None),
None => self.model.generate(Some(texts), None),
}
}
}

View File

@ -484,10 +484,10 @@ impl ZeroShotClassificationModel {
.iter()
.cartesian_product(label_sentences.iter())
.map(|(&s, label)| (s, label.as_str()))
.collect();
.collect::<Vec<(&str, &str)>>();
let tokenized_input: Vec<TokenizedInput> = self.tokenizer.encode_pair_list(
text_pair_list,
text_pair_list.as_ref(),
max_len,
&TruncationStrategy::LongestFirst,
0,

View File

@ -43,8 +43,7 @@ fn albert_masked_lm() -> anyhow::Result<()> {
"Looks like one [MASK] is missing",
"It\'s like comparing [MASK] to apples",
];
let tokenized_input =
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let tokenized_input = tokenizer.encode_list(&input, 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
@ -117,8 +116,7 @@ fn albert_for_sequence_classification() -> anyhow::Result<()> {
"Looks like one thing is missing",
"It\'s like comparing oranges to apples",
];
let tokenized_input =
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let tokenized_input = tokenizer.encode_list(&input, 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
@ -179,8 +177,7 @@ fn albert_for_multiple_choice() -> anyhow::Result<()> {
"Looks like one thing is missing",
"It\'s like comparing oranges to apples",
];
let tokenized_input =
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let tokenized_input = tokenizer.encode_list(&input, 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
@ -252,8 +249,7 @@ fn albert_for_token_classification() -> anyhow::Result<()> {
"Looks like one thing is missing",
"It\'s like comparing oranges to apples",
];
let tokenized_input =
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let tokenized_input = tokenizer.encode_list(&input, 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
@ -314,8 +310,7 @@ fn albert_for_question_answering() -> anyhow::Result<()> {
"Looks like one thing is missing",
"It\'s like comparing oranges to apples",
];
let tokenized_input =
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let tokenized_input = tokenizer.encode_list(&input, 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())

View File

@ -43,8 +43,7 @@ fn bart_lm_model() -> anyhow::Result<()> {
// Define input
let input = ["One two three four"];
let tokenized_input =
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let tokenized_input = tokenizer.encode_list(&input, 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())

View File

@ -45,8 +45,7 @@ fn bert_masked_lm() -> anyhow::Result<()> {
"Looks like one thing is missing",
"It\'s like comparing oranges to apples",
];
let tokenized_input =
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let tokenized_input = tokenizer.encode_list(&input, 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
@ -134,8 +133,7 @@ fn bert_for_sequence_classification() -> anyhow::Result<()> {
"Looks like one thing is missing",
"It\'s like comparing oranges to apples",
];
let tokenized_input =
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let tokenized_input = tokenizer.encode_list(&input, 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
@ -194,8 +192,7 @@ fn bert_for_multiple_choice() -> anyhow::Result<()> {
"Looks like one thing is missing",
"It\'s like comparing oranges to apples",
];
let tokenized_input =
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let tokenized_input = tokenizer.encode_list(&input, 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
@ -261,8 +258,7 @@ fn bert_for_token_classification() -> anyhow::Result<()> {
"Looks like one thing is missing",
"It\'s like comparing oranges to apples",
];
let tokenized_input =
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let tokenized_input = tokenizer.encode_list(&input, 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
@ -321,8 +317,7 @@ fn bert_for_question_answering() -> anyhow::Result<()> {
"Looks like one thing is missing",
"It\'s like comparing oranges to apples",
];
let tokenized_input =
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let tokenized_input = tokenizer.encode_list(&input, 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())

View File

@ -69,8 +69,7 @@ fn distilbert_masked_lm() -> anyhow::Result<()> {
"Looks like one thing is missing",
"It\'s like comparing oranges to apples",
];
let tokenized_input =
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let tokenized_input = tokenizer.encode_list(&input, 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
@ -148,8 +147,7 @@ fn distilbert_for_question_answering() -> anyhow::Result<()> {
"Looks like one thing is missing",
"It\'s like comparing oranges to apples",
];
let tokenized_input =
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let tokenized_input = tokenizer.encode_list(&input, 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
@ -220,8 +218,7 @@ fn distilbert_for_token_classification() -> anyhow::Result<()> {
"Looks like one thing is missing",
"It\'s like comparing oranges to apples",
];
let tokenized_input =
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let tokenized_input = tokenizer.encode_list(&input, 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())

View File

@ -42,8 +42,7 @@ fn distilgpt2_lm_model() -> anyhow::Result<()> {
// Define input
let input = ["One two three four five six seven eight nine ten eleven"];
let tokenized_input =
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let tokenized_input = tokenizer.encode_list(&input, 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())

View File

@ -40,8 +40,7 @@ fn electra_masked_lm() -> anyhow::Result<()> {
"Looks like one [MASK] is missing",
"It was a very nice and [MASK] day",
];
let tokenized_input =
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let tokenized_input = tokenizer.encode_list(&input, 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
@ -120,8 +119,7 @@ fn electra_discriminator() -> anyhow::Result<()> {
// Define input
let input = ["One Two Three Ten Five Six Seven Eight"];
let tokenized_input =
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let tokenized_input = tokenizer.encode_list(&input, 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())

View File

@ -43,8 +43,7 @@ fn gpt2_lm_model() -> anyhow::Result<()> {
// Define input
let input = ["One two three four"];
let tokenized_input =
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let tokenized_input = tokenizer.encode_list(&input, 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())

View File

@ -45,8 +45,7 @@ fn openai_gpt_lm_model() -> anyhow::Result<()> {
// Define input
let input = ["Wondering what the next word will"];
let tokenized_input =
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let tokenized_input = tokenizer.encode_list(&input, 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())

View File

@ -55,8 +55,7 @@ fn roberta_masked_lm() -> anyhow::Result<()> {
"<pad> Looks like one thing is missing",
"It\'s like comparing oranges to apples",
];
let tokenized_input =
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let tokenized_input = tokenizer.encode_list(&input, 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
@ -154,8 +153,7 @@ fn roberta_for_sequence_classification() -> anyhow::Result<()> {
"Looks like one thing is missing",
"It\'s like comparing oranges to apples",
];
let tokenized_input =
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let tokenized_input = tokenizer.encode_list(&input, 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
@ -224,8 +222,7 @@ fn roberta_for_multiple_choice() -> anyhow::Result<()> {
"Looks like one thing is missing",
"It\'s like comparing oranges to apples",
];
let tokenized_input =
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let tokenized_input = tokenizer.encode_list(&input, 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
@ -301,8 +298,7 @@ fn roberta_for_token_classification() -> anyhow::Result<()> {
"Looks like one thing is missing",
"It\'s like comparing oranges to apples",
];
let tokenized_input =
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let tokenized_input = tokenizer.encode_list(&input, 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())

View File

@ -40,8 +40,7 @@ fn xlnet_base_model() -> anyhow::Result<()> {
// Define input
let input = ["One two three four"];
let tokenized_input =
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let tokenized_input = tokenizer.encode_list(&input, 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
@ -146,8 +145,7 @@ fn xlnet_lm_model() -> anyhow::Result<()> {
// Define input
let input = ["One two three four"];
let tokenized_input =
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let tokenized_input = tokenizer.encode_list(&input, 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
@ -264,8 +262,7 @@ fn xlnet_for_sequence_classification() -> anyhow::Result<()> {
// Define input
let input = ["Very positive sentence", "Second sentence input"];
let tokenized_input =
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let tokenized_input = tokenizer.encode_list(&input, 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
@ -332,7 +329,10 @@ fn xlnet_for_multiple_choice() -> anyhow::Result<()> {
let prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced.";
let inputs = ["Very positive sentence", "Second sentence input"];
let tokenized_input = tokenizer.encode_pair_list(
inputs.iter().map(|&inp| (prompt, inp)).collect(),
inputs
.iter()
.map(|&inp| (prompt, inp))
.collect::<Vec<(&str, &str)>>(),
128,
&TruncationStrategy::LongestFirst,
0,
@ -401,8 +401,7 @@ fn xlnet_for_token_classification() -> anyhow::Result<()> {
// Define input
let inputs = ["Where's Paris?", "In Kentucky, United States"];
let tokenized_input =
tokenizer.encode_list(inputs.into(), 128, &TruncationStrategy::LongestFirst, 0);
let tokenized_input = tokenizer.encode_list(&inputs, 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
@ -460,7 +459,7 @@ fn xlnet_for_question_answering() -> anyhow::Result<()> {
// Define input
let inputs = ["Where's Paris?", "Paris is in In Kentucky, United States"];
let tokenized_input = tokenizer.encode_pair_list(
vec![(inputs[0], inputs[1])],
&[(inputs[0], inputs[1])],
128,
&TruncationStrategy::LongestFirst,
0,