diff --git a/CHANGELOG.md b/CHANGELOG.md index 1826fa7..b0a649a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,10 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). ## [Unreleased] +## Changed +- Updated to `tch` 1.6.0 (libtorch 1.10) +- (BREAKING) Simplified the generics for multiple library traits taking as a rule `&[AsRef]` or `&str` as inputs (no longer accepts owned types `Vec` and `String`) + ## Added - (BREAKING) Support for `bad_word_ids` generation, allowing to ban a set of word ids for all model supporting text generation - Support for half-precision mode for all models (reducing memory footprint). A model can be converted to half-precision by calling the `half()` method on the `VarStore` is it currently stored in. Half-precision Torch kernels are not available for CPU (limited to CUDA devices) diff --git a/Cargo.toml b/Cargo.toml index b6322a1..2df498c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -57,21 +57,21 @@ all-tests = [] features = ["doc-only"] [dependencies] -rust_tokenizers = "~6.2.4" -tch = { version = "0.5.0", path = "E:/Coding/tch-rs" } -serde_json = "1.0.66" -serde = { version = "1.0.129", features = ["derive"] } -dirs = "3.0.2" -ordered-float = "2.7.0" +rust_tokenizers = "~7.0.0" +tch = "~0.6.1" +serde_json = "1.0.68" +serde = { version = "1.0.130", features = ["derive"] } +dirs = "4.0.0" +ordered-float = "2.8.0" cached-path = "0.5.1" lazy_static = "1.4.0" uuid = { version = "0.8.2", features = ["v4"] } -thiserror = "1.0.26" +thiserror = "1.0.30" half = "1.7.1" [dev-dependencies] -anyhow = "1.0.43" +anyhow = "1.0.44" csv = "1.1.6" criterion = "0.3.5" -torch-sys = { version = "0.5.0", path = "E:/Coding/tch-rs/torch-sys" } +torch-sys = "~0.6.1" tempfile = "3.2.0" diff --git a/README.md b/README.md index ef16df1..a51c294 100644 --- a/README.md +++ b/README.md @@ -71,8 +71,8 @@ This cache location defaults to `~/.cache/.rustbert`, but can be changed by sett ### Manual installation (recommended) -1. Download `libtorch` from https://pytorch.org/get-started/locally/. This package requires `v1.9.0`: if this version is no longer available on the "get started" page, -the file should be accessible by modifying the target link, for example `https://download.pytorch.org/libtorch/cu111/libtorch-shared-with-deps-1.9.0%2Bcu111.zip` for a Linux version with CUDA11. +1. Download `libtorch` from https://pytorch.org/get-started/locally/. This package requires `v1.10.0`: if this version is no longer available on the "get started" page, +the file should be accessible by modifying the target link, for example `https://download.pytorch.org/libtorch/cu111/libtorch-shared-with-deps-1.10.0%2Bcu111.zip` for a Linux version with CUDA11. 2. Extract the library to a location of your choice 3. Set the following environment variables ##### Linux: diff --git a/examples/translation_m2m100.rs b/examples/translation_m2m100.rs index 74c0a65..032a762 100644 --- a/examples/translation_m2m100.rs +++ b/examples/translation_m2m100.rs @@ -53,9 +53,9 @@ fn main() -> anyhow::Result<()> { let source_sentence = "This sentence will be translated in multiple languages."; let mut outputs = Vec::new(); - outputs.extend(model.translate([source_sentence], Language::English, Language::French)?); - outputs.extend(model.translate([source_sentence], Language::English, Language::Spanish)?); - outputs.extend(model.translate([source_sentence], Language::English, Language::Hindi)?); + outputs.extend(model.translate(&[source_sentence], Language::English, Language::French)?); + outputs.extend(model.translate(&[source_sentence], Language::English, Language::Spanish)?); + outputs.extend(model.translate(&[source_sentence], Language::English, Language::Hindi)?); for sentence in outputs { println!("{}", sentence); diff --git a/examples/translation_mbart.rs b/examples/translation_mbart.rs index aea5a11..73960e0 100644 --- a/examples/translation_mbart.rs +++ b/examples/translation_mbart.rs @@ -53,9 +53,9 @@ fn main() -> anyhow::Result<()> { let source_sentence = "This sentence will be translated in multiple languages."; let mut outputs = Vec::new(); - outputs.extend(model.translate([source_sentence], Language::English, Language::French)?); - outputs.extend(model.translate([source_sentence], Language::English, Language::Spanish)?); - outputs.extend(model.translate([source_sentence], Language::English, Language::Hindi)?); + outputs.extend(model.translate(&[source_sentence], Language::English, Language::French)?); + outputs.extend(model.translate(&[source_sentence], Language::English, Language::Spanish)?); + outputs.extend(model.translate(&[source_sentence], Language::English, Language::Hindi)?); for sentence in outputs { println!("{}", sentence); diff --git a/examples/translation_t5.rs b/examples/translation_t5.rs index b2efeb8..083412c 100644 --- a/examples/translation_t5.rs +++ b/examples/translation_t5.rs @@ -56,9 +56,9 @@ fn main() -> anyhow::Result<()> { let source_sentence = "This sentence will be translated in multiple languages."; let mut outputs = Vec::new(); - outputs.extend(model.translate([source_sentence], Language::English, Language::French)?); - outputs.extend(model.translate([source_sentence], Language::English, Language::German)?); - outputs.extend(model.translate([source_sentence], Language::English, Language::Romanian)?); + outputs.extend(model.translate(&[source_sentence], Language::English, Language::French)?); + outputs.extend(model.translate(&[source_sentence], Language::English, Language::German)?); + outputs.extend(model.translate(&[source_sentence], Language::English, Language::Romanian)?); for sentence in outputs { println!("{}", sentence); diff --git a/src/bart/bart_model.rs b/src/bart/bart_model.rs index e79bc3b..b9fd2a9 100644 --- a/src/bart/bart_model.rs +++ b/src/bart/bart_model.rs @@ -1205,17 +1205,17 @@ impl PrivateLanguageGenerator( + fn encode_prompt_text( &self, - prompt_text: S, + prompt_text: &[S], max_len: i64, pad_token_id: Option, ) -> Tensor where - S: AsRef<[&'a str]>, + S: AsRef + Sync, { let tokens = self._get_tokenizer().encode_list( - prompt_text.as_ref(), + prompt_text, max_len as usize, &TruncationStrategy::LongestFirst, 0, diff --git a/src/common/activations.rs b/src/common/activations.rs index ab7f29a..3deedf6 100644 --- a/src/common/activations.rs +++ b/src/common/activations.rs @@ -19,7 +19,7 @@ pub fn _mish(x: &Tensor) -> Tensor { } pub fn _gelu_new(x: &Tensor) -> Tensor { - x * 0.5 * (((x.pow(3.0f64) * 0.044715 + x) * ((2f64 / PI).sqrt())).tanh() + 1) + x * 0.5 * (((x.pow_tensor_scalar(3.0f64) * 0.044715 + x) * ((2f64 / PI).sqrt())).tanh() + 1) } pub fn _tanh(x: &Tensor) -> Tensor { diff --git a/src/lib.rs b/src/lib.rs index 52fd534..3605ddd 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -81,8 +81,8 @@ //! //! ### Manual installation (recommended) //! -//! 1. Download `libtorch` from . This package requires `v1.8.1`: if this version is no longer available on the "get started" page, -//! the file should be accessible by modifying the target link, for example `https://download.pytorch.org/libtorch/cu111/libtorch-shared-with-deps-1.8.1%2Bcu111.zip` for a Linux version with CUDA11. +//! 1. Download `libtorch` from . This package requires `v1.10.0`: if this version is no longer available on the "get started" page, +//! the file should be accessible by modifying the target link, for example `https://download.pytorch.org/libtorch/cu111/libtorch-shared-with-deps-1.10.0%2Bcu111.zip` for a Linux version with CUDA11. //! 2. Extract the library to a location of your choice //! 3. Set the following environment variables //! ##### Linux: diff --git a/src/m2m_100/m2m_100_model.rs b/src/m2m_100/m2m_100_model.rs index 7d876d0..1077f6d 100644 --- a/src/m2m_100/m2m_100_model.rs +++ b/src/m2m_100/m2m_100_model.rs @@ -801,17 +801,17 @@ impl PrivateLanguageGenerator( + fn encode_prompt_text( &self, - prompt_text: S, + prompt_text: &[S], max_len: i64, pad_token_id: Option, ) -> Tensor where - S: AsRef<[&'a str]>, + S: AsRef + Sync, { let tokens = self._get_tokenizer().encode_list( - prompt_text.as_ref(), + prompt_text, max_len as usize, &TruncationStrategy::LongestFirst, 0, diff --git a/src/marian/marian_model.rs b/src/marian/marian_model.rs index d506b9f..c273d66 100644 --- a/src/marian/marian_model.rs +++ b/src/marian/marian_model.rs @@ -979,17 +979,17 @@ impl PrivateLanguageGenerator( + fn encode_prompt_text( &self, - prompt_text: T, + prompt_text: &[S], max_len: i64, pad_token_id: Option, ) -> Tensor where - T: AsRef<[&'a str]>, + S: AsRef + Sync, { let tokens = self._get_tokenizer().encode_list( - prompt_text.as_ref(), + prompt_text, max_len as usize, &TruncationStrategy::LongestFirst, 0, diff --git a/src/mbart/mbart_model.rs b/src/mbart/mbart_model.rs index a9e40ac..5a91a79 100644 --- a/src/mbart/mbart_model.rs +++ b/src/mbart/mbart_model.rs @@ -1011,17 +1011,17 @@ impl PrivateLanguageGenerator( + fn encode_prompt_text( &self, - prompt_text: S, + prompt_text: &[S], max_len: i64, pad_token_id: Option, ) -> Tensor where - S: AsRef<[&'a str]>, + S: AsRef + Sync, { let tokens = self._get_tokenizer().encode_list( - prompt_text.as_ref(), + prompt_text, max_len as usize, &TruncationStrategy::LongestFirst, 0, diff --git a/src/pegasus/pegasus_model.rs b/src/pegasus/pegasus_model.rs index 5c6bfea..e9eb873 100644 --- a/src/pegasus/pegasus_model.rs +++ b/src/pegasus/pegasus_model.rs @@ -774,17 +774,17 @@ impl PrivateLanguageGenerator( + fn encode_prompt_text( &self, - prompt_text: S, + prompt_text: &[S], max_len: i64, pad_token_id: Option, ) -> Tensor where - S: AsRef<[&'a str]>, + S: AsRef + Sync, { let tokens = self._get_tokenizer().encode_list( - prompt_text.as_ref(), + prompt_text, max_len as usize, &TruncationStrategy::LongestFirst, 0, diff --git a/src/pipelines/common.rs b/src/pipelines/common.rs index fe70201..b2dd0c8 100644 --- a/src/pipelines/common.rs +++ b/src/pipelines/common.rs @@ -494,13 +494,16 @@ impl TokenizerOption { } /// Interface method - pub fn encode_list( + pub fn encode_list( &self, - text_list: &[&str], + text_list: &[S], max_len: usize, truncation_strategy: &TruncationStrategy, stride: usize, - ) -> Vec { + ) -> Vec + where + S: AsRef + Sync, + { match *self { Self::Bert(ref tokenizer) => MultiThreadedTokenizer::encode_list( tokenizer, @@ -809,7 +812,10 @@ impl TokenizerOption { } /// Interface method to tokenization - pub fn tokenize_list(&self, text: &[&str]) -> Vec> { + pub fn tokenize_list(&self, text: &[S]) -> Vec> + where + S: AsRef + Sync, + { match *self { Self::Bert(ref tokenizer) => MultiThreadedTokenizer::tokenize_list(tokenizer, text), Self::Roberta(ref tokenizer) => MultiThreadedTokenizer::tokenize_list(tokenizer, text), @@ -837,7 +843,7 @@ impl TokenizerOption { /// Interface method to decoding pub fn decode( &self, - token_ids: Vec, + token_ids: &[i64], skip_special_tokens: bool, clean_up_tokenization_spaces: bool, ) -> String { @@ -964,10 +970,9 @@ impl TokenizerOption { } /// Interface method to convert tokens to ids - pub fn convert_tokens_to_ids(&self, tokens: S) -> Vec + pub fn convert_tokens_to_ids(&self, tokens: &[S]) -> Vec where - S: AsRef<[ST]>, - ST: AsRef, + S: AsRef, { match *self { Self::Bert(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens), diff --git a/src/pipelines/conversation.rs b/src/pipelines/conversation.rs index 0a4050b..77b5f42 100644 --- a/src/pipelines/conversation.rs +++ b/src/pipelines/conversation.rs @@ -416,22 +416,20 @@ impl Conversation { /// let _ = conversation_manager /// .get(&conversation_1_id) /// .unwrap() - /// .load_from_history(history, encoded_history); + /// .load_from_history(&history, &encoded_history); /// # Ok(()) /// # } /// ``` - pub fn load_from_history(&mut self, texts: ST, ids: SI) + pub fn load_from_history(&mut self, texts: &[S], ids: &[SI]) where - ST: AsRef<[STR]>, - SI: AsRef<[SIN]>, - STR: AsRef, - SIN: AsRef<[i64]>, + S: AsRef, + SI: AsRef<[i64]>, { - for (round_text, round_ids) in texts.as_ref().iter().zip(ids.as_ref().iter()) { + for (round_text, round_ids) in texts.iter().zip(ids.iter()) { self.append(round_text.as_ref(), round_ids.as_ref()); } - if texts.as_ref().len() / 2 == 1 { + if texts.len() / 2 == 1 { self.history.pop(); } } @@ -840,11 +838,11 @@ impl ConversationModel { let generated_response = &generated_sequence[input_length - removed_padding.0..]; conversation .generated_responses - .push(self.model.get_tokenizer().decode( - generated_response.to_vec(), - true, - true, - )); + .push( + self.model + .get_tokenizer() + .decode(generated_response, true, true), + ); conversation.history.push(conversation_promp_ids); conversation.history.push(generated_response.to_vec()); conversation.mark_processed(); diff --git a/src/pipelines/generation_utils.rs b/src/pipelines/generation_utils.rs index 0650da6..f1e73ac 100644 --- a/src/pipelines/generation_utils.rs +++ b/src/pipelines/generation_utils.rs @@ -41,7 +41,7 @@ //! let input_context = "The dog"; //! let second_input_context = "The cat was"; //! let output = gpt2_generator.generate( -//! Some(vec![input_context, second_input_context]), +//! Some(&[input_context, second_input_context]), //! None, //! min_length, //! max_length, @@ -321,16 +321,16 @@ pub(crate) mod private_generation_utils { } } - fn encode_prompt_text<'a, S>( + fn encode_prompt_text( &self, - prompt_text: S, + prompt_text: &[S], max_len: i64, pad_token_id: Option, ) -> Tensor where - S: AsRef<[&'a str]>, + S: AsRef + Sync, { - let tokens = self._get_tokenizer().tokenize_list(prompt_text.as_ref()); + let tokens = self._get_tokenizer().tokenize_list(prompt_text); let token_ids = tokens .into_iter() .map(|prompt_tokens| self._get_tokenizer().convert_tokens_to_ids(&prompt_tokens)) @@ -934,7 +934,7 @@ pub(crate) mod private_generation_utils { current_length += 1; } let scores_output = scores_output.map(|scores_tensor| { - (scores_tensor / sentence_lengths.pow(gen_opt.length_penalty)) + (scores_tensor / sentence_lengths.pow_tensor_scalar(gen_opt.length_penalty)) .iter::() .unwrap() .collect::>() @@ -1504,7 +1504,7 @@ pub trait LanguageGenerator>: /// } /// /// let output = gpt2_generator.generate( - /// Some(vec![input_context, second_input_context]), + /// Some(&[input_context, second_input_context]), /// attention_mask, /// min_length, /// max_length, @@ -1530,9 +1530,9 @@ pub trait LanguageGenerator>: /// ] /// # ; /// ``` - fn generate<'a, S>( + fn generate( &self, - prompt_texts: Option, + prompt_texts: Option<&[S]>, attention_mask: Option, min_length: impl Into>, max_length: impl Into>, @@ -1543,7 +1543,7 @@ pub trait LanguageGenerator>: output_scores: bool, ) -> Vec where - S: AsRef<[&'a str]>, + S: AsRef + Sync, { let indices_outputs = self.generate_indices( prompt_texts, @@ -1561,7 +1561,7 @@ pub trait LanguageGenerator>: output.push(GeneratedTextOutput { text: self ._get_tokenizer() - .decode(generated_sequence.indices, true, true), + .decode(&generated_sequence.indices, true, true), score: generated_sequence.score, }); } @@ -1636,7 +1636,7 @@ pub trait LanguageGenerator>: /// } /// /// let output = gpt2_generator.generate_indices( - /// Some(vec![input_context, second_input_context]), + /// Some(&[input_context, second_input_context]), /// attention_mask, /// min_length, /// max_length, @@ -1649,9 +1649,9 @@ pub trait LanguageGenerator>: /// # Ok(()) /// # } /// ``` - fn generate_indices<'a, S>( + fn generate_indices( &self, - prompt_texts: Option, + prompt_texts: Option<&[S]>, attention_mask: Option, min_length: impl Into>, max_length: impl Into>, @@ -1662,7 +1662,7 @@ pub trait LanguageGenerator>: output_scores: bool, ) -> Vec where - S: AsRef<[&'a str]>, + S: AsRef + Sync, { let eos_token_ids = PrivateLanguageGenerator::get_eos_ids(self).clone(); @@ -1771,7 +1771,7 @@ pub trait LanguageGenerator>: /// } /// /// let output = gpt2_generator.generate_indices( - /// Some(vec![input_context, second_input_context]), + /// Some(&[input_context, second_input_context]), /// attention_mask, /// min_length, /// max_length, diff --git a/src/pipelines/ner.rs b/src/pipelines/ner.rs index e3f52a5..6e07c9f 100644 --- a/src/pipelines/ner.rs +++ b/src/pipelines/ner.rs @@ -191,9 +191,9 @@ impl NERModel { /// # Ok(()) /// # } /// ``` - pub fn predict<'a, S>(&self, input: S) -> Vec> + pub fn predict(&self, input: &[S]) -> Vec> where - S: AsRef<[&'a str]>, + S: AsRef, { self.token_classification_model .predict(input, true, false) diff --git a/src/pipelines/pos_tagging.rs b/src/pipelines/pos_tagging.rs index 339b4ef..6d9fb22 100644 --- a/src/pipelines/pos_tagging.rs +++ b/src/pipelines/pos_tagging.rs @@ -195,9 +195,9 @@ impl POSModel { /// # Ok(()) /// # } /// ``` - pub fn predict<'a, S>(&self, input: S) -> Vec> + pub fn predict(&self, input: &[S]) -> Vec> where - S: AsRef<[&'a str]>, + S: AsRef, { self.token_classification_model .predict(input, true, false) diff --git a/src/pipelines/summarization.rs b/src/pipelines/summarization.rs index 4f5114e..460c49b 100644 --- a/src/pipelines/summarization.rs +++ b/src/pipelines/summarization.rs @@ -254,13 +254,13 @@ impl SummarizationOption { } /// Interface method to generate() of the particular models. - pub fn generate<'a, S>( + pub fn generate( &self, - prompt_texts: Option, + prompt_texts: Option<&[S]>, attention_mask: Option, ) -> Vec where - S: AsRef<[&'a str]>, + S: AsRef + Sync, { match *self { Self::Bart(ref model) => model @@ -406,22 +406,18 @@ impl SummarizationModel { /// # } /// ``` /// (New sample credits: [WikiNews](https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b)) - pub fn summarize<'a, S>(&self, texts: S) -> Vec + pub fn summarize(&self, texts: &[S]) -> Vec where - S: AsRef<[&'a str]>, + S: AsRef + Sync, { match &self.prefix { None => self.model.generate(Some(texts), None), Some(prefix) => { let texts = texts - .as_ref() .iter() - .map(|text| format!("{}{}", prefix, text)) + .map(|text| format!("{}{}", prefix, text.as_ref())) .collect::>(); - self.model.generate( - Some(texts.iter().map(|x| &**x).collect::>()), - None, - ) + self.model.generate(Some(&texts), None) } } } diff --git a/src/pipelines/text_generation.rs b/src/pipelines/text_generation.rs index 83d6338..f87af0f 100644 --- a/src/pipelines/text_generation.rs +++ b/src/pipelines/text_generation.rs @@ -238,15 +238,15 @@ impl TextGenerationOption { } /// Interface method to generate() of the particular models. - pub fn generate_indices<'a, S>( + pub fn generate_indices( &self, - prompt_texts: Option, + prompt_texts: Option<&[S]>, attention_mask: Option, min_length: Option, max_length: Option, ) -> Vec> where - S: AsRef<[&'a str]>, + S: AsRef + Sync, { match *self { Self::GPT(ref model) => model @@ -460,9 +460,9 @@ with people, even a bishop, begging for his blessing. " /// # Ok(()) /// # } /// ``` - pub fn generate<'a, S>(&self, texts: S, prefix: impl Into>) -> Vec + pub fn generate<'a, S>(&self, texts: &[S], prefix: impl Into>) -> Vec where - S: AsRef<[&'a str]>, + S: AsRef + Sync, { let (prefix, prefix_length) = match (prefix.into(), &self.prefix) { (Some(query_prefix), _) => ( @@ -478,10 +478,10 @@ with people, even a bishop, begging for his blessing. " let texts = texts .as_ref() .iter() - .map(|text| format!("{} {}", prefix, text)) + .map(|text| format!("{} {}", prefix, text.as_ref())) .collect::>(); self.model.generate_indices( - Some(texts.iter().map(|x| &**x).collect::>()), + Some(&texts), None, Some(self.min_length + prefix_length), Some(self.max_length + prefix_length), @@ -493,14 +493,7 @@ with people, even a bishop, begging for his blessing. " let mut output = Vec::with_capacity(generated_indices.len()); for generated_sequence in generated_indices { output.push(self.model.get_tokenizer().decode( - if prefix_length.is_some() { - generated_sequence - .into_iter() - .skip(prefix_length.unwrap_or(0) as usize) - .collect::>() - } else { - generated_sequence - }, + &generated_sequence[prefix_length.unwrap_or(0) as usize..], true, true, )); diff --git a/src/pipelines/token_classification.rs b/src/pipelines/token_classification.rs index f712f9c..1cf0fcd 100644 --- a/src/pipelines/token_classification.rs +++ b/src/pipelines/token_classification.rs @@ -776,17 +776,16 @@ impl TokenClassificationModel { /// # Ok(()) /// # } /// ``` - pub fn predict<'a, S>( + pub fn predict( &self, - input: S, + input: &[S], consolidate_sub_tokens: bool, return_special: bool, ) -> Vec> where - S: AsRef<[&'a str]>, + S: AsRef, { let mut features: Vec = input - .as_ref() .iter() .enumerate() .map(|(example_index, example)| self.generate_features(example, example_index)) @@ -794,7 +793,7 @@ impl TokenClassificationModel { .collect(); let mut example_tokens_map: HashMap> = HashMap::new(); - for example_idx in 0..input.as_ref().len() { + for example_idx in 0..input.len() { example_tokens_map.insert(example_idx, Vec::new()); } let mut start = 0usize; @@ -820,7 +819,8 @@ impl TokenClassificationModel { let labels = label_indices.get(sentence_idx); let feature = &features[sentence_idx as usize]; let sentence_reference_flag = &feature.reference_feature; - let original_chars = input.as_ref()[feature.example_index] + let original_chars = input[feature.example_index] + .as_ref() .chars() .collect::>(); let mut word_idx: u16 = 0; @@ -935,19 +935,19 @@ impl TokenClassificationModel { let text = match offsets { None => match self.tokenizer { TokenizerOption::Bert(ref tokenizer) => { - Tokenizer::decode(tokenizer, vec![token_id], false, false) + Tokenizer::decode(tokenizer, &[token_id], false, false) } TokenizerOption::Roberta(ref tokenizer) => { - Tokenizer::decode(tokenizer, vec![token_id], false, false) + Tokenizer::decode(tokenizer, &[token_id], false, false) } TokenizerOption::XLMRoberta(ref tokenizer) => { - Tokenizer::decode(tokenizer, vec![token_id], false, false) + Tokenizer::decode(tokenizer, &[token_id], false, false) } TokenizerOption::Albert(ref tokenizer) => { - Tokenizer::decode(tokenizer, vec![token_id], false, false) + Tokenizer::decode(tokenizer, &[token_id], false, false) } TokenizerOption::XLNet(ref tokenizer) => { - Tokenizer::decode(tokenizer, vec![token_id], false, false) + Tokenizer::decode(tokenizer, &[token_id], false, false) } _ => panic!( "Token classification not implemented for {:?}!", diff --git a/src/pipelines/translation/mod.rs b/src/pipelines/translation/mod.rs index 41a0732..7d648e1 100644 --- a/src/pipelines/translation/mod.rs +++ b/src/pipelines/translation/mod.rs @@ -55,9 +55,13 @@ //! let source_sentence = "This sentence will be translated in multiple languages."; //! //! let mut outputs = Vec::new(); -//! outputs.extend(model.translate([source_sentence], Language::English, Language::French)?); -//! outputs.extend(model.translate([source_sentence], Language::English, Language::Spanish)?); -//! outputs.extend(model.translate([source_sentence], Language::English, Language::Hindi)?); +//! outputs.extend(model.translate(&[source_sentence], Language::English, Language::French)?); +//! outputs.extend(model.translate( +//! &[source_sentence], +//! Language::English, +//! Language::Spanish, +//! )?); +//! outputs.extend(model.translate(&[source_sentence], Language::English, Language::Hindi)?); //! //! for sentence in outputs { //! println!("{}", sentence); diff --git a/src/pipelines/translation/translation_builder.rs b/src/pipelines/translation/translation_builder.rs index f0cf4b7..359a2b6 100644 --- a/src/pipelines/translation/translation_builder.rs +++ b/src/pipelines/translation/translation_builder.rs @@ -364,7 +364,7 @@ impl TranslationModelBuilder { { match (source_languages.as_slice(), target_languages.as_slice()) { ([Language::English], [Language::German]) => { - get_marian_resources!(ENGLISH2RUSSIAN) + get_marian_resources!(ENGLISH2GERMAN) } ([Language::English], [Language::Russian]) => { get_marian_resources!(ENGLISH2RUSSIAN) diff --git a/src/pipelines/translation/translation_pipeline.rs b/src/pipelines/translation/translation_pipeline.rs index 42741c1..9e50f29 100644 --- a/src/pipelines/translation/translation_pipeline.rs +++ b/src/pipelines/translation/translation_pipeline.rs @@ -586,8 +586,7 @@ impl TranslationOption { if !supported_source_languages.contains(source_language) { return Err(RustBertError::ValueError(format!( "{} not in list of supported languages: {:?}", - source_language.to_string(), - supported_source_languages + source_language, supported_source_languages ))); } } @@ -596,8 +595,7 @@ impl TranslationOption { if !supported_target_languages.contains(target_language) { return Err(RustBertError::ValueError(format!( "{} not in list of supported languages: {:?}", - target_language.to_string(), - supported_target_languages + target_language, supported_target_languages ))); } } @@ -630,7 +628,7 @@ impl TranslationOption { Some(format!( "translate {} to {}:", match source_language { - Some(value) => value.to_string(), + Some(value) => value, None => { return Err(RustBertError::ValueError( "Missing source language for T5".to_string(), @@ -638,7 +636,7 @@ impl TranslationOption { } }, match target_language { - Some(value) => value.to_string(), + Some(value) => value, None => { return Err(RustBertError::ValueError( "Missing target language for T5".to_string(), @@ -665,7 +663,7 @@ impl TranslationOption { )), if let Some(target_language) = target_language { Some( - model._get_tokenizer().convert_tokens_to_ids([format!( + model._get_tokenizer().convert_tokens_to_ids(&[format!( ">>{}<<", target_language.get_iso_639_1_code() )])[0], @@ -705,9 +703,8 @@ impl TranslationOption { if let Some(target_language) = target_language { let language_code = target_language.get_iso_639_1_code(); Some( - model - ._get_tokenizer() - .convert_tokens_to_ids([match language_code.len() { + model._get_tokenizer().convert_tokens_to_ids(&[ + match language_code.len() { 2 => format!(">>{}.<<", language_code), 3 => format!(">>{}<<", language_code), _ => { @@ -715,7 +712,8 @@ impl TranslationOption { "Invalid ISO 639-3 code".to_string(), )); } - }])[0], + }, + ])[0], ) } else { return Err(RustBertError::ValueError(format!( @@ -730,14 +728,14 @@ impl TranslationOption { } /// Interface method to generate() of the particular models. - pub fn generate<'a, S>( + pub fn generate( &self, - prompt_texts: Option, + prompt_texts: Option<&[S]>, attention_mask: Option, forced_bos_token_id: Option, ) -> Vec where - S: AsRef<[&'a str]>, + S: AsRef + Sync, { match *self { Self::Marian(ref model) => model @@ -927,14 +925,14 @@ impl TranslationModel { /// # Ok(()) /// # } /// ``` - pub fn translate<'a, S>( + pub fn translate( &self, - texts: S, + texts: &[S], source_language: impl Into>, target_language: impl Into>, ) -> Result, RustBertError> where - S: AsRef<[&'a str]>, + S: AsRef + Sync, { let (prefix, forced_bos_token_id) = self.model.validate_and_get_prefix_and_forced_bos_id( source_language.into().as_ref(), @@ -946,15 +944,10 @@ impl TranslationModel { Ok(match prefix { Some(value) => { let texts = texts - .as_ref() .iter() - .map(|&v| format!("{}{}", value, v)) + .map(|v| format!("{}{}", value, v.as_ref())) .collect::>(); - self.model.generate( - Some(texts.iter().map(AsRef::as_ref).collect::>()), - None, - forced_bos_token_id, - ) + self.model.generate(Some(&texts), None, forced_bos_token_id) } None => self.model.generate(Some(texts), None, forced_bos_token_id), }) diff --git a/src/prophetnet/prophetnet_model.rs b/src/prophetnet/prophetnet_model.rs index 02a26d2..16e58e1 100644 --- a/src/prophetnet/prophetnet_model.rs +++ b/src/prophetnet/prophetnet_model.rs @@ -1063,17 +1063,17 @@ impl } } - fn encode_prompt_text<'a, S>( + fn encode_prompt_text( &self, - prompt_text: S, + prompt_text: &[S], max_len: i64, pad_token_id: Option, ) -> Tensor where - S: AsRef<[&'a str]>, + S: AsRef + Sync, { let tokens = self._get_tokenizer().encode_list( - prompt_text.as_ref(), + prompt_text, max_len as usize, &TruncationStrategy::LongestFirst, 0, diff --git a/src/t5/layer_norm.rs b/src/t5/layer_norm.rs index 9c4c902..7f82f9b 100644 --- a/src/t5/layer_norm.rs +++ b/src/t5/layer_norm.rs @@ -33,10 +33,10 @@ impl T5LayerNorm { impl Module for T5LayerNorm { fn forward(&self, x: &Tensor) -> Tensor { let input_type = x.kind(); - let variance = x - .to_kind(Kind::Float) - .pow(2.0_f64) - .mean_dim(&[-1], true, Kind::Float); + let variance = + x.to_kind(Kind::Float) + .pow_tensor_scalar(2.0_f64) + .mean_dim(&[-1], true, Kind::Float); let x = x * (variance + self.epsilon).rsqrt(); if input_type != Kind::Float { (&self.weight * x).to_kind(input_type) diff --git a/src/t5/t5_model.rs b/src/t5/t5_model.rs index 068e477..3f28254 100644 --- a/src/t5/t5_model.rs +++ b/src/t5/t5_model.rs @@ -858,17 +858,17 @@ impl PrivateLanguageGenerator } } - fn encode_prompt_text<'a, S>( + fn encode_prompt_text( &self, - prompt_text: S, + prompt_text: &[S], max_len: i64, pad_token_id: Option, ) -> Tensor where - S: AsRef<[&'a str]>, + S: AsRef + Sync, { let tokens = self._get_tokenizer().encode_list( - prompt_text.as_ref(), + prompt_text, max_len as usize, &TruncationStrategy::LongestFirst, 0, diff --git a/src/xlnet/attention.rs b/src/xlnet/attention.rs index 438713a..dde3536 100644 --- a/src/xlnet/attention.rs +++ b/src/xlnet/attention.rs @@ -238,37 +238,36 @@ impl XLNetRelativeAttention { target_mapping: Option<&Tensor>, train: bool, ) -> (Tensor, Option, Option, Option) { - if let Some(g) = g { - let cat_value = if let Some(mems) = &layer_state { - if mems.prev_content.size().len() > 1 { - Some(Tensor::cat(&[&mems.prev_content, h], 0)) - } else { - None - } + let cat_value = if let Some(mems) = &layer_state { + if mems.prev_content.size().len() > 1 { + Some(Tensor::cat(&[&mems.prev_content, h], 0)) } else { None - }; - let cat = match &cat_value { - Some(value) => value, - None => h, - }; + } + } else { + None + }; + let cat = match &cat_value { + Some(value) => value, + None => h, + }; + let q_head_h = Tensor::einsum("ibh,hnd->ibnd", &[h, &self.query]); + let k_head_h = Tensor::einsum("ibh,hnd->ibnd", &[cat, &self.key]); + let v_head_h = Tensor::einsum("ibh,hnd->ibnd", &[cat, &self.value]); + let k_head_r = Tensor::einsum("ibh,hnd->ibnd", &[r, &self.pos]); - let q_head_h = Tensor::einsum("ibh,hnd->ibnd", &[h, &self.query]); - let k_head_h = Tensor::einsum("ibh,hnd->ibnd", &[cat, &self.key]); - let v_head_h = Tensor::einsum("ibh,hnd->ibnd", &[cat, &self.value]); - let k_head_r = Tensor::einsum("ibh,hnd->ibnd", &[r, &self.pos]); + let (attention_vec_h, attention_probas_h) = self.rel_attention_core( + &q_head_h, + &k_head_h, + &v_head_h, + &k_head_r, + seg_mat, + attn_mask_h, + train, + ); + let output_h = self.post_attention(h, &attention_vec_h, true, train); - let (attention_vec_h, attention_probas_h) = self.rel_attention_core( - &q_head_h, - &k_head_h, - &v_head_h, - &k_head_r, - seg_mat, - attn_mask_h, - train, - ); - - let output_h = self.post_attention(h, &attention_vec_h, true, train); + let (output_g, attention_probas_g) = if let Some(g) = g { let q_head_g = Tensor::einsum("ibh,hnd->ibnd", &[g, &self.query]); let (attention_vec_g, attention_probas_g) = match target_mapping { @@ -299,44 +298,10 @@ impl XLNetRelativeAttention { }; let output_g = self.post_attention(g, &attention_vec_g, true, train); - ( - output_h, - Some(output_g), - attention_probas_h, - attention_probas_g, - ) + (Some(output_g), attention_probas_g) } else { - let cat_value = if let Some(mems) = &layer_state { - if mems.prev_content.size().len() > 1 { - Some(Tensor::cat(&[&mems.prev_content, h], 0)) - } else { - None - } - } else { - None - }; - let cat = match &cat_value { - Some(value) => value, - None => h, - }; - - let q_head_h = Tensor::einsum("ibh,hnd->ibnd", &[h, &self.query]); - let k_head_h = Tensor::einsum("ibh,hnd->ibnd", &[cat, &self.key]); - let v_head_h = Tensor::einsum("ibh,hnd->ibnd", &[cat, &self.value]); - let k_head_r = Tensor::einsum("ibh,hnd->ibnd", &[r, &self.pos]); - - let (attention_vec, attention_probas) = self.rel_attention_core( - &q_head_h, - &k_head_h, - &v_head_h, - &k_head_r, - seg_mat, - attn_mask_h, - train, - ); - - let output_h = self.post_attention(h, &attention_vec, true, train); - (output_h, None, attention_probas, None) - } + (None, None) + }; + (output_h, output_g, attention_probas_h, attention_probas_g) } } diff --git a/tests/distilgpt2.rs b/tests/distilgpt2.rs index 0e60df4..6d3f166 100644 --- a/tests/distilgpt2.rs +++ b/tests/distilgpt2.rs @@ -80,7 +80,7 @@ fn distilgpt2_lm_model() -> anyhow::Result<()> { .get(-1) .argmax(-1, true) .int64_value(&[0]); - let next_word = tokenizer.decode(vec![next_word_id], true, true); + let next_word = tokenizer.decode(&[next_word_id], true, true); assert_eq!(model_output.lm_logits.size(), vec!(1, 11, 50257)); match model_output.cache { diff --git a/tests/gpt2.rs b/tests/gpt2.rs index da6f74c..08861d5 100644 --- a/tests/gpt2.rs +++ b/tests/gpt2.rs @@ -82,7 +82,7 @@ fn gpt2_lm_model() -> anyhow::Result<()> { .get(-1) .argmax(-1, true) .int64_value(&[0]); - let next_word = tokenizer.decode(vec![next_word_id], true, true); + let next_word = tokenizer.decode(&[next_word_id], true, true); assert_eq!(model_output.lm_logits.size(), vec!(1, 4, 50257)); match model_output.cache { diff --git a/tests/gpt_neo.rs b/tests/gpt_neo.rs index f712622..362a394 100644 --- a/tests/gpt_neo.rs +++ b/tests/gpt_neo.rs @@ -72,7 +72,7 @@ fn gpt_neo_lm() -> anyhow::Result<()> { .get(-1) .argmax(-1, true) .int64_value(&[0]); - let next_word = tokenizer.decode(vec![next_word_id], true, true); + let next_word = tokenizer.decode(&[next_word_id], true, true); let next_score = model_output .lm_logits .get(0) diff --git a/tests/longformer.rs b/tests/longformer.rs index 32c4203..eb18d08 100644 --- a/tests/longformer.rs +++ b/tests/longformer.rs @@ -274,7 +274,7 @@ fn longformer_for_multiple_choice() -> anyhow::Result<()> { let prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."; let inputs = ["Very positive sentence", "Second sentence input"]; let tokenized_input = tokenizer.encode_pair_list( - inputs + &inputs .iter() .map(|&inp| (prompt, inp)) .collect::>(), diff --git a/tests/m2m100.rs b/tests/m2m100.rs index db01b6d..d047715 100644 --- a/tests/m2m100.rs +++ b/tests/m2m100.rs @@ -107,9 +107,9 @@ fn m2m100_translation() -> anyhow::Result<()> { let source_sentence = "This sentence will be translated in multiple languages."; let mut outputs = Vec::new(); - outputs.extend(model.translate([source_sentence], Language::English, Language::French)?); - outputs.extend(model.translate([source_sentence], Language::English, Language::Spanish)?); - outputs.extend(model.translate([source_sentence], Language::English, Language::Hindi)?); + outputs.extend(model.translate(&[source_sentence], Language::English, Language::French)?); + outputs.extend(model.translate(&[source_sentence], Language::English, Language::Spanish)?); + outputs.extend(model.translate(&[source_sentence], Language::English, Language::Hindi)?); assert_eq!(outputs.len(), 3); assert_eq!( diff --git a/tests/mbart.rs b/tests/mbart.rs index f851f7e..4b6a809 100644 --- a/tests/mbart.rs +++ b/tests/mbart.rs @@ -73,9 +73,9 @@ fn mbart_translation() -> anyhow::Result<()> { let source_sentence = "This sentence will be translated in multiple languages."; let mut outputs = Vec::new(); - outputs.extend(model.translate([source_sentence], Language::English, Language::French)?); - outputs.extend(model.translate([source_sentence], Language::English, Language::Spanish)?); - outputs.extend(model.translate([source_sentence], Language::English, Language::Hindi)?); + outputs.extend(model.translate(&[source_sentence], Language::English, Language::French)?); + outputs.extend(model.translate(&[source_sentence], Language::English, Language::Spanish)?); + outputs.extend(model.translate(&[source_sentence], Language::English, Language::Hindi)?); assert_eq!(outputs.len(), 3); assert_eq!( diff --git a/tests/mobilebert.rs b/tests/mobilebert.rs index c7370fb..0fbc758 100644 --- a/tests/mobilebert.rs +++ b/tests/mobilebert.rs @@ -182,7 +182,7 @@ fn mobilebert_for_multiple_choice() -> anyhow::Result<()> { let prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."; let inputs = ["Very positive sentence", "Second sentence input"]; let tokenized_input = tokenizer.encode_pair_list( - inputs + &inputs .iter() .map(|&inp| (prompt, inp)) .collect::>(), diff --git a/tests/openai_gpt.rs b/tests/openai_gpt.rs index 5d69010..1d1ec31 100644 --- a/tests/openai_gpt.rs +++ b/tests/openai_gpt.rs @@ -83,7 +83,7 @@ fn openai_gpt_lm_model() -> anyhow::Result<()> { .get(-1) .argmax(-1, true) .int64_value(&[0]); - let next_word = tokenizer.decode(vec![next_word_id], true, true); + let next_word = tokenizer.decode(&[next_word_id], true, true); assert_eq!(model_output.lm_logits.size(), vec!(1, 6, 40478)); assert!( diff --git a/tests/t5.rs b/tests/t5.rs index 0acedfd..b851371 100644 --- a/tests/t5.rs +++ b/tests/t5.rs @@ -44,9 +44,9 @@ fn test_translation_t5() -> anyhow::Result<()> { let source_sentence = "This sentence will be translated in multiple languages."; let mut outputs = Vec::new(); - outputs.extend(model.translate([source_sentence], Language::English, Language::French)?); - outputs.extend(model.translate([source_sentence], Language::English, Language::German)?); - outputs.extend(model.translate([source_sentence], Language::English, Language::Romanian)?); + outputs.extend(model.translate(&[source_sentence], Language::English, Language::French)?); + outputs.extend(model.translate(&[source_sentence], Language::English, Language::German)?); + outputs.extend(model.translate(&[source_sentence], Language::English, Language::Romanian)?); assert_eq!(outputs.len(), 3); assert_eq!( diff --git a/tests/xlnet.rs b/tests/xlnet.rs index b434a0d..eaafb81 100644 --- a/tests/xlnet.rs +++ b/tests/xlnet.rs @@ -331,7 +331,7 @@ fn xlnet_for_multiple_choice() -> anyhow::Result<()> { let prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."; let inputs = ["Very positive sentence", "Second sentence input"]; let tokenized_input = tokenizer.encode_pair_list( - inputs + &inputs .iter() .map(|&inp| (prompt, inp)) .collect::>(),