diff --git a/CHANGELOG.md b/CHANGELOG.md index 2d31c0a..ea7938f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,7 +3,9 @@ All notable changes to this project will be documented in this file. The format ## [Unreleased] ## Added -- (BREAKING) Support for `prefix_allowed_tokens_fn` arguments for generation, allowing users to control the generation via custom functions +- (BREAKING) Support for `prefix_allowed_tokens_fn` argument for generation, allowing users to control the generation via custom functions +- (BREAKING) Support for `forced_bos_token_id` argument for generation, allowing users to force a given BOS token for generation (useful for MBart/M2M-class models) +- Addition of the MBart Language model and support for text generation / direct translation between 50 language ## [0.15.1] - 2021-06-01 ### Fixed diff --git a/README.md b/README.md index 70fb6ce..a04c929 100644 --- a/README.md +++ b/README.md @@ -46,8 +46,9 @@ RoBERTa|✅|✅|✅| | | |✅| GPT| | | |✅ | | | | GPT2| | | |✅ | | | | GPT-Neo| | | |✅ | | | | -BART|✅| | |✅ |✅| | | +BART|✅| | |✅ |✅| | | Marian| | | | | |✅| | +MBart|✅| | |✅ | | | | Electra | |✅| | | | |✅| ALBERT |✅|✅|✅| | | |✅| T5 | | | |✅ |✅|✅| | diff --git a/examples/albert.rs b/examples/albert.rs deleted file mode 100644 index 7321398..0000000 --- a/examples/albert.rs +++ /dev/null @@ -1,97 +0,0 @@ -// Copyright 2018 Google AI and Google Brain team. -// Copyright 2020-present, the HuggingFace Inc. team. -// Copyright 2020 Guillaume Becquin -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// http://www.apache.org/licenses/LICENSE-2.0 -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -extern crate anyhow; - -use rust_bert::albert::{ - AlbertConfig, AlbertConfigResources, AlbertForMaskedLM, AlbertModelResources, - AlbertVocabResources, -}; -use rust_bert::resources::{RemoteResource, Resource}; -use rust_bert::Config; -use rust_tokenizers::tokenizer::{AlbertTokenizer, Tokenizer, TruncationStrategy}; -use rust_tokenizers::vocab::Vocab; -use tch::{nn, no_grad, Device, Tensor}; - -fn main() -> anyhow::Result<()> { - // Resources paths - let config_resource = Resource::Remote(RemoteResource::from_pretrained( - AlbertConfigResources::ALBERT_BASE_V2, - )); - let vocab_resource = Resource::Remote(RemoteResource::from_pretrained( - AlbertVocabResources::ALBERT_BASE_V2, - )); - let weights_resource = Resource::Remote(RemoteResource::from_pretrained( - AlbertModelResources::ALBERT_BASE_V2, - )); - let config_path = config_resource.get_local_path()?; - let vocab_path = vocab_resource.get_local_path()?; - let weights_path = weights_resource.get_local_path()?; - - // Set-up masked LM model - let device = Device::Cpu; - let mut vs = nn::VarStore::new(device); - let tokenizer: AlbertTokenizer = - AlbertTokenizer::from_file(vocab_path.to_str().unwrap(), true, false)?; - let config = AlbertConfig::from_file(config_path); - let albert_model = AlbertForMaskedLM::new(&vs.root(), &config); - vs.load(weights_path)?; - - // Define input - let input = [ - "Looks like one [MASK] is missing", - "It was a very nice and [MASK] day", - ]; - let tokenized_input = tokenizer.encode_list(&input, 128, &TruncationStrategy::LongestFirst, 0); - let max_len = tokenized_input - .iter() - .map(|input| input.token_ids.len()) - .max() - .unwrap(); - let tokenized_input = tokenized_input - .iter() - .map(|input| input.token_ids.clone()) - .map(|mut input| { - input.extend(vec![0; max_len - input.len()]); - input - }) - .map(|input| Tensor::of_slice(&(input))) - .collect::>(); - let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device); - - // Forward pass - let model_output = - no_grad(|| albert_model.forward_t(Some(input_tensor), None, None, None, None, false)); - println!( - "{:?}", - model_output.prediction_scores.double_value(&[0, 0, 0]) - ); - // Print masked tokens - let index_1 = model_output - .prediction_scores - .get(0) - .get(4) - .argmax(0, false); - let index_2 = model_output - .prediction_scores - .get(1) - .get(7) - .argmax(0, false); - let word_1 = tokenizer.vocab().id_to_token(&index_1.int64_value(&[])); - let word_2 = tokenizer.vocab().id_to_token(&index_2.int64_value(&[])); - - println!("{} - {}", &index_1.int64_value(&[]), word_1); // Outputs "_them" : "Looks like one [them] is missing" - println!("{} - {}", &index_2.int64_value(&[]), word_2); // Outputs "_enjoyable" : "It was a very nice and [enjoyable] day" - - Ok(()) -} diff --git a/examples/bart.rs b/examples/bart.rs deleted file mode 100644 index 5a89cbb..0000000 --- a/examples/bart.rs +++ /dev/null @@ -1,81 +0,0 @@ -// Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc. -// Copyright 2019 Guillaume Becquin -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// http://www.apache.org/licenses/LICENSE-2.0 -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -extern crate anyhow; - -use rust_bert::bart::{ - BartConfig, BartConfigResources, BartMergesResources, BartModel, BartModelResources, - BartVocabResources, -}; -use rust_bert::resources::{RemoteResource, Resource}; -use rust_bert::Config; -use rust_tokenizers::tokenizer::{RobertaTokenizer, Tokenizer, TruncationStrategy}; -use tch::{nn, no_grad, Device, Tensor}; - -fn main() -> anyhow::Result<()> { - // Resources paths - let config_resource = - Resource::Remote(RemoteResource::from_pretrained(BartConfigResources::BART)); - let vocab_resource = - Resource::Remote(RemoteResource::from_pretrained(BartVocabResources::BART)); - let merges_resource = - Resource::Remote(RemoteResource::from_pretrained(BartMergesResources::BART)); - let weights_resource = - Resource::Remote(RemoteResource::from_pretrained(BartModelResources::BART)); - let config_path = config_resource.get_local_path()?; - let vocab_path = vocab_resource.get_local_path()?; - let merges_path = merges_resource.get_local_path()?; - let weights_path = weights_resource.get_local_path()?; - - // Set-up masked LM model - let device = Device::cuda_if_available(); - let mut vs = nn::VarStore::new(device); - let tokenizer = RobertaTokenizer::from_file( - vocab_path.to_str().unwrap(), - merges_path.to_str().unwrap(), - false, - false, - )?; - let config = BartConfig::from_file(config_path); - let bart_model = BartModel::new(&vs.root(), &config); - vs.load(weights_path)?; - - // Define input - let input = ["One two three four"]; - - let tokenized_input = tokenizer.encode_list(input, 1024, &TruncationStrategy::LongestFirst, 0); - let max_len = tokenized_input - .iter() - .map(|input| input.token_ids.len()) - .max() - .unwrap(); - let tokenized_input = tokenized_input - .iter() - .map(|input| input.token_ids.clone()) - .map(|mut input| { - input.extend(vec![0; max_len - input.len()]); - input - }) - .map(|input| Tensor::of_slice(&(input))) - .collect::>(); - let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device); - - // Forward pass - let model_output = - no_grad(|| bart_model.forward_t(Some(&input_tensor), None, None, None, None, None, false)); - - // Print masked tokens - println!("{:?}", model_output.encoder_hidden_state); - println!("{:?}", model_output.decoder_output); - println!("{:?}", model_output.decoder_output.double_value(&[0, 0, 0])); - Ok(()) -} diff --git a/examples/distilbert_masked_lm.rs b/examples/distilbert_masked_lm.rs deleted file mode 100644 index 2de332f..0000000 --- a/examples/distilbert_masked_lm.rs +++ /dev/null @@ -1,102 +0,0 @@ -// Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc. -// Copyright 2019 Guillaume Becquin -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// http://www.apache.org/licenses/LICENSE-2.0 -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -extern crate anyhow; - -use rust_bert::distilbert::{ - DistilBertConfig, DistilBertConfigResources, DistilBertModelMaskedLM, DistilBertModelResources, - DistilBertVocabResources, -}; -use rust_bert::resources::{RemoteResource, Resource}; -use rust_bert::Config; -use rust_tokenizers::tokenizer::{BertTokenizer, MultiThreadedTokenizer, TruncationStrategy}; -use rust_tokenizers::vocab::Vocab; -use tch::{nn, no_grad, Device, Tensor}; - -fn main() -> anyhow::Result<()> { - // Resources paths - let config_resource = Resource::Remote(RemoteResource::from_pretrained( - DistilBertConfigResources::DISTIL_BERT, - )); - let vocab_resource = Resource::Remote(RemoteResource::from_pretrained( - DistilBertVocabResources::DISTIL_BERT, - )); - let weights_resource = Resource::Remote(RemoteResource::from_pretrained( - DistilBertModelResources::DISTIL_BERT, - )); - let config_path = config_resource.get_local_path()?; - let vocab_path = vocab_resource.get_local_path()?; - let weights_path = weights_resource.get_local_path()?; - - // Set-up masked LM model - let device = Device::Cpu; - let mut vs = nn::VarStore::new(device); - let tokenizer: BertTokenizer = - BertTokenizer::from_file(vocab_path.to_str().unwrap(), true, true)?; - let config = DistilBertConfig::from_file(config_path); - let distil_bert_model = DistilBertModelMaskedLM::new(&vs.root(), &config); - vs.load(weights_path)?; - - // Define input - let input = [ - "Looks like one thing is missing", - "It\'s like comparing oranges to apples", - ]; - let tokenized_input = tokenizer.encode_list(input, 128, &TruncationStrategy::LongestFirst, 0); - let max_len = tokenized_input - .iter() - .map(|input| input.token_ids.len()) - .max() - .unwrap(); - let mut tokenized_input = tokenized_input - .iter() - .map(|input| input.token_ids.clone()) - .map(|mut input| { - input.extend(vec![0; max_len - input.len()]); - input - }) - .collect::>(); - - // Masking the token [thing] of sentence 1 and [oranges] of sentence 2 - tokenized_input[0][4] = 103; - tokenized_input[1][6] = 103; - let tokenized_input = tokenized_input - .iter() - .map(|input| Tensor::of_slice(&(input))) - .collect::>(); - let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device); - - // Forward pass - let model_output = no_grad(|| { - distil_bert_model - .forward_t(Some(input_tensor), None, None, false) - .unwrap() - }); - - // Print masked tokens - let index_1 = model_output - .prediction_scores - .get(0) - .get(4) - .argmax(0, false); - let index_2 = model_output - .prediction_scores - .get(1) - .get(6) - .argmax(0, false); - let word_1 = tokenizer.vocab().id_to_token(&index_1.int64_value(&[])); - let word_2 = tokenizer.vocab().id_to_token(&index_2.int64_value(&[])); - - println!("{}", word_1); // Outputs "person" : "Looks like one [person] is missing" - println!("{}", word_2); // Outputs "pear" : "It\'s like comparing [pear] to apples" - - Ok(()) -} diff --git a/examples/download_all_dependencies.rs b/examples/download_all_dependencies.rs deleted file mode 100644 index 1854fd2..0000000 --- a/examples/download_all_dependencies.rs +++ /dev/null @@ -1,406 +0,0 @@ -extern crate anyhow; - -use rust_bert::albert::{AlbertConfigResources, AlbertModelResources, AlbertVocabResources}; -use rust_bert::bart::{ - BartConfigResources, BartMergesResources, BartModelResources, BartVocabResources, -}; -use rust_bert::bert::{BertConfigResources, BertModelResources, BertVocabResources}; -use rust_bert::distilbert::{ - DistilBertConfigResources, DistilBertModelResources, DistilBertVocabResources, -}; -use rust_bert::electra::{ElectraConfigResources, ElectraModelResources, ElectraVocabResources}; -use rust_bert::gpt2::{ - Gpt2ConfigResources, Gpt2MergesResources, Gpt2ModelResources, Gpt2VocabResources, -}; -use rust_bert::openai_gpt::{ - OpenAiGptConfigResources, OpenAiGptMergesResources, OpenAiGptModelResources, - OpenAiGptVocabResources, -}; -use rust_bert::resources::{RemoteResource, Resource}; -use rust_bert::roberta::{ - RobertaConfigResources, RobertaMergesResources, RobertaModelResources, RobertaVocabResources, -}; -use rust_bert::t5::{T5ConfigResources, T5ModelResources, T5VocabResources}; -use rust_bert::xlnet::{XLNetConfigResources, XLNetModelResources, XLNetVocabResources}; - -/// This example downloads and caches all dependencies used in model tests. This allows for safe -/// multi threaded testing (two test using the same resource would otherwise download the file to -/// the same location). - -fn download_distil_gpt2() -> anyhow::Result<()> { - // Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format. - let config_resource = Resource::Remote(RemoteResource::from_pretrained( - Gpt2ConfigResources::DISTIL_GPT2, - )); - let vocab_resource = Resource::Remote(RemoteResource::from_pretrained( - Gpt2VocabResources::DISTIL_GPT2, - )); - let merges_resource = Resource::Remote(RemoteResource::from_pretrained( - Gpt2MergesResources::DISTIL_GPT2, - )); - let weights_resource = Resource::Remote(RemoteResource::from_pretrained( - Gpt2ModelResources::DISTIL_GPT2, - )); - let _ = config_resource.get_local_path()?; - let _ = vocab_resource.get_local_path()?; - let _ = merges_resource.get_local_path()?; - let _ = weights_resource.get_local_path()?; - Ok(()) -} - -fn download_distilbert_sst2() -> anyhow::Result<()> { - // Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format. - let weights_resource = Resource::Remote(RemoteResource::from_pretrained( - DistilBertModelResources::DISTIL_BERT_SST2, - )); - let config_resource = Resource::Remote(RemoteResource::from_pretrained( - DistilBertConfigResources::DISTIL_BERT_SST2, - )); - let vocab_resource = Resource::Remote(RemoteResource::from_pretrained( - DistilBertVocabResources::DISTIL_BERT_SST2, - )); - let _ = config_resource.get_local_path()?; - let _ = vocab_resource.get_local_path()?; - let _ = weights_resource.get_local_path()?; - Ok(()) -} - -fn download_distilbert_qa() -> anyhow::Result<()> { - // Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format. - let weights_resource = Resource::Remote(RemoteResource::from_pretrained( - DistilBertModelResources::DISTIL_BERT_SQUAD, - )); - let config_resource = Resource::Remote(RemoteResource::from_pretrained( - DistilBertConfigResources::DISTIL_BERT_SQUAD, - )); - let vocab_resource = Resource::Remote(RemoteResource::from_pretrained( - DistilBertVocabResources::DISTIL_BERT_SQUAD, - )); - let _ = config_resource.get_local_path()?; - let _ = vocab_resource.get_local_path()?; - let _ = weights_resource.get_local_path()?; - Ok(()) -} - -fn download_distilbert() -> anyhow::Result<()> { - // Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format. - let weights_resource = Resource::Remote(RemoteResource::from_pretrained( - DistilBertModelResources::DISTIL_BERT, - )); - let config_resource = Resource::Remote(RemoteResource::from_pretrained( - DistilBertConfigResources::DISTIL_BERT, - )); - let vocab_resource = Resource::Remote(RemoteResource::from_pretrained( - DistilBertVocabResources::DISTIL_BERT, - )); - let _ = config_resource.get_local_path()?; - let _ = vocab_resource.get_local_path()?; - let _ = weights_resource.get_local_path()?; - Ok(()) -} - -fn download_gpt2() -> anyhow::Result<()> { - // Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2. Modified with conversion to C-array format. - let config_resource = - Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2)); - let vocab_resource = - Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2)); - let merges_resource = - Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2)); - let weights_resource = - Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2)); - let _ = config_resource.get_local_path()?; - let _ = vocab_resource.get_local_path()?; - let _ = merges_resource.get_local_path()?; - let _ = weights_resource.get_local_path()?; - Ok(()) -} - -fn download_gpt() -> anyhow::Result<()> { - // Shared under MIT license by the OpenAI team at https://github.com/openai/finetune-transformer-lm. Modified with conversion to C-array format. - let config_resource = Resource::Remote(RemoteResource::from_pretrained( - OpenAiGptConfigResources::GPT, - )); - let vocab_resource = Resource::Remote(RemoteResource::from_pretrained( - OpenAiGptVocabResources::GPT, - )); - let merges_resource = Resource::Remote(RemoteResource::from_pretrained( - OpenAiGptMergesResources::GPT, - )); - let weights_resource = Resource::Remote(RemoteResource::from_pretrained( - OpenAiGptModelResources::GPT, - )); - let _ = config_resource.get_local_path()?; - let _ = vocab_resource.get_local_path()?; - let _ = merges_resource.get_local_path()?; - let _ = weights_resource.get_local_path()?; - Ok(()) -} - -fn download_roberta() -> anyhow::Result<()> { - // Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format. - let config_resource = Resource::Remote(RemoteResource::from_pretrained( - RobertaConfigResources::ROBERTA, - )); - let vocab_resource = Resource::Remote(RemoteResource::from_pretrained( - RobertaVocabResources::ROBERTA, - )); - let merges_resource = Resource::Remote(RemoteResource::from_pretrained( - RobertaMergesResources::ROBERTA, - )); - let weights_resource = Resource::Remote(RemoteResource::from_pretrained( - RobertaModelResources::ROBERTA, - )); - let _ = config_resource.get_local_path()?; - let _ = vocab_resource.get_local_path()?; - let _ = merges_resource.get_local_path()?; - let _ = weights_resource.get_local_path()?; - Ok(()) -} - -fn download_bert() -> anyhow::Result<()> { - // Shared under Apache 2.0 license by the Google team at https://github.com/google-research/bert. Modified with conversion to C-array format. - let config_resource = - Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT)); - let vocab_resource = - Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT)); - let weights_resource = - Resource::Remote(RemoteResource::from_pretrained(BertModelResources::BERT)); - let _ = config_resource.get_local_path()?; - let _ = vocab_resource.get_local_path()?; - let _ = weights_resource.get_local_path()?; - Ok(()) -} - -fn download_bert_ner() -> anyhow::Result<()> { - // Shared under MIT license by the MDZ Digital Library team at the Bavarian State Library at https://github.com/dbmdz/berts. Modified with conversion to C-array format. - let config_resource = Resource::Remote(RemoteResource::from_pretrained( - BertConfigResources::BERT_NER, - )); - let vocab_resource = Resource::Remote(RemoteResource::from_pretrained( - BertVocabResources::BERT_NER, - )); - let weights_resource = Resource::Remote(RemoteResource::from_pretrained( - BertModelResources::BERT_NER, - )); - let _ = config_resource.get_local_path()?; - let _ = vocab_resource.get_local_path()?; - let _ = weights_resource.get_local_path()?; - Ok(()) -} - -fn download_bart() -> anyhow::Result<()> { - // Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format. - let config_resource = - Resource::Remote(RemoteResource::from_pretrained(BartConfigResources::BART)); - let vocab_resource = - Resource::Remote(RemoteResource::from_pretrained(BartVocabResources::BART)); - let merges_resource = - Resource::Remote(RemoteResource::from_pretrained(BartMergesResources::BART)); - let weights_resource = - Resource::Remote(RemoteResource::from_pretrained(BartModelResources::BART)); - let _ = config_resource.get_local_path()?; - let _ = vocab_resource.get_local_path()?; - let _ = merges_resource.get_local_path()?; - let _ = weights_resource.get_local_path()?; - Ok(()) -} - -fn download_bart_cnn() -> anyhow::Result<()> { - // Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format. - let config_resource = Resource::Remote(RemoteResource::from_pretrained( - BartConfigResources::BART_CNN, - )); - let vocab_resource = Resource::Remote(RemoteResource::from_pretrained( - BartVocabResources::BART_CNN, - )); - let merges_resource = Resource::Remote(RemoteResource::from_pretrained( - BartMergesResources::BART_CNN, - )); - let weights_resource = Resource::Remote(RemoteResource::from_pretrained( - BartModelResources::BART_CNN, - )); - let _ = config_resource.get_local_path()?; - let _ = vocab_resource.get_local_path()?; - let _ = merges_resource.get_local_path()?; - let _ = weights_resource.get_local_path()?; - Ok(()) -} - -fn download_electra_generator() -> anyhow::Result<()> { - // Shared under Apache 2.0 license by the Google team at https://github.com/google-research/electra. Modified with conversion to C-array format. - let config_resource = Resource::Remote(RemoteResource::from_pretrained( - ElectraConfigResources::BASE_GENERATOR, - )); - let vocab_resource = Resource::Remote(RemoteResource::from_pretrained( - ElectraVocabResources::BASE_GENERATOR, - )); - let weights_resource = Resource::Remote(RemoteResource::from_pretrained( - ElectraModelResources::BASE_GENERATOR, - )); - let _ = config_resource.get_local_path()?; - let _ = vocab_resource.get_local_path()?; - let _ = weights_resource.get_local_path()?; - Ok(()) -} - -fn download_electra_discriminator() -> anyhow::Result<()> { - // Shared under Apache 2.0 license by the Google team at https://github.com/google-research/electra. Modified with conversion to C-array format. - let config_resource = Resource::Remote(RemoteResource::from_pretrained( - ElectraConfigResources::BASE_DISCRIMINATOR, - )); - let vocab_resource = Resource::Remote(RemoteResource::from_pretrained( - ElectraVocabResources::BASE_DISCRIMINATOR, - )); - let weights_resource = Resource::Remote(RemoteResource::from_pretrained( - ElectraModelResources::BASE_DISCRIMINATOR, - )); - let _ = config_resource.get_local_path()?; - let _ = vocab_resource.get_local_path()?; - let _ = weights_resource.get_local_path()?; - Ok(()) -} - -fn download_albert_base_v2() -> anyhow::Result<()> { - // Shared under Apache 2.0 license by the Google team at https://github.com/google-research/ALBERT. Modified with conversion to C-array format. - let config_resource = Resource::Remote(RemoteResource::from_pretrained( - AlbertConfigResources::ALBERT_BASE_V2, - )); - let vocab_resource = Resource::Remote(RemoteResource::from_pretrained( - AlbertVocabResources::ALBERT_BASE_V2, - )); - let weights_resource = Resource::Remote(RemoteResource::from_pretrained( - AlbertModelResources::ALBERT_BASE_V2, - )); - let _ = config_resource.get_local_path()?; - let _ = vocab_resource.get_local_path()?; - let _ = weights_resource.get_local_path()?; - Ok(()) -} - -fn _download_dialogpt() -> anyhow::Result<()> { - // Shared under MIT license by the Microsoft team at https://huggingface.co/microsoft/DialoGPT-medium. Modified with conversion to C-array format. - let config_resource = Resource::Remote(RemoteResource::from_pretrained( - Gpt2ConfigResources::DIALOGPT_MEDIUM, - )); - let vocab_resource = Resource::Remote(RemoteResource::from_pretrained( - Gpt2VocabResources::DIALOGPT_MEDIUM, - )); - let merges_resource = Resource::Remote(RemoteResource::from_pretrained( - Gpt2MergesResources::DIALOGPT_MEDIUM, - )); - let weights_resource = Resource::Remote(RemoteResource::from_pretrained( - Gpt2ModelResources::DIALOGPT_MEDIUM, - )); - let _ = config_resource.get_local_path()?; - let _ = vocab_resource.get_local_path()?; - let _ = merges_resource.get_local_path()?; - let _ = weights_resource.get_local_path()?; - Ok(()) -} - -fn download_t5_small() -> anyhow::Result<()> { - // Shared under Apache 2.0 license by the Google team at https://github.com/google-research/text-to-text-transfer-transformer. - let config_resource = - Resource::Remote(RemoteResource::from_pretrained(T5ConfigResources::T5_SMALL)); - let vocab_resource = - Resource::Remote(RemoteResource::from_pretrained(T5VocabResources::T5_SMALL)); - let weights_resource = - Resource::Remote(RemoteResource::from_pretrained(T5ModelResources::T5_SMALL)); - let _ = config_resource.get_local_path()?; - let _ = vocab_resource.get_local_path()?; - let _ = weights_resource.get_local_path()?; - Ok(()) -} - -fn download_roberta_qa() -> anyhow::Result<()> { - // Shared under Apache 2.0 license by [deepset](https://deepset.ai) at https://huggingface.co/deepset/roberta-base-squad2. - let config_resource = Resource::Remote(RemoteResource::from_pretrained( - RobertaConfigResources::ROBERTA_QA, - )); - let vocab_resource = Resource::Remote(RemoteResource::from_pretrained( - RobertaVocabResources::ROBERTA_QA, - )); - let weights_resource = Resource::Remote(RemoteResource::from_pretrained( - RobertaModelResources::ROBERTA_QA, - )); - let merges_resource = Resource::Remote(RemoteResource::from_pretrained( - RobertaMergesResources::ROBERTA_QA, - )); - let _ = config_resource.get_local_path()?; - let _ = vocab_resource.get_local_path()?; - let _ = merges_resource.get_local_path()?; - let _ = weights_resource.get_local_path()?; - Ok(()) -} - -fn download_bert_qa() -> anyhow::Result<()> { - // Shared under Apache 2.0 license by [deepset](https://deepset.ai) at https://huggingface.co/deepset/roberta-base-squad2. - let config_resource = Resource::Remote(RemoteResource::from_pretrained( - BertConfigResources::BERT_QA, - )); - let vocab_resource = - Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT_QA)); - let weights_resource = - Resource::Remote(RemoteResource::from_pretrained(BertModelResources::BERT_QA)); - let _ = config_resource.get_local_path()?; - let _ = vocab_resource.get_local_path()?; - let _ = weights_resource.get_local_path()?; - Ok(()) -} - -fn download_xlm_roberta_ner_german() -> anyhow::Result<()> { - // Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. - let config_resource = Resource::Remote(RemoteResource::from_pretrained( - RobertaConfigResources::XLM_ROBERTA_NER_DE, - )); - let vocab_resource = Resource::Remote(RemoteResource::from_pretrained( - RobertaVocabResources::XLM_ROBERTA_NER_DE, - )); - let weights_resource = Resource::Remote(RemoteResource::from_pretrained( - RobertaModelResources::XLM_ROBERTA_NER_DE, - )); - let _ = config_resource.get_local_path()?; - let _ = vocab_resource.get_local_path()?; - let _ = weights_resource.get_local_path()?; - Ok(()) -} - -fn download_xlnet_base_cased() -> anyhow::Result<()> { - // Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. - let config_resource = Resource::Remote(RemoteResource::from_pretrained( - XLNetConfigResources::XLNET_BASE_CASED, - )); - let vocab_resource = Resource::Remote(RemoteResource::from_pretrained( - XLNetVocabResources::XLNET_BASE_CASED, - )); - let weights_resource = Resource::Remote(RemoteResource::from_pretrained( - XLNetModelResources::XLNET_BASE_CASED, - )); - let _ = config_resource.get_local_path()?; - let _ = vocab_resource.get_local_path()?; - let _ = weights_resource.get_local_path()?; - Ok(()) -} - -fn main() { - let _ = download_distil_gpt2(); - let _ = download_distilbert_sst2(); - let _ = download_distilbert_qa(); - let _ = download_distilbert(); - let _ = download_gpt2(); - let _ = download_gpt(); - let _ = download_roberta(); - let _ = download_bert(); - let _ = download_bert_ner(); - let _ = download_bart(); - let _ = download_bart_cnn(); - let _ = download_electra_generator(); - let _ = download_electra_discriminator(); - let _ = download_albert_base_v2(); - let _ = download_t5_small(); - let _ = download_roberta_qa(); - let _ = download_bert_qa(); - let _ = download_xlm_roberta_ner_german(); - let _ = download_xlnet_base_cased(); -} diff --git a/examples/electra_discriminator.rs b/examples/electra_discriminator.rs deleted file mode 100644 index 169a7aa..0000000 --- a/examples/electra_discriminator.rs +++ /dev/null @@ -1,96 +0,0 @@ -// Copyright 2020 The Google Research Authors. -// Copyright 2019-present, the HuggingFace Inc. team -// Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. -// Copyright 2019 Guillaume Becquin -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// http://www.apache.org/licenses/LICENSE-2.0 -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use rust_bert::electra::{ - ElectraConfig, ElectraConfigResources, ElectraDiscriminator, ElectraModelResources, - ElectraVocabResources, -}; -use rust_bert::resources::{RemoteResource, Resource}; -use rust_bert::Config; -use rust_tokenizers::tokenizer::{ - BertTokenizer, MultiThreadedTokenizer, Tokenizer, TruncationStrategy, -}; -use tch::{nn, no_grad, Device, Tensor}; - -fn main() -> anyhow::Result<()> { - // Resources paths - let config_resource = Resource::Remote(RemoteResource::from_pretrained( - ElectraConfigResources::BASE_DISCRIMINATOR, - )); - let vocab_resource = Resource::Remote(RemoteResource::from_pretrained( - ElectraVocabResources::BASE_DISCRIMINATOR, - )); - let weights_resource = Resource::Remote(RemoteResource::from_pretrained( - ElectraModelResources::BASE_DISCRIMINATOR, - )); - let config_path = config_resource.get_local_path()?; - let vocab_path = vocab_resource.get_local_path()?; - let weights_path = weights_resource.get_local_path()?; - - // Set-up masked LM model - let device = Device::Cpu; - let mut vs = nn::VarStore::new(device); - let tokenizer: BertTokenizer = - BertTokenizer::from_file(vocab_path.to_str().unwrap(), true, true)?; - let config = ElectraConfig::from_file(config_path); - let electra_model = ElectraDiscriminator::new(&vs.root(), &config); - vs.load(weights_path)?; - - // Define input - let input = ["One Two Three Ten Five Six Seven Eight"]; - let tokenized_input = MultiThreadedTokenizer::encode_list( - &tokenizer, - &input, - 128, - &TruncationStrategy::LongestFirst, - 0, - ); - let max_len = tokenized_input - .iter() - .map(|input| input.token_ids.len()) - .max() - .unwrap(); - let encoded_input = tokenized_input - .iter() - .map(|input| input.token_ids.clone()) - .map(|mut input| { - input.extend(vec![0; max_len - input.len()]); - input - }) - .map(|input| Tensor::of_slice(&(input))) - .collect::>(); - let input_tensor = Tensor::stack(encoded_input.as_slice(), 0).to(device); - - // Forward pass - let model_output = - no_grad(|| electra_model.forward_t(Some(input_tensor), None, None, None, None, false)); - - // Print model predictions - for (position, token) in tokenized_input[0].token_ids.iter().enumerate() { - let probability = model_output.probabilities.double_value(&[position as i64]); - let generated = if probability > 0.5 { - "generated" - } else { - "original" - }; - println!( - "{:?}: {} ({:.1}%)", - tokenizer.decode([*token].to_vec(), false, false), - generated, - 100f64 * probability - ) - } - - Ok(()) -} diff --git a/examples/electra_masked_lm.rs b/examples/electra_masked_lm.rs deleted file mode 100644 index 4a9f3d2..0000000 --- a/examples/electra_masked_lm.rs +++ /dev/null @@ -1,93 +0,0 @@ -// Copyright 2020 The Google Research Authors. -// Copyright 2019-present, the HuggingFace Inc. team -// Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. -// Copyright 2019 Guillaume Becquin -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// http://www.apache.org/licenses/LICENSE-2.0 -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use rust_bert::electra::{ - ElectraConfig, ElectraConfigResources, ElectraForMaskedLM, ElectraModelResources, - ElectraVocabResources, -}; -use rust_bert::resources::{RemoteResource, Resource}; -use rust_bert::Config; -use rust_tokenizers::tokenizer::{BertTokenizer, MultiThreadedTokenizer, TruncationStrategy}; -use rust_tokenizers::vocab::Vocab; -use tch::{nn, no_grad, Device, Tensor}; - -fn main() -> anyhow::Result<()> { - // Resources paths - let config_resource = Resource::Remote(RemoteResource::from_pretrained( - ElectraConfigResources::BASE_GENERATOR, - )); - let vocab_resource = Resource::Remote(RemoteResource::from_pretrained( - ElectraVocabResources::BASE_GENERATOR, - )); - let weights_resource = Resource::Remote(RemoteResource::from_pretrained( - ElectraModelResources::BASE_GENERATOR, - )); - let config_path = config_resource.get_local_path()?; - let vocab_path = vocab_resource.get_local_path()?; - let weights_path = weights_resource.get_local_path()?; - - // Set-up masked LM model - let device = Device::Cpu; - let mut vs = nn::VarStore::new(device); - let tokenizer: BertTokenizer = - BertTokenizer::from_file(vocab_path.to_str().unwrap(), true, true)?; - let config = ElectraConfig::from_file(config_path); - let electra_model = ElectraForMaskedLM::new(&vs.root(), &config); - vs.load(weights_path)?; - - // Define input - let input = [ - "Looks like one [MASK] is missing", - "It was a very nice and [MASK] day", - ]; - let tokenized_input = tokenizer.encode_list(&input, 128, &TruncationStrategy::LongestFirst, 0); - let max_len = tokenized_input - .iter() - .map(|input| input.token_ids.len()) - .max() - .unwrap(); - let tokenized_input = tokenized_input - .iter() - .map(|input| input.token_ids.clone()) - .map(|mut input| { - input.extend(vec![0; max_len - input.len()]); - input - }) - .map(|input| Tensor::of_slice(&(input))) - .collect::>(); - let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device); - - // Forward pass - let model_output = - no_grad(|| electra_model.forward_t(Some(input_tensor), None, None, None, None, false)); - - // Print masked tokens - let index_1 = model_output - .prediction_scores - .get(0) - .get(4) - .argmax(0, false); - let index_2 = model_output - .prediction_scores - .get(1) - .get(7) - .argmax(0, false); - let word_1 = tokenizer.vocab().id_to_token(&index_1.int64_value(&[])); - let word_2 = tokenizer.vocab().id_to_token(&index_2.int64_value(&[])); - - println!("{}", word_1); // Outputs "thing" : "Looks like one [thing] is missing" - println!("{}", word_2); // Outputs "sunny" : "It was a very nice and [sunny] day" - - Ok(()) -} diff --git a/examples/generation.rs b/examples/generation_gpt2.rs similarity index 97% rename from examples/generation.rs rename to examples/generation_gpt2.rs index 2e04018..c1c74b1 100644 --- a/examples/generation.rs +++ b/examples/generation_gpt2.rs @@ -16,7 +16,7 @@ use rust_bert::pipelines::common::ModelType; use rust_bert::pipelines::text_generation::{TextGenerationConfig, TextGenerationModel}; fn main() -> anyhow::Result<()> { - // Set-up masked LM model + // Set-up model let generate_config = TextGenerationConfig { model_type: ModelType::GPT2, max_length: 30, diff --git a/examples/generation_reformer.rs b/examples/generation_reformer.rs index c8e57f4..e43b131 100644 --- a/examples/generation_reformer.rs +++ b/examples/generation_reformer.rs @@ -22,7 +22,7 @@ use rust_bert::reformer::{ use rust_bert::resources::{RemoteResource, Resource}; fn main() -> anyhow::Result<()> { - // Set-up masked LM model + // Set-up model // Resources paths let config_resource = Resource::Remote(RemoteResource::from_pretrained( ReformerConfigResources::CRIME_AND_PUNISHMENT, diff --git a/examples/generation_xlnet.rs b/examples/generation_xlnet.rs index b2c57ef..7cf98a3 100644 --- a/examples/generation_xlnet.rs +++ b/examples/generation_xlnet.rs @@ -20,7 +20,7 @@ use rust_bert::resources::{RemoteResource, Resource}; use rust_bert::xlnet::{XLNetConfigResources, XLNetModelResources, XLNetVocabResources}; fn main() -> anyhow::Result<()> { - // Set-up masked LM model + // Set-up model // Resources paths let config_resource = Resource::Remote(RemoteResource::from_pretrained( XLNetConfigResources::XLNET_BASE_CASED, diff --git a/examples/gpt2.rs b/examples/gpt2.rs deleted file mode 100644 index d363466..0000000 --- a/examples/gpt2.rs +++ /dev/null @@ -1,97 +0,0 @@ -// Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc. -// Copyright 2019 Guillaume Becquin -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// http://www.apache.org/licenses/LICENSE-2.0 -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -extern crate anyhow; - -use rust_bert::gpt2::{ - GPT2LMHeadModel, Gpt2Config, Gpt2ConfigResources, Gpt2MergesResources, Gpt2ModelResources, - Gpt2VocabResources, -}; -use rust_bert::pipelines::generation_utils::{Cache, LMHeadModel}; -use rust_bert::resources::{RemoteResource, Resource}; -use rust_bert::Config; -use rust_tokenizers::tokenizer::{Gpt2Tokenizer, Tokenizer, TruncationStrategy}; -use tch::{nn, Device, Tensor}; - -fn main() -> anyhow::Result<()> { - // Resources set-up - let config_resource = - Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2)); - let vocab_resource = - Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2)); - let merges_resource = - Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2)); - let weights_resource = - Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2)); - let config_path = config_resource.get_local_path()?; - let vocab_path = vocab_resource.get_local_path()?; - let merges_path = merges_resource.get_local_path()?; - let weights_path = weights_resource.get_local_path()?; - - // Set-up masked LM model - let device = Device::Cpu; - let mut vs = nn::VarStore::new(device); - let tokenizer: Gpt2Tokenizer = Gpt2Tokenizer::from_file( - vocab_path.to_str().unwrap(), - merges_path.to_str().unwrap(), - false, - )?; - let config = Gpt2Config::from_file(config_path); - let gpt2_model = GPT2LMHeadModel::new(&vs.root(), &config); - vs.load(weights_path)?; - - // Define input - let input = ["One two three four five six seven eight nine ten eleven"]; - let tokenized_input = tokenizer.encode_list(&input, 128, &TruncationStrategy::LongestFirst, 0); - let max_len = tokenized_input - .iter() - .map(|input| input.token_ids.len()) - .max() - .unwrap(); - let tokenized_input = tokenized_input - .iter() - .map(|input| input.token_ids.clone()) - .map(|mut input| { - input.extend(vec![0; max_len - input.len()]); - input - }) - .map(|input| Tensor::of_slice(&(input))) - .collect::>(); - let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device); - - // Forward pass - let model_output = gpt2_model - .forward_t( - &Some(input_tensor), - Cache::None, - &None, - &None, - &None, - &None, - None, - &None, - false, - ) - .unwrap(); - - let next_word_id = model_output - .lm_logits - .get(0) - .get(-1) - .argmax(-1, true) - .int64_value(&[0]); - let next_word = tokenizer.decode(vec![next_word_id], true, true); - println!("Provided input: {}", input[0]); - println!("Next word: {}", next_word); - - Ok(()) -} diff --git a/examples/bert.rs b/examples/masked_language_model_bert.rs similarity index 100% rename from examples/bert.rs rename to examples/masked_language_model_bert.rs diff --git a/examples/mobilebert_masked_lm.rs b/examples/mobilebert_masked_lm.rs deleted file mode 100644 index 73d8c3e..0000000 --- a/examples/mobilebert_masked_lm.rs +++ /dev/null @@ -1,101 +0,0 @@ -// Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc. -// Copyright 2019 Guillaume Becquin -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// http://www.apache.org/licenses/LICENSE-2.0 -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -extern crate anyhow; - -use rust_bert::mobilebert::{ - MobileBertConfig, MobileBertConfigResources, MobileBertForMaskedLM, MobileBertModelResources, - MobileBertVocabResources, -}; -use rust_bert::resources::{RemoteResource, Resource}; -use rust_bert::Config; -use rust_tokenizers::tokenizer::{BertTokenizer, MultiThreadedTokenizer, TruncationStrategy}; -use rust_tokenizers::vocab::Vocab; -use tch::{nn, no_grad, Device, Tensor}; - -fn main() -> anyhow::Result<()> { - // Resources paths - let config_resource = Resource::Remote(RemoteResource::from_pretrained( - MobileBertConfigResources::MOBILEBERT_UNCASED, - )); - let vocab_resource = Resource::Remote(RemoteResource::from_pretrained( - MobileBertVocabResources::MOBILEBERT_UNCASED, - )); - let weights_resource = Resource::Remote(RemoteResource::from_pretrained( - MobileBertModelResources::MOBILEBERT_UNCASED, - )); - let config_path = config_resource.get_local_path()?; - let vocab_path = vocab_resource.get_local_path()?; - let weights_path = weights_resource.get_local_path()?; - - // Set-up masked LM model - let device = Device::Cpu; - let mut vs = nn::VarStore::new(device); - let tokenizer: BertTokenizer = - BertTokenizer::from_file(vocab_path.to_str().unwrap(), true, true)?; - let config = MobileBertConfig::from_file(config_path); - let mobilebert_model = MobileBertForMaskedLM::new(&vs.root(), &config); - vs.load(weights_path)?; - - // Define input - let input = [ - "Looks like one [MASK] is missing", - "It was a very nice and [MASK] day", - ]; - let tokenized_input = tokenizer.encode_list(&input, 128, &TruncationStrategy::LongestFirst, 0); - let max_len = tokenized_input - .iter() - .map(|input| input.token_ids.len()) - .max() - .unwrap(); - let tokenized_input = tokenized_input - .iter() - .map(|input| input.token_ids.clone()) - .map(|mut input| { - input.extend(vec![0; max_len - input.len()]); - input - }) - .map(|input| Tensor::of_slice(&(input))) - .collect::>(); - let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device); - - // Forward pass - let model_output = - no_grad(|| mobilebert_model.forward_t(Some(&input_tensor), None, None, None, None, false))?; - - // Print masked tokens - let index_1 = model_output.logits.get(0).get(4).argmax(0, false); - let index_2 = model_output.logits.get(1).get(7).argmax(0, false); - let word_1 = tokenizer.vocab().id_to_token(&index_1.int64_value(&[])); - let word_2 = tokenizer.vocab().id_to_token(&index_2.int64_value(&[])); - - println!("{}", word_1); // Outputs "thing" : "Looks like one [thing] is missing" - println!( - "score: {}", - model_output - .logits - .get(0) - .get(4) - .double_value(&[i64::from(&index_1)]) - ); // 10.0558 - - println!("{}", word_2); // Outputs "sunny" : "It was a very nice and [sunny] day" - println!( - "score: {}", - model_output - .logits - .get(1) - .get(7) - .double_value(&[i64::from(&index_2)]) - ); // 14.2708 - Ok(()) -} diff --git a/examples/ner.rs b/examples/named_entities_recognition.rs similarity index 100% rename from examples/ner.rs rename to examples/named_entities_recognition.rs diff --git a/examples/openai_gpt.rs b/examples/openai_gpt.rs deleted file mode 100644 index 90f2343..0000000 --- a/examples/openai_gpt.rs +++ /dev/null @@ -1,102 +0,0 @@ -// Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc. -// Copyright 2019 Guillaume Becquin -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// http://www.apache.org/licenses/LICENSE-2.0 -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -extern crate anyhow; - -use rust_bert::gpt2::Gpt2Config; -use rust_bert::openai_gpt::{ - OpenAIGPTLMHeadModel, OpenAiGptConfigResources, OpenAiGptMergesResources, - OpenAiGptModelResources, OpenAiGptVocabResources, -}; -use rust_bert::pipelines::generation_utils::{Cache, LMHeadModel}; -use rust_bert::resources::{RemoteResource, Resource}; -use rust_bert::Config; -use rust_tokenizers::tokenizer::{OpenAiGptTokenizer, Tokenizer, TruncationStrategy}; -use tch::{nn, Device, Tensor}; - -fn main() -> anyhow::Result<()> { - // Resources paths - let config_resource = Resource::Remote(RemoteResource::from_pretrained( - OpenAiGptConfigResources::GPT, - )); - let vocab_resource = Resource::Remote(RemoteResource::from_pretrained( - OpenAiGptVocabResources::GPT, - )); - let merges_resource = Resource::Remote(RemoteResource::from_pretrained( - OpenAiGptMergesResources::GPT, - )); - let weights_resource = Resource::Remote(RemoteResource::from_pretrained( - OpenAiGptModelResources::GPT, - )); - let config_path = config_resource.get_local_path()?; - let vocab_path = vocab_resource.get_local_path()?; - let merges_path = merges_resource.get_local_path()?; - let weights_path = weights_resource.get_local_path()?; - - // Set-up masked LM model - let device = Device::Cpu; - let mut vs = nn::VarStore::new(device); - let tokenizer = OpenAiGptTokenizer::from_file( - vocab_path.to_str().unwrap(), - merges_path.to_str().unwrap(), - true, - )?; - let config = Gpt2Config::from_file(config_path); - let openai_gpt = OpenAIGPTLMHeadModel::new(&vs.root(), &config); - vs.load(weights_path)?; - - // Define input - let input = ["Wondering what the next word will"]; - let tokenized_input = tokenizer.encode_list(&input, 128, &TruncationStrategy::LongestFirst, 0); - let max_len = tokenized_input - .iter() - .map(|input| input.token_ids.len()) - .max() - .unwrap(); - let tokenized_input = tokenized_input - .iter() - .map(|input| input.token_ids.clone()) - .map(|mut input| { - input.extend(vec![0; max_len - input.len()]); - input - }) - .map(|input| Tensor::of_slice(&(input))) - .collect::>(); - let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device); - - // Forward pass - let model_output = openai_gpt - .forward_t( - &Some(input_tensor), - Cache::None, - &None, - &None, - &None, - &None, - None, - &None, - false, - ) - .unwrap(); - - let next_word_id = model_output - .lm_logits - .get(0) - .get(-1) - .argmax(-1, true) - .int64_value(&[0]); - let next_word = tokenizer.decode(vec![next_word_id], true, true); - println!("Provided input: {}", input[0]); - println!("Next word: {}", next_word); - - Ok(()) -} diff --git a/examples/squad.rs b/examples/question_answering_squad.rs similarity index 100% rename from examples/squad.rs rename to examples/question_answering_squad.rs diff --git a/examples/reformer.rs b/examples/reformer.rs deleted file mode 100644 index 1085e24..0000000 --- a/examples/reformer.rs +++ /dev/null @@ -1,74 +0,0 @@ -// Copyright 2018 Google AI and Google Brain team. -// Copyright 2018 Carnegie Mellon University Authors. -// Copyright 2020-present, the HuggingFace Inc. team. -// Copyright 2020 Guillaume Becquin -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// http://www.apache.org/licenses/LICENSE-2.0 -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -extern crate anyhow; - -use rust_bert::reformer::{ - ReformerConfig, ReformerConfigResources, ReformerModelResources, ReformerModelWithLMHead, - ReformerVocabResources, -}; -use rust_bert::resources::{RemoteResource, Resource}; -use rust_bert::Config; -use rust_tokenizers::tokenizer::{MultiThreadedTokenizer, ReformerTokenizer, TruncationStrategy}; -use tch::{nn, Device, Tensor}; - -fn main() -> anyhow::Result<()> { - // Resources paths - let config_resource = Resource::Remote(RemoteResource::from_pretrained( - ReformerConfigResources::CRIME_AND_PUNISHMENT, - )); - let vocab_resource = Resource::Remote(RemoteResource::from_pretrained( - ReformerVocabResources::CRIME_AND_PUNISHMENT, - )); - let weights_resource = Resource::Remote(RemoteResource::from_pretrained( - ReformerModelResources::CRIME_AND_PUNISHMENT, - )); - let config_path = config_resource.get_local_path()?; - let vocab_path = vocab_resource.get_local_path()?; - let weights_path = weights_resource.get_local_path()?; - - // Set-up masked LM model - let device = Device::cuda_if_available(); - let mut vs = nn::VarStore::new(device); - let tokenizer = ReformerTokenizer::from_file(vocab_path.to_str().unwrap(), false)?; - let config = ReformerConfig::from_file(config_path); - let reformer_model = ReformerModelWithLMHead::new(&vs.root(), &config)?; - vs.load(weights_path)?; - - // Define input - let input = ["One two three four five six seven eight nine ten eleven One two three four five six seven eight nine ten eleven One two three four five six seven eight nine ten eleven"]; - let tokenized_input = tokenizer.encode_list(&input, 128, &TruncationStrategy::LongestFirst, 0); - let max_len = tokenized_input - .iter() - .map(|input| input.token_ids.len()) - .max() - .unwrap(); - let tokenized_input = tokenized_input - .iter() - .map(|input| input.token_ids.clone()) - .map(|mut input| { - input.extend(vec![0; max_len - input.len()]); - input - }) - .map(|input| Tensor::of_slice(&(input))) - .collect::>(); - let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device); - - // Forward pass - let _model_output = - reformer_model.forward_t(Some(&input_tensor), None, None, None, None, None, false)?; - - _model_output.logits.print(); - Ok(()) -} diff --git a/examples/roberta.rs b/examples/roberta.rs deleted file mode 100644 index 679ba68..0000000 --- a/examples/roberta.rs +++ /dev/null @@ -1,119 +0,0 @@ -// Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc. -// Copyright 2019 Guillaume Becquin -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// http://www.apache.org/licenses/LICENSE-2.0 -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -extern crate anyhow; - -use rust_bert::bert::BertConfig; -use rust_bert::resources::{RemoteResource, Resource}; -use rust_bert::roberta::{ - RobertaConfigResources, RobertaForMaskedLM, RobertaMergesResources, RobertaModelResources, - RobertaVocabResources, -}; -use rust_bert::Config; -use rust_tokenizers::tokenizer::{RobertaTokenizer, Tokenizer, TruncationStrategy}; -use rust_tokenizers::vocab::Vocab; -use tch::{nn, no_grad, Device, Tensor}; - -fn main() -> anyhow::Result<()> { - // Resources paths - let config_resource = Resource::Remote(RemoteResource::from_pretrained( - RobertaConfigResources::ROBERTA, - )); - let vocab_resource = Resource::Remote(RemoteResource::from_pretrained( - RobertaVocabResources::ROBERTA, - )); - let merges_resource = Resource::Remote(RemoteResource::from_pretrained( - RobertaMergesResources::ROBERTA, - )); - let weights_resource = Resource::Remote(RemoteResource::from_pretrained( - RobertaModelResources::ROBERTA, - )); - let config_path = config_resource.get_local_path()?; - let vocab_path = vocab_resource.get_local_path()?; - let merges_path = merges_resource.get_local_path()?; - let weights_path = weights_resource.get_local_path()?; - - // Set-up masked LM model - let device = Device::Cpu; - let mut vs = nn::VarStore::new(device); - let tokenizer: RobertaTokenizer = RobertaTokenizer::from_file( - vocab_path.to_str().unwrap(), - merges_path.to_str().unwrap(), - true, - false, - )?; - let config = BertConfig::from_file(config_path); - let bert_model = RobertaForMaskedLM::new(&vs.root(), &config); - vs.load(weights_path)?; - - // Define input - let input = [ - " Looks like one thing is missing", - "It\'s like comparing oranges to apples", - ]; - let tokenized_input = tokenizer.encode_list(&input, 128, &TruncationStrategy::LongestFirst, 0); - let max_len = tokenized_input - .iter() - .map(|input| input.token_ids.len()) - .max() - .unwrap(); - let mut tokenized_input = tokenized_input - .iter() - .map(|input| input.token_ids.clone()) - .map(|mut input| { - input.extend(vec![0; max_len - input.len()]); - input - }) - .collect::>(); - - // Masking the token [thing] of sentence 1 and [oranges] of sentence 2 - tokenized_input[0][4] = 103; - tokenized_input[1][5] = 103; - let tokenized_input = tokenized_input - .iter() - .map(|input| Tensor::of_slice(&(input))) - .collect::>(); - let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device); - - // Forward pass - let model_output = no_grad(|| { - bert_model.forward_t( - Some(input_tensor), - None, - None, - None, - None, - &None, - &None, - false, - ) - }); - - // Print masked tokens - let index_1 = model_output - .prediction_scores - .get(0) - .get(4) - .argmax(0, false); - let index_2 = model_output - .prediction_scores - .get(1) - .get(5) - .argmax(0, false); - let word_1 = tokenizer.vocab().id_to_token(&index_1.int64_value(&[])); - let word_2 = tokenizer.vocab().id_to_token(&index_2.int64_value(&[])); - - println!("{}", word_1); // Outputs "some" : "Looks like [some] thing is missing" - println!("{}", word_2); // Outputs "apple" : "It\'s like comparing [apple] to apples" - - Ok(()) -} diff --git a/examples/sentiment.rs b/examples/sentiment_analysis.rs similarity index 100% rename from examples/sentiment.rs rename to examples/sentiment_analysis.rs diff --git a/examples/sst2.rs b/examples/sentiment_analysis_sst2.rs similarity index 100% rename from examples/sst2.rs rename to examples/sentiment_analysis_sst2.rs diff --git a/examples/summarization.rs b/examples/summarization_bart.rs similarity index 100% rename from examples/summarization.rs rename to examples/summarization_bart.rs diff --git a/examples/translation.rs b/examples/translation_marian.rs similarity index 100% rename from examples/translation.rs rename to examples/translation_marian.rs diff --git a/examples/translation_mbart.rs b/examples/translation_mbart.rs new file mode 100644 index 0000000..a88374b --- /dev/null +++ b/examples/translation_mbart.rs @@ -0,0 +1,59 @@ +// Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc. +// Copyright 2019 Guillaume Becquin +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +extern crate anyhow; + +use rust_bert::mbart::{ + MBartConfigResources, MBartGenerator, MBartModelResources, MBartVocabResources, +}; +use rust_bert::pipelines::generation_utils::{GenerateConfig, LanguageGenerator}; +use rust_bert::resources::{RemoteResource, Resource}; + +fn main() -> anyhow::Result<()> { + let generate_config = GenerateConfig { + max_length: 56, + model_resource: Resource::Remote(RemoteResource::from_pretrained( + MBartModelResources::MBART50_MANY_TO_MANY, + )), + config_resource: Resource::Remote(RemoteResource::from_pretrained( + MBartConfigResources::MBART50_MANY_TO_MANY, + )), + vocab_resource: Resource::Remote(RemoteResource::from_pretrained( + MBartVocabResources::MBART50_MANY_TO_MANY, + )), + merges_resource: Resource::Remote(RemoteResource::from_pretrained( + MBartVocabResources::MBART50_MANY_TO_MANY, + )), + do_sample: false, + num_beams: 1, + ..Default::default() + }; + let model = MBartGenerator::new(generate_config)?; + + let input_context_1 = "en_XX The quick brown fox jumps over the lazy dog."; + let target_language = model.get_tokenizer().convert_tokens_to_ids(["de_DE"])[0]; + + let output = model.generate( + Some(&[input_context_1]), + None, + None, + None, + None, + target_language, + None, + ); + + for sentence in output { + println!("{:?}", sentence); + } + Ok(()) +} diff --git a/examples/t5.rs b/examples/translation_t5.rs similarity index 96% rename from examples/t5.rs rename to examples/translation_t5.rs index 7c4eee9..18cf08e 100644 --- a/examples/t5.rs +++ b/examples/translation_t5.rs @@ -35,13 +35,13 @@ fn main() -> anyhow::Result<()> { ..Default::default() }; - // Set-up masked LM model + // Set-up model let t5_model = T5Generator::new(generate_config)?; // Define input let input = ["translate English to German: This sentence will get translated to German"]; - let output = t5_model.generate(Some(input.to_vec()), None, None, None, None, None); + let output = t5_model.generate(Some(input.to_vec()), None, None, None, None, None, None); println!("{:?}", output); Ok(()) diff --git a/examples/xlnet.rs b/examples/xlnet.rs deleted file mode 100644 index 53704e9..0000000 --- a/examples/xlnet.rs +++ /dev/null @@ -1,99 +0,0 @@ -// Copyright 2018 Google AI and Google Brain team. -// Copyright 2018 Carnegie Mellon University Authors. -// Copyright 2020-present, the HuggingFace Inc. team. -// Copyright 2020 Guillaume Becquin -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// http://www.apache.org/licenses/LICENSE-2.0 -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -extern crate anyhow; - -use rust_bert::resources::{RemoteResource, Resource}; -use rust_bert::xlnet::{ - XLNetConfig, XLNetConfigResources, XLNetLMHeadModel, XLNetModelResources, XLNetVocabResources, -}; -use rust_bert::Config; -use rust_tokenizers::tokenizer::{MultiThreadedTokenizer, TruncationStrategy, XLNetTokenizer}; -use rust_tokenizers::vocab::Vocab; -use tch::{nn, no_grad, Device, Kind, Tensor}; - -fn main() -> anyhow::Result<()> { - // Resources paths - let config_resource = Resource::Remote(RemoteResource::from_pretrained( - XLNetConfigResources::XLNET_BASE_CASED, - )); - let vocab_resource = Resource::Remote(RemoteResource::from_pretrained( - XLNetVocabResources::XLNET_BASE_CASED, - )); - let weights_resource = Resource::Remote(RemoteResource::from_pretrained( - XLNetModelResources::XLNET_BASE_CASED, - )); - let config_path = config_resource.get_local_path()?; - let vocab_path = vocab_resource.get_local_path()?; - let weights_path = weights_resource.get_local_path()?; - - // Set-up masked LM model - let device = Device::cuda_if_available(); - let mut vs = nn::VarStore::new(device); - let tokenizer: XLNetTokenizer = - XLNetTokenizer::from_file(vocab_path.to_str().unwrap(), false, true)?; - let config = XLNetConfig::from_file(config_path); - let xlnet_model = XLNetLMHeadModel::new(&vs.root(), &config); - vs.load(weights_path)?; - - // Define input - let input = ["One two three four"]; - let tokenized_input = tokenizer.encode_list(&input, 128, &TruncationStrategy::LongestFirst, 0); - let max_len = tokenized_input - .iter() - .map(|input| input.token_ids.len()) - .max() - .unwrap(); - let tokenized_input = tokenized_input - .iter() - .map(|input| input.token_ids.clone()) - .map(|mut input| { - input.extend(vec![0; max_len - input.len()]); - input - }) - .map(|input| Tensor::of_slice(&(input[..input.len() - 2]))) - .collect::>(); - let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device); - - // Forward pass - let perm_mask = Tensor::zeros(&[1, 4, 4], (Kind::Float, device)); - let _ = perm_mask.narrow(2, 3, 1).fill_(1.0); - - let target_mapping = Tensor::zeros(&[1, 1, 4], (Kind::Float, device)); - let _ = target_mapping.narrow(2, 3, 1).fill_(1.0); - let model_output = no_grad(|| { - xlnet_model - .forward_t( - Some(&input_tensor), - None, - None, - Some(perm_mask.as_ref()), - Some(target_mapping.as_ref()), - None, - None, - false, - ) - .unwrap() - }); - - let index_1 = model_output - .lm_logits - .get(0) - .argmax(1, false) - .int64_value(&[]); - let score_1 = model_output.lm_logits.double_value(&[0, 0, index_1]); - let word_1 = tokenizer.vocab().id_to_token(&index_1); - println!("{}, {}, {}", index_1, score_1, word_1); - Ok(()) -} diff --git a/src/albert/mod.rs b/src/albert/mod.rs index c4c82c8..45bad07 100644 --- a/src/albert/mod.rs +++ b/src/albert/mod.rs @@ -11,7 +11,6 @@ //! //! # Model set-up and pre-trained weights loading //! -//! A full working example is provided in `examples/albert`, run with `cargo run --example albert`. //! The example below illustrate a Masked language model example, the structure is similar for other models. //! All models expect the following resources: //! - Configuration file expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers) diff --git a/src/bart/bart_model.rs b/src/bart/bart_model.rs index fcaf6f7..3876af5 100644 --- a/src/bart/bart_model.rs +++ b/src/bart/bart_model.rs @@ -1112,7 +1112,7 @@ impl PrivateLanguageGenerator &BartForConditionalGeneration { &self.model } - fn get_tokenizer(&self) -> &TokenizerOption { + fn _get_tokenizer(&self) -> &TokenizerOption { &self.tokenizer } fn get_var_store(&self) -> &nn::VarStore { @@ -1201,7 +1201,7 @@ impl PrivateLanguageGenerator, { - let tokens = self.get_tokenizer().encode_list( + let tokens = self._get_tokenizer().encode_list( prompt_text.as_ref(), max_len as usize, &TruncationStrategy::LongestFirst, @@ -1217,7 +1217,7 @@ impl PrivateLanguageGenerator value, None => self - .get_tokenizer() + ._get_tokenizer() .convert_tokens_to_ids(&[RobertaVocab::unknown_value()])[0], }; diff --git a/src/bart/mod.rs b/src/bart/mod.rs index 3706776..0d0397a 100644 --- a/src/bart/mod.rs +++ b/src/bart/mod.rs @@ -6,8 +6,7 @@ //! //! # Model set-up and pre-trained weights loading //! -//! A full working example is provided in `examples/bart`, run with `cargo run --example bart`. -//! Alternatively, the summarization capabilities are illustrated in `examples/summarization.rs`, run with `cargo run --example summarization`. +//! The summarization capabilities are illustrated in `examples/summarization_bart`, run with `cargo run --example summarization_bart`. //! All models expect the following resources: //! - Configuration file expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers) //! - Model weights are expected to have a structure and parameter names following the [Transformers library](https://github.com/huggingface/transformers). A conversion using the Python utility scripts is required to convert the `.bin` weights to the `.ot` format. diff --git a/src/bert/bert_model.rs b/src/bert/bert_model.rs index e7a7de7..652d313 100644 --- a/src/bert/bert_model.rs +++ b/src/bert/bert_model.rs @@ -1240,7 +1240,7 @@ mod test { #[test] #[ignore] // compilation is enough, no need to run - fn bart_model_send() { + fn bert_model_send() { let config_resource = Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT)); let config_path = config_resource.get_local_path().expect(""); diff --git a/src/bert/mod.rs b/src/bert/mod.rs index 3bc352d..4c49fc0 100644 --- a/src/bert/mod.rs +++ b/src/bert/mod.rs @@ -10,7 +10,7 @@ //! //! # Model set-up and pre-trained weights loading //! -//! A full working example is provided in `examples/bert`, run with `cargo run --example bert`. +//! A full working example is provided in `examples/masked_language_model_bert`, run with `cargo run --example masked_language_model_bert`. //! The example below illustrate a Masked language model example, the structure is similar for other models. //! All models expect the following resources: //! - Configuration file expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers) diff --git a/src/distilbert/mod.rs b/src/distilbert/mod.rs index 29b7cf0..a976446 100644 --- a/src/distilbert/mod.rs +++ b/src/distilbert/mod.rs @@ -9,7 +9,6 @@ //! //! # Model set-up and pre-trained weights loading //! -//! A full working example is provided in `examples/distilbert_masked_lm.rs`, run with `cargo run --example distilbert_masked_lm`. //! The example below illustrate a DistilBERT Masked language model example, the structure is similar for other models. //! All models expect the following resources: //! - Configuration file expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers) diff --git a/src/electra/mod.rs b/src/electra/mod.rs index 03eb043..b0d042e 100644 --- a/src/electra/mod.rs +++ b/src/electra/mod.rs @@ -14,7 +14,6 @@ //! //! # Model set-up and pre-trained weights loading //! -//! A full working example is provided in `examples/electra_masked_lm.rs`, run with `cargo run --example electra_masked_lm`. //! The example below illustrate a Masked language model example, the structure is similar for other models (e.g. discriminator). //! All models expect the following resources: //! - Configuration file expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers) diff --git a/src/gpt2/gpt2_model.rs b/src/gpt2/gpt2_model.rs index 25f5a2e..cd9c812 100644 --- a/src/gpt2/gpt2_model.rs +++ b/src/gpt2/gpt2_model.rs @@ -752,7 +752,7 @@ impl PrivateLanguageGenerator for GPT fn get_model(&self) -> &GPT2LMHeadModel { &self.model } - fn get_tokenizer(&self) -> &TokenizerOption { + fn _get_tokenizer(&self) -> &TokenizerOption { &self.tokenizer } fn get_var_store(&self) -> &nn::VarStore { diff --git a/src/gpt2/mod.rs b/src/gpt2/mod.rs index 77bb0b6..0620cb3 100644 --- a/src/gpt2/mod.rs +++ b/src/gpt2/mod.rs @@ -6,7 +6,7 @@ //! //! # Model set-up and pre-trained weights loading //! -//! A full working example is provided in `examples/generation.rs`, run with `cargo run --example generation`. +//! A full working example is provided in `examples/generation_gpt2`, run with `cargo run --example generation_gpt2`. //! All models expect the following resources: //! - Configuration file expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers) //! - Model weights are expected to have a structure and parameter names following the [Transformers library](https://github.com/huggingface/transformers). A conversion using the Python utility scripts is required to convert the `.bin` weights to the `.ot` format. diff --git a/src/gpt_neo/gpt_neo_model.rs b/src/gpt_neo/gpt_neo_model.rs index 631246e..6db9013 100644 --- a/src/gpt_neo/gpt_neo_model.rs +++ b/src/gpt_neo/gpt_neo_model.rs @@ -729,7 +729,7 @@ impl PrivateLanguageGenerator for G fn get_model(&self) -> &GptNeoForCausalLM { &self.model } - fn get_tokenizer(&self) -> &TokenizerOption { + fn _get_tokenizer(&self) -> &TokenizerOption { &self.tokenizer } fn get_var_store(&self) -> &nn::VarStore { diff --git a/src/gpt_neo/mod.rs b/src/gpt_neo/mod.rs index c6bdfc2..b9195e1 100644 --- a/src/gpt_neo/mod.rs +++ b/src/gpt_neo/mod.rs @@ -5,6 +5,7 @@ //! //! # Model set-up and pre-trained weights loading //! +//! A full working example is provided in `examples/generation_gpt_neo`, run with `cargo run --example generation_gpt_neo`. //! All models expect the following resources: //! - Configuration file expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers) //! - Model weights are expected to have a structure and parameter names following the [Transformers library](https://github.com/huggingface/transformers). A conversion using the Python utility scripts is required to convert the `.bin` weights to the `.ot` format. diff --git a/src/lib.rs b/src/lib.rs index 72e1411..01f6dad 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -58,6 +58,7 @@ //! GPT-Neo| | | |✅ | | | | //! BART|✅| | |✅ |✅| | | //! Marian| | | | | |✅| | +//! MBart|✅| | |✅ | | | | //! Electra | |✅| | | | |✅| //! ALBERT |✅|✅|✅| | | |✅| //! T5 | | | |✅ |✅|✅| | diff --git a/src/longformer/mod.rs b/src/longformer/mod.rs index 7ad6b9e..5bbc8ad 100644 --- a/src/longformer/mod.rs +++ b/src/longformer/mod.rs @@ -10,7 +10,7 @@ //! //! # Model set-up and pre-trained weights loading //! -//! A full working example (generation) is provided in `examples/question_answering_longformer`, run with `cargo run --example question_answering_longformer`. +//! A full working example (question answering) is provided in `examples/question_answering_longformer`, run with `cargo run --example question_answering_longformer`. //! All models expect the following resources: //! - Configuration file expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers) //! - Model weights are expected to have a structure and parameter names following the [Transformers library](https://github.com/huggingface/transformers). A conversion using the Python utility scripts is required to convert the `.bin` weights to the `.ot` format. diff --git a/src/marian/marian_model.rs b/src/marian/marian_model.rs index 43aa6f3..4a72b42 100644 --- a/src/marian/marian_model.rs +++ b/src/marian/marian_model.rs @@ -859,7 +859,7 @@ impl PrivateLanguageGenerator &MarianForConditionalGeneration { &self.model } - fn get_tokenizer(&self) -> &TokenizerOption { + fn _get_tokenizer(&self) -> &TokenizerOption { &self.tokenizer } fn get_var_store(&self) -> &nn::VarStore { @@ -950,7 +950,7 @@ impl PrivateLanguageGenerator, { - let tokens = self.get_tokenizer().encode_list( + let tokens = self._get_tokenizer().encode_list( prompt_text.as_ref(), max_len as usize, &TruncationStrategy::LongestFirst, @@ -965,7 +965,7 @@ impl PrivateLanguageGenerator value, - None => self.get_tokenizer().get_unk_id(), + None => self._get_tokenizer().get_unk_id(), }; let token_ids = token_ids diff --git a/src/marian/mod.rs b/src/marian/mod.rs index e99758c..b385d7f 100644 --- a/src/marian/mod.rs +++ b/src/marian/mod.rs @@ -6,7 +6,7 @@ //! //! # Model set-up and pre-trained weights loading //! -//! A full working example is provided in `examples/translation.rs`, run with `cargo run --example translation`. +//! A full working example is provided in `examples/translation_marian`, run with `cargo run --example translation_marian`. //! All models expect the following resources: //! - Configuration file expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers) //! - Model weights are expected to have a structure and parameter names following the [Transformers library](https://github.com/huggingface/transformers). A conversion using the Python utility scripts is required to convert the `.bin` weights to the `.ot` format. diff --git a/src/mbart/mbart_model.rs b/src/mbart/mbart_model.rs index df5b068..eb47e96 100644 --- a/src/mbart/mbart_model.rs +++ b/src/mbart/mbart_model.rs @@ -698,7 +698,7 @@ impl LMHeadModel for MBartForConditionalGeneration { /// # let device = Device::Cpu; /// # let vs = nn::VarStore::new(device); /// # let config = MBartConfig::from_file(config_path); - /// # let bart_model: MBartForConditionalGeneration = MBartForConditionalGeneration::new(&vs.root(), &config); + /// # let mbart_model: MBartForConditionalGeneration = MBartForConditionalGeneration::new(&vs.root(), &config); /// let (batch_size, source_sequence_length, target_sequence_length) = (64, 128, 56); /// let input_tensor = Tensor::rand(&[batch_size, source_sequence_length], (Int64, device)); /// let target_tensor = Tensor::rand(&[batch_size, target_sequence_length], (Int64, device)); @@ -915,7 +915,7 @@ impl PrivateLanguageGenerator &MBartForConditionalGeneration { &self.model } - fn get_tokenizer(&self) -> &TokenizerOption { + fn _get_tokenizer(&self) -> &TokenizerOption { &self.tokenizer } fn get_var_store(&self) -> &nn::VarStore { @@ -1002,7 +1002,7 @@ impl PrivateLanguageGenerator, { - let tokens = self.get_tokenizer().encode_list( + let tokens = self._get_tokenizer().encode_list( prompt_text.as_ref(), max_len as usize, &TruncationStrategy::LongestFirst, @@ -1018,7 +1018,7 @@ impl PrivateLanguageGenerator value, None => self - .get_tokenizer() + ._get_tokenizer() .convert_tokens_to_ids(&[MBart50Vocab::unknown_value()])[0], }; diff --git a/src/mbart/mod.rs b/src/mbart/mod.rs index eef2dbc..7567cc9 100644 --- a/src/mbart/mod.rs +++ b/src/mbart/mod.rs @@ -1,3 +1,55 @@ +//! # MBart (Liu et al.) +//! +//! Implementation of the MBart language model ([Multilingual Denoising Pre-training for Neural Machine Translation](https://arxiv.org/abs/2001.08210) Liu, Gu, Goyal, Li, Edunov, Ghazvininejad, Lewis, Zettlemoyer, 2020). +//! The base model is implemented in the `mbart_model::MBartModel` struct. The model also includes a language model head: `mbart_model::MBartForConditionalGeneration` +//! implementing the common `generation_utils::LMHeadModel` trait shared between the models used for generation (see `pipelines` for more information). +//! +//! # Model set-up and pre-trained weights loading +//! +//! The summarization capabilities are illustrated in `examples/translation_mbart`, run with `cargo run --example translation_mbart`. +//! All models expect the following resources: +//! - Configuration file expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers) +//! - Model weights are expected to have a structure and parameter names following the [Transformers library](https://github.com/huggingface/transformers). A conversion using the Python utility scripts is required to convert the `.bin` weights to the `.ot` format. +//! - `MBart50Tokenizer` using a `spiece.model` SentencePiece model +//! Pretrained models are available and can be downloaded using RemoteResources. +//! +//! ```no_run +//! # fn main() -> anyhow::Result<()> { +//! # +//! use tch::{nn, Device}; +//! # use std::path::PathBuf; +//! use rust_bert::resources::{LocalResource, Resource}; +//! use rust_bert::Config; +//! use rust_tokenizers::tokenizer::MBart50Tokenizer; +//! use rust_bert::mbart::{MBartConfig, MBartModel}; +//! +//! let config_resource = Resource::Local(LocalResource { +//! local_path: PathBuf::from("path/to/config.json"), +//! }); +//! let vocab_resource = Resource::Local(LocalResource { +//! local_path: PathBuf::from("path/to/vocab.txt"), +//! }); +//! let weights_resource = Resource::Local(LocalResource { +//! local_path: PathBuf::from("path/to/model.ot"), +//! }); +//! let config_path = config_resource.get_local_path()?; +//! let vocab_path = vocab_resource.get_local_path()?; +//! let weights_path = weights_resource.get_local_path()?; +//! +//! let device = Device::cuda_if_available(); +//! let mut vs = nn::VarStore::new(device); +//! let tokenizer: MBart50Tokenizer = MBart50Tokenizer::from_file( +//! vocab_path.to_str().unwrap(), +//! false, +//! )?; +//! let config = MBartConfig::from_file(config_path); +//! let bart_model = MBartModel::new(&vs.root(), &config); +//! vs.load(weights_path)?; +//! +//! # Ok(()) +//! # } +//! ``` + mod attention; mod decoder; mod embeddings; diff --git a/src/mobilebert/mod.rs b/src/mobilebert/mod.rs index 087393e..fb8821f 100644 --- a/src/mobilebert/mod.rs +++ b/src/mobilebert/mod.rs @@ -9,7 +9,6 @@ //! //! # Model set-up and pre-trained weights loading //! -//! A full working example (generation) is provided in `examples/mobilebert_masked_lm`, run with `cargo run --example mobilebert_masked_lm`. //! All models expect the following resources: //! - Configuration file expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers) //! - Model weights are expected to have a structure and parameter names following the [Transformers library](https://github.com/huggingface/transformers). A conversion using the Python utility scripts is required to convert the `.bin` weights to the `.ot` format. diff --git a/src/openai_gpt/mod.rs b/src/openai_gpt/mod.rs index edba067..1071789 100644 --- a/src/openai_gpt/mod.rs +++ b/src/openai_gpt/mod.rs @@ -6,7 +6,6 @@ //! //! # Model set-up and pre-trained weights loading //! -//! A full working example is provided in `examples/openai_gpt`, run with `cargo run --example openai_gpt`. //! All models expect the following resources: //! - Configuration file expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers) //! - Model weights are expected to have a structure and parameter names following the [Transformers library](https://github.com/huggingface/transformers). A conversion using the Python utility scripts is required to convert the `.bin` weights to the `.ot` format. diff --git a/src/openai_gpt/openai_gpt_model.rs b/src/openai_gpt/openai_gpt_model.rs index 2839e15..bb97132 100644 --- a/src/openai_gpt/openai_gpt_model.rs +++ b/src/openai_gpt/openai_gpt_model.rs @@ -576,7 +576,7 @@ impl PrivateLanguageGenerator &OpenAIGPTLMHeadModel { &self.model } - fn get_tokenizer(&self) -> &TokenizerOption { + fn _get_tokenizer(&self) -> &TokenizerOption { &self.tokenizer } fn get_var_store(&self) -> &nn::VarStore { diff --git a/src/pegasus/pegasus_model.rs b/src/pegasus/pegasus_model.rs index acb22b1..8e0e4e1 100644 --- a/src/pegasus/pegasus_model.rs +++ b/src/pegasus/pegasus_model.rs @@ -691,7 +691,7 @@ impl PrivateLanguageGenerator &PegasusForConditionalGeneration { &self.model } - fn get_tokenizer(&self) -> &TokenizerOption { + fn _get_tokenizer(&self) -> &TokenizerOption { &self.tokenizer } fn get_var_store(&self) -> &nn::VarStore { @@ -775,7 +775,7 @@ impl PrivateLanguageGenerator, { - let tokens = self.get_tokenizer().encode_list( + let tokens = self._get_tokenizer().encode_list( prompt_text.as_ref(), max_len as usize, &TruncationStrategy::LongestFirst, @@ -791,7 +791,7 @@ impl PrivateLanguageGenerator value, None => self - .get_tokenizer() + ._get_tokenizer() .convert_tokens_to_ids(&[PegasusVocab::pad_value()])[0], }; diff --git a/src/pipelines/conversation.rs b/src/pipelines/conversation.rs index 47d7126..a3cad2e 100644 --- a/src/pipelines/conversation.rs +++ b/src/pipelines/conversation.rs @@ -698,7 +698,7 @@ impl ConversationOption { pub fn get_tokenizer(&self) -> &TokenizerOption { match self { - Self::GPT2(model_ref) => model_ref.get_tokenizer(), + Self::GPT2(model_ref) => model_ref._get_tokenizer(), } } diff --git a/src/pipelines/generation_utils.rs b/src/pipelines/generation_utils.rs index 7c81d92..5426cb5 100644 --- a/src/pipelines/generation_utils.rs +++ b/src/pipelines/generation_utils.rs @@ -36,6 +36,7 @@ //! let min_length = Some(32); //! let max_length = Some(128); //! let decoder_start_id = None; +//! let forced_bos_token_id = None; //! //! let input_context = "The dog"; //! let second_input_context = "The cat was"; @@ -45,6 +46,7 @@ //! min_length, //! max_length, //! decoder_start_id, +//! forced_bos_token_id, //! None, //! ); //! # Ok(()) @@ -86,6 +88,7 @@ use crate::t5::LayerState as T5LayerState; use crate::xlnet::LayerState as XLNetLayerState; use self::ordered_float::OrderedFloat; +use crate::pipelines::common::TokenizerOption; extern crate ordered_float; @@ -272,7 +275,7 @@ pub(crate) mod private_generation_utils { pub trait PrivateLanguageGenerator> { fn get_model(&self) -> &T; - fn get_tokenizer(&self) -> &TokenizerOption; + fn _get_tokenizer(&self) -> &TokenizerOption; fn get_var_store(&self) -> &nn::VarStore; fn get_config(&self) -> &GenerateConfig; fn get_bos_id(&self) -> &Option; @@ -322,10 +325,10 @@ pub(crate) mod private_generation_utils { where S: AsRef<[&'a str]>, { - let tokens = self.get_tokenizer().tokenize_list(prompt_text.as_ref()); + let tokens = self._get_tokenizer().tokenize_list(prompt_text.as_ref()); let token_ids = tokens .into_iter() - .map(|prompt_tokens| self.get_tokenizer().convert_tokens_to_ids(&prompt_tokens)) + .map(|prompt_tokens| self._get_tokenizer().convert_tokens_to_ids(&prompt_tokens)) .collect::>>(); let num_truncated_tokens = token_ids @@ -365,7 +368,7 @@ pub(crate) mod private_generation_utils { let pad_token = match pad_token_id { Some(value) => value, - None => self.get_tokenizer().get_unk_id(), + None => self._get_tokenizer().get_unk_id(), }; let token_ids = token_ids @@ -1219,7 +1222,7 @@ pub trait LanguageGenerator>: /// num_return_sequences: 3, /// ..Default::default() /// }; - /// let mut gpt2_generator = GPT2Generator::new(generate_config)?; + /// let gpt2_generator = GPT2Generator::new(generate_config)?; /// let input_context = "The dog"; /// let second_input_context = "The cat was"; /// @@ -1227,6 +1230,7 @@ pub trait LanguageGenerator>: /// let min_length = 32; /// let max_length = 128; /// let decoder_start_token_id = None; + /// let forced_bos_token_id = None; /// /// //Example custom function for fine-grained generation control /// fn force_one_paragraph(_batch_id: i64, previous_token_ids: &Tensor) -> Vec { @@ -1251,6 +1255,7 @@ pub trait LanguageGenerator>: /// min_length, /// max_length, /// decoder_start_token_id, + /// forced_bos_token_id, /// Some(&force_one_paragraph) /// ); /// # Ok(()) @@ -1293,7 +1298,7 @@ pub trait LanguageGenerator>: ); let mut output = Vec::with_capacity(generated.len()); for generated_sequence in generated { - output.push(self.get_tokenizer().decode(generated_sequence, true, true)); + output.push(self._get_tokenizer().decode(generated_sequence, true, true)); } output } @@ -1337,13 +1342,14 @@ pub trait LanguageGenerator>: /// num_return_sequences: 3, /// ..Default::default() /// }; - /// let mut gpt2_generator = GPT2Generator::new(generate_config)?; + /// let gpt2_generator = GPT2Generator::new(generate_config)?; /// let input_context = "The dog"; /// let second_input_context = "The cat was"; /// let attention_mask = None; /// let min_length = 32; /// let max_length = 128; /// let decoder_start_token_id = None; + /// let forced_bos_token_id = None; /// /// //Example custom function for fine-grained generation control /// fn force_one_paragraph(_batch_id: i64, previous_token_ids: &Tensor) -> Vec { @@ -1368,6 +1374,7 @@ pub trait LanguageGenerator>: /// min_length, /// max_length, /// decoder_start_token_id, + /// forced_bos_token_id, /// Some(&force_one_paragraph), /// ); /// # Ok(()) @@ -1462,13 +1469,14 @@ pub trait LanguageGenerator>: /// num_return_sequences: 3, /// ..Default::default() /// }; - /// let mut gpt2_generator = GPT2Generator::new(generate_config)?; + /// let gpt2_generator = GPT2Generator::new(generate_config)?; /// let input_context = "The dog"; /// let second_input_context = "The cat was"; /// let attention_mask = None; /// let min_length = 32; /// let max_length = 128; /// let decoder_start_token_id = None; + /// let forced_bos_token_id = None; /// /// //Example custom function for fine-grained generation control /// fn force_one_paragraph(_batch_id: i64, previous_token_ids: &Tensor) -> Vec { @@ -1493,6 +1501,7 @@ pub trait LanguageGenerator>: /// min_length, /// max_length, /// decoder_start_token_id, + /// forced_bos_token_id, /// Some(&force_one_paragraph), /// ); /// # Ok(()) @@ -1674,6 +1683,46 @@ pub trait LanguageGenerator>: } output_ids } + + /// Returns a reference to the text generator's tokenizer + /// + /// # Returns + /// * `&TokenizerOption` Reference to the generator's tokenizer. + /// + /// # Example + /// + /// ```no_run + /// # use std::path::PathBuf; + /// # use tch::Device; + /// # fn main() -> anyhow::Result<()> { + /// use rust_bert::gpt2::GPT2Generator; + /// use rust_bert::pipelines::generation_utils::{GenerateConfig, LanguageGenerator}; + /// use tch::Tensor; + /// # let mut home: PathBuf = dirs::home_dir().unwrap(); + /// # home.push("rustbert"); + /// # home.push("gpt2"); + /// # let config_path = &home.as_path().join("config.json"); + /// # let vocab_path = &home.as_path().join("vocab.txt"); + /// # let merges_path = &home.as_path().join("merges.txt"); + /// # let weights_path = &home.as_path().join("model.ot"); + /// let device = Device::cuda_if_available(); + /// let generate_config = GenerateConfig { + /// max_length: 30, + /// do_sample: true, + /// num_beams: 5, + /// temperature: 1.1, + /// num_return_sequences: 3, + /// ..Default::default() + /// }; + /// let gpt2_generator = GPT2Generator::new(generate_config)?; + /// let tokenizer = gpt2_generator.get_tokenizer(); + /// tokenizer.tokenize("Hello, world!"); + /// # Ok(()) + /// # } + /// ``` + fn get_tokenizer(&self) -> &TokenizerOption { + self._get_tokenizer() + } } #[derive(Debug)] diff --git a/src/pipelines/text_generation.rs b/src/pipelines/text_generation.rs index 844e208..835d807 100644 --- a/src/pipelines/text_generation.rs +++ b/src/pipelines/text_generation.rs @@ -229,11 +229,11 @@ impl TextGenerationOption { /// Interface method to access tokenizer pub fn get_tokenizer(&self) -> &TokenizerOption { match self { - Self::GPT(model_ref) => model_ref.get_tokenizer(), - Self::GPT2(model_ref) => model_ref.get_tokenizer(), - Self::GPTNeo(model_ref) => model_ref.get_tokenizer(), - Self::XLNet(model_ref) => model_ref.get_tokenizer(), - Self::Reformer(model_ref) => model_ref.get_tokenizer(), + Self::GPT(model_ref) => model_ref._get_tokenizer(), + Self::GPT2(model_ref) => model_ref._get_tokenizer(), + Self::GPTNeo(model_ref) => model_ref._get_tokenizer(), + Self::XLNet(model_ref) => model_ref._get_tokenizer(), + Self::Reformer(model_ref) => model_ref._get_tokenizer(), } } diff --git a/src/prophetnet/mod.rs b/src/prophetnet/mod.rs index eb4f166..d130923 100644 --- a/src/prophetnet/mod.rs +++ b/src/prophetnet/mod.rs @@ -7,7 +7,7 @@ //! //! # Model set-up and pre-trained weights loading //! -//! A full working example (generation) is provided in `examples/summarization_prophetnet`, run with `cargo run --example summarization_prophetnet`. +//! A full working example (summarization) is provided in `examples/summarization_prophetnet`, run with `cargo run --example summarization_prophetnet`. //! All models expect the following resources: //! - Configuration file expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers) //! - Model weights are expected to have a structure and parameter names following the [Transformers library](https://github.com/huggingface/transformers). A conversion using the Python utility scripts is required to convert the `.bin` weights to the `.ot` format. diff --git a/src/prophetnet/prophetnet_model.rs b/src/prophetnet/prophetnet_model.rs index 911e7e9..1c52fd8 100644 --- a/src/prophetnet/prophetnet_model.rs +++ b/src/prophetnet/prophetnet_model.rs @@ -993,7 +993,7 @@ impl fn get_model(&self) -> &ProphetNetForConditionalGeneration { &self.model } - fn get_tokenizer(&self) -> &TokenizerOption { + fn _get_tokenizer(&self) -> &TokenizerOption { &self.tokenizer } fn get_var_store(&self) -> &nn::VarStore { @@ -1069,7 +1069,7 @@ impl where S: AsRef<[&'a str]>, { - let tokens = self.get_tokenizer().encode_list( + let tokens = self._get_tokenizer().encode_list( prompt_text.as_ref(), max_len as usize, &TruncationStrategy::LongestFirst, @@ -1085,7 +1085,7 @@ impl let pad_token = match pad_token_id { Some(value) => value, None => self - .get_tokenizer() + ._get_tokenizer() .convert_tokens_to_ids(&[ProphetNetVocab::unknown_value()])[0], }; diff --git a/src/reformer/reformer_model.rs b/src/reformer/reformer_model.rs index dadb611..a7d7f7c 100644 --- a/src/reformer/reformer_model.rs +++ b/src/reformer/reformer_model.rs @@ -1114,7 +1114,7 @@ impl PrivateLanguageGenerator &ReformerModelWithLMHead { &self.model } - fn get_tokenizer(&self) -> &TokenizerOption { + fn _get_tokenizer(&self) -> &TokenizerOption { &self.tokenizer } fn get_var_store(&self) -> &nn::VarStore { diff --git a/src/roberta/mod.rs b/src/roberta/mod.rs index 5dfa5a2..1742a39 100644 --- a/src/roberta/mod.rs +++ b/src/roberta/mod.rs @@ -10,7 +10,6 @@ //! //! # Model set-up and pre-trained weights loading //! -//! A full working example is provided in `examples/roberta.rs`, run with `cargo run --example roberta`. //! The example below illustrate a Masked language model example, the structure is similar for other models. //! All models expect the following resources: //! - Configuration file expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers) diff --git a/src/t5/mod.rs b/src/t5/mod.rs index cfbe5ef..99f4b46 100644 --- a/src/t5/mod.rs +++ b/src/t5/mod.rs @@ -6,7 +6,7 @@ //! //! # Model set-up and pre-trained weights loading //! -//! A full working example (translation) is provided in `examples/t5`, run with `cargo run --example t5`. +//! A full working example (summarization) is provided in `examples/summarization_t5`, run with `cargo run --example summarization_t5`. //! All models expect the following resources: //! - Configuration file expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers) //! - Model weights are expected to have a structure and parameter names following the [Transformers library](https://github.com/huggingface/transformers). A conversion using the Python utility scripts is required to convert the `.bin` weights to the `.ot` format. diff --git a/src/t5/t5_model.rs b/src/t5/t5_model.rs index e1790db..7c06807 100644 --- a/src/t5/t5_model.rs +++ b/src/t5/t5_model.rs @@ -778,7 +778,7 @@ impl PrivateLanguageGenerator fn get_model(&self) -> &T5ForConditionalGeneration { &self.model } - fn get_tokenizer(&self) -> &TokenizerOption { + fn _get_tokenizer(&self) -> &TokenizerOption { &self.tokenizer } fn get_var_store(&self) -> &nn::VarStore { @@ -850,7 +850,7 @@ impl PrivateLanguageGenerator where S: AsRef<[&'a str]>, { - let tokens = self.get_tokenizer().encode_list( + let tokens = self._get_tokenizer().encode_list( prompt_text.as_ref(), max_len as usize, &TruncationStrategy::LongestFirst, @@ -865,7 +865,7 @@ impl PrivateLanguageGenerator let pad_token = match pad_token_id { Some(value) => value, - None => self.get_tokenizer().get_unk_id(), + None => self._get_tokenizer().get_unk_id(), }; let token_ids = token_ids diff --git a/src/xlnet/xlnet_model.rs b/src/xlnet/xlnet_model.rs index 1fd9fa2..f2e7492 100644 --- a/src/xlnet/xlnet_model.rs +++ b/src/xlnet/xlnet_model.rs @@ -1615,7 +1615,7 @@ impl PrivateLanguageGenerator for fn get_model(&self) -> &XLNetLMHeadModel { &self.model } - fn get_tokenizer(&self) -> &TokenizerOption { + fn _get_tokenizer(&self) -> &TokenizerOption { &self.tokenizer } fn get_var_store(&self) -> &nn::VarStore { diff --git a/tests/bart.rs b/tests/bart.rs index 8fbfb2f..9562b00 100644 --- a/tests/bart.rs +++ b/tests/bart.rs @@ -77,7 +77,6 @@ fn bart_lm_model() -> anyhow::Result<()> { #[test] fn bart_summarization_greedy() -> anyhow::Result<()> { - // Set-up masked LM model let config_resource = Resource::Remote(RemoteResource::from_pretrained( BartConfigResources::DISTILBART_CNN_6_6, )); @@ -139,7 +138,6 @@ about exoplanets like K2-18b."]; #[test] fn bart_summarization_beam_search() -> anyhow::Result<()> { - // Set-up masked LM model let config_resource = Resource::Remote(RemoteResource::from_pretrained( BartConfigResources::DISTILBART_CNN_6_6, )); @@ -202,7 +200,7 @@ about exoplanets like K2-18b."]; #[test] #[cfg_attr(not(feature = "all-tests"), ignore)] fn bart_zero_shot_classification() -> anyhow::Result<()> { - // Set-up model model + // Set-up model let zero_shot_config = ZeroShotClassificationConfig { device: Device::Cpu, ..Default::default() @@ -235,7 +233,7 @@ fn bart_zero_shot_classification() -> anyhow::Result<()> { #[test] #[cfg_attr(not(feature = "all-tests"), ignore)] fn bart_zero_shot_classification_multilabel() -> anyhow::Result<()> { - // Set-up model model + // Set-up model let zero_shot_config = ZeroShotClassificationConfig { device: Device::Cpu, ..Default::default() diff --git a/tests/distilgpt2.rs b/tests/distilgpt2.rs index 8a50c19..e1ed2b3 100644 --- a/tests/distilgpt2.rs +++ b/tests/distilgpt2.rs @@ -28,7 +28,7 @@ fn distilgpt2_lm_model() -> anyhow::Result<()> { let merges_path = merges_resource.get_local_path()?; let weights_path = weights_resource.get_local_path()?; - // Set-up masked LM model + // Set-up model let device = Device::Cpu; let mut vs = nn::VarStore::new(device); let tokenizer: Gpt2Tokenizer = Gpt2Tokenizer::from_file( diff --git a/tests/gpt2.rs b/tests/gpt2.rs index 67d3944..3066319 100644 --- a/tests/gpt2.rs +++ b/tests/gpt2.rs @@ -426,6 +426,7 @@ fn gpt2_prefix_allowed_token_greedy() -> anyhow::Result<()> { None, None, None, + None, Some(&force_one_paragraph), ); @@ -490,6 +491,7 @@ fn gpt2_prefix_allowed_token_beam_search() -> anyhow::Result<()> { None, None, None, + None, Some(&force_one_paragraph), ); diff --git a/tests/gpt_neo.rs b/tests/gpt_neo.rs index 85252f1..8f69619 100644 --- a/tests/gpt_neo.rs +++ b/tests/gpt_neo.rs @@ -29,7 +29,7 @@ fn gpt_neo_lm() -> anyhow::Result<()> { let merges_path = merges_resource.get_local_path()?; let weights_path = weights_resource.get_local_path()?; - // Set-up masked LM model + // Set-up model let device = Device::Cpu; let mut vs = nn::VarStore::new(device); let tokenizer: Gpt2Tokenizer = Gpt2Tokenizer::from_file( @@ -122,7 +122,7 @@ fn test_generation_gpt_neo() -> anyhow::Result<()> { GptNeoModelResources::GPT_NEO_125M, )); - // Set-up translation model + // Set-up model let generation_config = TextGenerationConfig { model_type: ModelType::GPTNeo, model_resource, diff --git a/tests/mbart.rs b/tests/mbart.rs new file mode 100644 index 0000000..f9fd7ee --- /dev/null +++ b/tests/mbart.rs @@ -0,0 +1,110 @@ +use rust_bert::mbart::{ + MBartConfig, MBartConfigResources, MBartGenerator, MBartModel, MBartModelResources, + MBartVocabResources, +}; +use rust_bert::pipelines::generation_utils::{GenerateConfig, LanguageGenerator}; +use rust_bert::pipelines::summarization::{SummarizationConfig, SummarizationModel}; +use rust_bert::resources::{RemoteResource, Resource}; +use rust_bert::Config; +use rust_tokenizers::tokenizer::{MBart50Tokenizer, Tokenizer, TruncationStrategy}; +use tch::{nn, Device, Tensor}; + +#[test] +fn mbart_lm_model() -> anyhow::Result<()> { + // Resources paths + let config_resource = Resource::Remote(RemoteResource::from_pretrained( + MBartConfigResources::MBART50_MANY_TO_MANY, + )); + let vocab_resource = Resource::Remote(RemoteResource::from_pretrained( + MBartVocabResources::MBART50_MANY_TO_MANY, + )); + let weights_resource = Resource::Remote(RemoteResource::from_pretrained( + MBartModelResources::MBART50_MANY_TO_MANY, + )); + let config_path = config_resource.get_local_path()?; + let vocab_path = vocab_resource.get_local_path()?; + let weights_path = weights_resource.get_local_path()?; + + // Set-up masked LM model + let device = Device::Cpu; + let mut vs = nn::VarStore::new(device); + let tokenizer = MBart50Tokenizer::from_file(vocab_path.to_str().unwrap(), false)?; + let config = MBartConfig::from_file(config_path); + let mbart_model = MBartModel::new(&vs.root() / "model", &config); + vs.load(weights_path)?; + + // Define input + let input = ["One two three four"]; + let tokenized_input = tokenizer.encode_list(&input, 128, &TruncationStrategy::LongestFirst, 0); + let max_len = tokenized_input + .iter() + .map(|input| input.token_ids.len()) + .max() + .unwrap(); + let tokenized_input = tokenized_input + .iter() + .map(|input| input.token_ids.clone()) + .map(|mut input| { + input.extend(vec![0; max_len - input.len()]); + input + }) + .map(|input| Tensor::of_slice(&(input))) + .collect::>(); + let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device); + + // Forward pass + let model_output = + mbart_model.forward_t(Some(&input_tensor), None, None, None, None, None, false); + assert_eq!(model_output.decoder_output.size(), vec!(1, 5, 1024)); + assert_eq!( + model_output.encoder_hidden_state.unwrap().size(), + vec!(1, 5, 1024) + ); + assert!((model_output.decoder_output.double_value(&[0, 0, 0]) - -0.8936).abs() < 1e-4); + Ok(()) +} + +#[test] +fn mbart_translation() -> anyhow::Result<()> { + // Resources paths + let generate_config = GenerateConfig { + max_length: 56, + model_resource: Resource::Remote(RemoteResource::from_pretrained( + MBartModelResources::MBART50_MANY_TO_MANY, + )), + config_resource: Resource::Remote(RemoteResource::from_pretrained( + MBartConfigResources::MBART50_MANY_TO_MANY, + )), + vocab_resource: Resource::Remote(RemoteResource::from_pretrained( + MBartVocabResources::MBART50_MANY_TO_MANY, + )), + merges_resource: Resource::Remote(RemoteResource::from_pretrained( + MBartVocabResources::MBART50_MANY_TO_MANY, + )), + do_sample: false, + num_beams: 3, + ..Default::default() + }; + let model = MBartGenerator::new(generate_config)?; + + let input_context = "en_XX The quick brown fox jumps over the lazy dog."; + let target_language = model.get_tokenizer().convert_tokens_to_ids(["de_DE"])[0]; + + let output = model.generate( + Some(&[input_context]), + None, + None, + None, + None, + target_language, + None, + ); + + assert_eq!(output.len(), 1); + assert_eq!( + output[0], + "de_DE Der schnelle braune Fuchs springt über den faulen Hund." + ); + + Ok(()) +} diff --git a/tests/openai_gpt.rs b/tests/openai_gpt.rs index e26a60a..442bd2d 100644 --- a/tests/openai_gpt.rs +++ b/tests/openai_gpt.rs @@ -117,7 +117,7 @@ fn openai_gpt_generation_greedy() -> anyhow::Result<()> { OpenAiGptModelResources::GPT, )); - // Set-up masked LM model + // Set-up model let generate_config = TextGenerationConfig { model_type: ModelType::OpenAiGpt, model_resource, @@ -159,7 +159,7 @@ fn openai_gpt_generation_beam_search() -> anyhow::Result<()> { OpenAiGptModelResources::GPT, )); - // Set-up masked LM model + // Set-up model let generate_config = TextGenerationConfig { model_type: ModelType::OpenAiGpt, model_resource, @@ -211,7 +211,7 @@ fn openai_gpt_generation_beam_search_multiple_prompts_without_padding() -> anyho OpenAiGptModelResources::GPT, )); - // Set-up masked LM model + // Set-up model let generate_config = TextGenerationConfig { model_type: ModelType::OpenAiGpt, model_resource, @@ -279,7 +279,7 @@ fn openai_gpt_generation_beam_search_multiple_prompts_with_padding() -> anyhow:: OpenAiGptModelResources::GPT, )); - // Set-up masked LM model + // Set-up model let generate_config = TextGenerationConfig { model_type: ModelType::OpenAiGpt, model_resource, diff --git a/tests/pegasus.rs b/tests/pegasus.rs index 9aa5eb4..4784f7e 100644 --- a/tests/pegasus.rs +++ b/tests/pegasus.rs @@ -7,7 +7,7 @@ use tch::Device; #[test] fn pegasus_summarization_greedy() -> anyhow::Result<()> { - // Set-up masked LM model + // Set-up model let config_resource = Resource::Remote(RemoteResource::from_pretrained( PegasusConfigResources::CNN_DAILYMAIL, )); diff --git a/tests/prophetnet.rs b/tests/prophetnet.rs index c7091f4..a05573c 100644 --- a/tests/prophetnet.rs +++ b/tests/prophetnet.rs @@ -9,7 +9,7 @@ use tch::Device; #[test] fn prophetnet_summarization_greedy() -> anyhow::Result<()> { - // Set-up masked LM model + // Set-up model let config_resource = Resource::Remote(RemoteResource::from_pretrained( ProphetNetConfigResources::PROPHETNET_LARGE_CNN_DM, ));