diff --git a/Cargo.toml b/Cargo.toml index 59c66a9..d9383f3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "rust-bert" -version = "0.10.0" +version = "0.11.0" authors = ["Guillaume Becquin "] edition = "2018" description = "Ready-to-use NLP pipelines and transformer-based models (BERT, DistilBERT, GPT2,...)" diff --git a/README.md b/README.md index 099aaa1..37fe2ef 100644 --- a/README.md +++ b/README.md @@ -10,17 +10,17 @@ This repository exposes the model base architecture, task-specific heads (see be The following models are currently implemented: - | |**DistilBERT**|**BERT**|**RoBERTa**|**GPT**|**GPT2**|**BART**|**Electra**|**Marian**|**ALBERT**|**T5**| -:-----:|:----:|:----:|:-----:|:----:|:-----:|:----:|:----:|:----:|:----:|:----: -Masked LM|✅ |✅ |✅ | | | |✅| |✅ | | -Sequence classification|✅ |✅ |✅| | |✅ | | |✅ | | -Token classification|✅ |✅ | ✅| | | |✅| |✅ | | -Question answering|✅ |✅ |✅| | | | | |✅ | | -Multiple choices| |✅ |✅| | | | | |✅ | | -Next token prediction| | | |✅|✅|✅| | | | | -Natural Language Generation| | | |✅|✅|✅| | | | | -Summarization | | | | | |✅| | | | | -Translation | | | | | |✅| |✅ | |✅| + | |**DistilBERT**|**BERT**|**RoBERTa**|**GPT**|**GPT2**|**BART**|**Electra**|**Marian**|**ALBERT**|**T5**|**XLNet**| +:-----:|:----:|:----:|:-----:|:----:|:-----:|:----:|:----:|:----:|:----:|:----:|:----: +Masked LM|✅ |✅ |✅ | | | |✅| |✅ | |✅ | +Sequence classification|✅ |✅ |✅| | |✅ | | |✅ | |✅ | +Token classification|✅ |✅ | ✅| | | |✅| |✅ | |✅ | +Question answering|✅ |✅ |✅| | | | | |✅ | |✅ | +Multiple choices| |✅ |✅| | | | | |✅ | |✅ |✅ | +Next token prediction| | | |✅|✅|✅| | | | |✅ | +Natural Language Generation| | | |✅|✅|✅| | | | |✅ | +Summarization | | | | | |✅| | | | | | +Translation | | | | | |✅| |✅ | |✅| | ## Ready-to-use pipelines diff --git a/examples/bart.rs b/examples/bart.rs index 14467f4..83e5dd1 100644 --- a/examples/bart.rs +++ b/examples/bart.rs @@ -54,7 +54,7 @@ fn main() -> anyhow::Result<()> { // Credits: WikiNews, CC BY 2.5 license (https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b) - let tokenized_input = tokenizer.encode_list(&input, 1024, &TruncationStrategy::LongestFirst, 0); + let tokenized_input = tokenizer.encode_list(input, 1024, &TruncationStrategy::LongestFirst, 0); let max_len = tokenized_input .iter() .map(|input| input.token_ids.len()) diff --git a/src/lib.rs b/src/lib.rs index 1249f36..d95f95a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -31,17 +31,17 @@ //! ``` //! - Transformer models base architectures with customized heads. These allow to load pre-trained models for customized inference in Rust //! -//! | |**DistilBERT**|**BERT**|**RoBERTa**|**GPT**|**GPT2**|**BART**|**Electra**|**Marian**|**ALBERT**|**T5** -//! :-----:|:----:|:----:|:----:|:----:|:----:|:----:|:----:|:----:|:----:|:----: -//! Masked LM|✅ |✅ |✅ | | | |✅| |✅ | | -//! Sequence classification|✅ |✅ |✅| | |✅| | |✅ | | -//! Token classification|✅ |✅ | ✅| | | |✅| |✅ | | -//! Question answering|✅ |✅ |✅| | | | | |✅ | | -//! Multiple choices| |✅ |✅| | | | | |✅ | | -//! Next token prediction| | | |✅|✅| | | | | | -//! Natural Language Generation| | | |✅|✅| | | | | | -//! Summarization| | | | | |✅| | | | | -//! Translation| | | | | | | |✅| |✅| +//! | |**DistilBERT**|**BERT**|**RoBERTa**|**GPT**|**GPT2**|**BART**|**Electra**|**Marian**|**ALBERT**|**T5**|**XLNet** +//! :-----:|:----:|:----:|:----:|:----:|:----:|:----:|:----:|:----:|:----:|:----:|:----: +//! Masked LM|✅ |✅ |✅ | | | |✅| |✅ | |✅| +//! Sequence classification|✅ |✅ |✅| | |✅| | |✅ | |✅| +//! Token classification|✅ |✅ | ✅| | | |✅| |✅ | |✅| +//! Question answering|✅ |✅ |✅| | | | | |✅ | |✅| +//! Multiple choices| |✅ |✅| | | | | |✅ | |✅| +//! Next token prediction| | | |✅|✅| | | | | |✅| +//! Natural Language Generation| | | |✅|✅| | | | | |✅| +//! Summarization| | | | | |✅| | | | | | +//! Translation| | | | | | | |✅| |✅| | //! //! # Loading pre-trained models //! diff --git a/src/pipelines/common.rs b/src/pipelines/common.rs index e8bb61d..a9ca198 100644 --- a/src/pipelines/common.rs +++ b/src/pipelines/common.rs @@ -448,13 +448,13 @@ impl TokenizerOption { /// Interface method to convert tokens to ids pub fn convert_tokens_to_ids(&self, tokens: &[String]) -> Vec { match *self { - Self::Bert(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens.into()), - Self::Roberta(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens.into()), - Self::Marian(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens.into()), - Self::T5(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens.into()), - Self::XLMRoberta(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens.into()), - Self::Albert(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens.into()), - Self::XLNet(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens.into()), + Self::Bert(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens), + Self::Roberta(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens), + Self::Marian(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens), + Self::T5(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens), + Self::XLMRoberta(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens), + Self::Albert(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens), + Self::XLNet(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens), } } diff --git a/src/pipelines/ner.rs b/src/pipelines/ner.rs index 1cbc696..3f121c0 100644 --- a/src/pipelines/ner.rs +++ b/src/pipelines/ner.rs @@ -188,7 +188,10 @@ impl NERModel { /// # Ok(()) /// # } /// ``` - pub fn predict(&self, input: &[&str]) -> Vec { + pub fn predict<'a, S>(&self, input: S) -> Vec + where + S: AsRef<[&'a str]>, + { self.token_classification_model .predict(input, true, false) .into_iter() diff --git a/src/pipelines/zero_shot_classification.rs b/src/pipelines/zero_shot_classification.rs index 1c17711..c4d5103 100644 --- a/src/pipelines/zero_shot_classification.rs +++ b/src/pipelines/zero_shot_classification.rs @@ -465,22 +465,32 @@ impl ZeroShotClassificationModel { }) } - fn prepare_for_model( + fn prepare_for_model<'a, S, T>( &self, - inputs: &[&str], - labels: &[&str], + inputs: S, + labels: T, template: Option String>>, max_len: usize, - ) -> (Tensor, Tensor) { + ) -> (Tensor, Tensor) + where + S: AsRef<[&'a str]>, + T: AsRef<[&'a str]>, + { let label_sentences: Vec = match template { - Some(function) => labels.iter().map(|label| function(label)).collect(), + Some(function) => labels + .as_ref() + .iter() + .map(|label| function(label)) + .collect(), None => labels + .as_ref() .iter() .map(|label| format!("This example is about {}.", label)) .collect(), }; let text_pair_list = inputs + .as_ref() .iter() .cartesian_product(label_sentences.iter()) .map(|(&s, label)| (s, label.as_str())) @@ -577,15 +587,20 @@ impl ZeroShotClassificationModel { /// ] /// .to_vec(); /// ``` - pub fn predict( + pub fn predict<'a, S, T>( &self, - inputs: &[&str], - labels: &[&str], + inputs: S, + labels: T, template: Option String>>, max_length: usize, - ) -> Vec