Updated documentation, cleaned examples, added integration tests

This commit is contained in:
Guillaume B 2021-06-06 13:01:33 +02:00
parent 698e7143e8
commit 5907b7d954
65 changed files with 341 additions and 1540 deletions

View File

@ -3,7 +3,9 @@ All notable changes to this project will be documented in this file. The format
## [Unreleased]
## Added
- (BREAKING) Support for `prefix_allowed_tokens_fn` arguments for generation, allowing users to control the generation via custom functions
- (BREAKING) Support for `prefix_allowed_tokens_fn` argument for generation, allowing users to control the generation via custom functions
- (BREAKING) Support for `forced_bos_token_id` argument for generation, allowing users to force a given BOS token for generation (useful for MBart/M2M-class models)
- Addition of the MBart Language model and support for text generation / direct translation between 50 language
## [0.15.1] - 2021-06-01
### Fixed

View File

@ -46,8 +46,9 @@ RoBERTa|✅|✅|✅| | | |✅|
GPT| | | |✅ | | | |
GPT2| | | |✅ | | | |
GPT-Neo| | | |✅ | | | |
BART|✅| | |✅ |✅| | |
BART|✅| | |✅ |✅| | |
Marian| | | | | |✅| |
MBart|✅| | |✅ | | | |
Electra | |✅| | | | |✅|
ALBERT |✅|✅|✅| | | |✅|
T5 | | | |✅ |✅|✅| |

View File

@ -1,97 +0,0 @@
// Copyright 2018 Google AI and Google Brain team.
// Copyright 2020-present, the HuggingFace Inc. team.
// Copyright 2020 Guillaume Becquin
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
extern crate anyhow;
use rust_bert::albert::{
AlbertConfig, AlbertConfigResources, AlbertForMaskedLM, AlbertModelResources,
AlbertVocabResources,
};
use rust_bert::resources::{RemoteResource, Resource};
use rust_bert::Config;
use rust_tokenizers::tokenizer::{AlbertTokenizer, Tokenizer, TruncationStrategy};
use rust_tokenizers::vocab::Vocab;
use tch::{nn, no_grad, Device, Tensor};
fn main() -> anyhow::Result<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
AlbertConfigResources::ALBERT_BASE_V2,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
AlbertVocabResources::ALBERT_BASE_V2,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
AlbertModelResources::ALBERT_BASE_V2,
));
let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?;
let weights_path = weights_resource.get_local_path()?;
// Set-up masked LM model
let device = Device::Cpu;
let mut vs = nn::VarStore::new(device);
let tokenizer: AlbertTokenizer =
AlbertTokenizer::from_file(vocab_path.to_str().unwrap(), true, false)?;
let config = AlbertConfig::from_file(config_path);
let albert_model = AlbertForMaskedLM::new(&vs.root(), &config);
vs.load(weights_path)?;
// Define input
let input = [
"Looks like one [MASK] is missing",
"It was a very nice and [MASK] day",
];
let tokenized_input = tokenizer.encode_list(&input, 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let tokenized_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]);
input
})
.map(|input| Tensor::of_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
let model_output =
no_grad(|| albert_model.forward_t(Some(input_tensor), None, None, None, None, false));
println!(
"{:?}",
model_output.prediction_scores.double_value(&[0, 0, 0])
);
// Print masked tokens
let index_1 = model_output
.prediction_scores
.get(0)
.get(4)
.argmax(0, false);
let index_2 = model_output
.prediction_scores
.get(1)
.get(7)
.argmax(0, false);
let word_1 = tokenizer.vocab().id_to_token(&index_1.int64_value(&[]));
let word_2 = tokenizer.vocab().id_to_token(&index_2.int64_value(&[]));
println!("{} - {}", &index_1.int64_value(&[]), word_1); // Outputs "_them" : "Looks like one [them] is missing"
println!("{} - {}", &index_2.int64_value(&[]), word_2); // Outputs "_enjoyable" : "It was a very nice and [enjoyable] day"
Ok(())
}

View File

@ -1,81 +0,0 @@
// Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc.
// Copyright 2019 Guillaume Becquin
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
extern crate anyhow;
use rust_bert::bart::{
BartConfig, BartConfigResources, BartMergesResources, BartModel, BartModelResources,
BartVocabResources,
};
use rust_bert::resources::{RemoteResource, Resource};
use rust_bert::Config;
use rust_tokenizers::tokenizer::{RobertaTokenizer, Tokenizer, TruncationStrategy};
use tch::{nn, no_grad, Device, Tensor};
fn main() -> anyhow::Result<()> {
// Resources paths
let config_resource =
Resource::Remote(RemoteResource::from_pretrained(BartConfigResources::BART));
let vocab_resource =
Resource::Remote(RemoteResource::from_pretrained(BartVocabResources::BART));
let merges_resource =
Resource::Remote(RemoteResource::from_pretrained(BartMergesResources::BART));
let weights_resource =
Resource::Remote(RemoteResource::from_pretrained(BartModelResources::BART));
let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?;
let merges_path = merges_resource.get_local_path()?;
let weights_path = weights_resource.get_local_path()?;
// Set-up masked LM model
let device = Device::cuda_if_available();
let mut vs = nn::VarStore::new(device);
let tokenizer = RobertaTokenizer::from_file(
vocab_path.to_str().unwrap(),
merges_path.to_str().unwrap(),
false,
false,
)?;
let config = BartConfig::from_file(config_path);
let bart_model = BartModel::new(&vs.root(), &config);
vs.load(weights_path)?;
// Define input
let input = ["One two three four"];
let tokenized_input = tokenizer.encode_list(input, 1024, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let tokenized_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]);
input
})
.map(|input| Tensor::of_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
let model_output =
no_grad(|| bart_model.forward_t(Some(&input_tensor), None, None, None, None, None, false));
// Print masked tokens
println!("{:?}", model_output.encoder_hidden_state);
println!("{:?}", model_output.decoder_output);
println!("{:?}", model_output.decoder_output.double_value(&[0, 0, 0]));
Ok(())
}

View File

@ -1,102 +0,0 @@
// Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc.
// Copyright 2019 Guillaume Becquin
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
extern crate anyhow;
use rust_bert::distilbert::{
DistilBertConfig, DistilBertConfigResources, DistilBertModelMaskedLM, DistilBertModelResources,
DistilBertVocabResources,
};
use rust_bert::resources::{RemoteResource, Resource};
use rust_bert::Config;
use rust_tokenizers::tokenizer::{BertTokenizer, MultiThreadedTokenizer, TruncationStrategy};
use rust_tokenizers::vocab::Vocab;
use tch::{nn, no_grad, Device, Tensor};
fn main() -> anyhow::Result<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
DistilBertConfigResources::DISTIL_BERT,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
DistilBertVocabResources::DISTIL_BERT,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
DistilBertModelResources::DISTIL_BERT,
));
let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?;
let weights_path = weights_resource.get_local_path()?;
// Set-up masked LM model
let device = Device::Cpu;
let mut vs = nn::VarStore::new(device);
let tokenizer: BertTokenizer =
BertTokenizer::from_file(vocab_path.to_str().unwrap(), true, true)?;
let config = DistilBertConfig::from_file(config_path);
let distil_bert_model = DistilBertModelMaskedLM::new(&vs.root(), &config);
vs.load(weights_path)?;
// Define input
let input = [
"Looks like one thing is missing",
"It\'s like comparing oranges to apples",
];
let tokenized_input = tokenizer.encode_list(input, 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let mut tokenized_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]);
input
})
.collect::<Vec<_>>();
// Masking the token [thing] of sentence 1 and [oranges] of sentence 2
tokenized_input[0][4] = 103;
tokenized_input[1][6] = 103;
let tokenized_input = tokenized_input
.iter()
.map(|input| Tensor::of_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
let model_output = no_grad(|| {
distil_bert_model
.forward_t(Some(input_tensor), None, None, false)
.unwrap()
});
// Print masked tokens
let index_1 = model_output
.prediction_scores
.get(0)
.get(4)
.argmax(0, false);
let index_2 = model_output
.prediction_scores
.get(1)
.get(6)
.argmax(0, false);
let word_1 = tokenizer.vocab().id_to_token(&index_1.int64_value(&[]));
let word_2 = tokenizer.vocab().id_to_token(&index_2.int64_value(&[]));
println!("{}", word_1); // Outputs "person" : "Looks like one [person] is missing"
println!("{}", word_2); // Outputs "pear" : "It\'s like comparing [pear] to apples"
Ok(())
}

View File

@ -1,406 +0,0 @@
extern crate anyhow;
use rust_bert::albert::{AlbertConfigResources, AlbertModelResources, AlbertVocabResources};
use rust_bert::bart::{
BartConfigResources, BartMergesResources, BartModelResources, BartVocabResources,
};
use rust_bert::bert::{BertConfigResources, BertModelResources, BertVocabResources};
use rust_bert::distilbert::{
DistilBertConfigResources, DistilBertModelResources, DistilBertVocabResources,
};
use rust_bert::electra::{ElectraConfigResources, ElectraModelResources, ElectraVocabResources};
use rust_bert::gpt2::{
Gpt2ConfigResources, Gpt2MergesResources, Gpt2ModelResources, Gpt2VocabResources,
};
use rust_bert::openai_gpt::{
OpenAiGptConfigResources, OpenAiGptMergesResources, OpenAiGptModelResources,
OpenAiGptVocabResources,
};
use rust_bert::resources::{RemoteResource, Resource};
use rust_bert::roberta::{
RobertaConfigResources, RobertaMergesResources, RobertaModelResources, RobertaVocabResources,
};
use rust_bert::t5::{T5ConfigResources, T5ModelResources, T5VocabResources};
use rust_bert::xlnet::{XLNetConfigResources, XLNetModelResources, XLNetVocabResources};
/// This example downloads and caches all dependencies used in model tests. This allows for safe
/// multi threaded testing (two test using the same resource would otherwise download the file to
/// the same location).
fn download_distil_gpt2() -> anyhow::Result<()> {
// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
Gpt2ConfigResources::DISTIL_GPT2,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
Gpt2VocabResources::DISTIL_GPT2,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
Gpt2MergesResources::DISTIL_GPT2,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
Gpt2ModelResources::DISTIL_GPT2,
));
let _ = config_resource.get_local_path()?;
let _ = vocab_resource.get_local_path()?;
let _ = merges_resource.get_local_path()?;
let _ = weights_resource.get_local_path()?;
Ok(())
}
fn download_distilbert_sst2() -> anyhow::Result<()> {
// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
DistilBertModelResources::DISTIL_BERT_SST2,
));
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
DistilBertConfigResources::DISTIL_BERT_SST2,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
DistilBertVocabResources::DISTIL_BERT_SST2,
));
let _ = config_resource.get_local_path()?;
let _ = vocab_resource.get_local_path()?;
let _ = weights_resource.get_local_path()?;
Ok(())
}
fn download_distilbert_qa() -> anyhow::Result<()> {
// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
DistilBertModelResources::DISTIL_BERT_SQUAD,
));
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
DistilBertConfigResources::DISTIL_BERT_SQUAD,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
DistilBertVocabResources::DISTIL_BERT_SQUAD,
));
let _ = config_resource.get_local_path()?;
let _ = vocab_resource.get_local_path()?;
let _ = weights_resource.get_local_path()?;
Ok(())
}
fn download_distilbert() -> anyhow::Result<()> {
// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
DistilBertModelResources::DISTIL_BERT,
));
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
DistilBertConfigResources::DISTIL_BERT,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
DistilBertVocabResources::DISTIL_BERT,
));
let _ = config_resource.get_local_path()?;
let _ = vocab_resource.get_local_path()?;
let _ = weights_resource.get_local_path()?;
Ok(())
}
fn download_gpt2() -> anyhow::Result<()> {
// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2. Modified with conversion to C-array format.
let config_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2));
let vocab_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2));
let merges_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
let weights_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
let _ = config_resource.get_local_path()?;
let _ = vocab_resource.get_local_path()?;
let _ = merges_resource.get_local_path()?;
let _ = weights_resource.get_local_path()?;
Ok(())
}
fn download_gpt() -> anyhow::Result<()> {
// Shared under MIT license by the OpenAI team at https://github.com/openai/finetune-transformer-lm. Modified with conversion to C-array format.
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
OpenAiGptConfigResources::GPT,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
OpenAiGptVocabResources::GPT,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
OpenAiGptMergesResources::GPT,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
OpenAiGptModelResources::GPT,
));
let _ = config_resource.get_local_path()?;
let _ = vocab_resource.get_local_path()?;
let _ = merges_resource.get_local_path()?;
let _ = weights_resource.get_local_path()?;
Ok(())
}
fn download_roberta() -> anyhow::Result<()> {
// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
RobertaConfigResources::ROBERTA,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
RobertaVocabResources::ROBERTA,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
RobertaMergesResources::ROBERTA,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
RobertaModelResources::ROBERTA,
));
let _ = config_resource.get_local_path()?;
let _ = vocab_resource.get_local_path()?;
let _ = merges_resource.get_local_path()?;
let _ = weights_resource.get_local_path()?;
Ok(())
}
fn download_bert() -> anyhow::Result<()> {
// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/bert. Modified with conversion to C-array format.
let config_resource =
Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT));
let vocab_resource =
Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT));
let weights_resource =
Resource::Remote(RemoteResource::from_pretrained(BertModelResources::BERT));
let _ = config_resource.get_local_path()?;
let _ = vocab_resource.get_local_path()?;
let _ = weights_resource.get_local_path()?;
Ok(())
}
fn download_bert_ner() -> anyhow::Result<()> {
// Shared under MIT license by the MDZ Digital Library team at the Bavarian State Library at https://github.com/dbmdz/berts. Modified with conversion to C-array format.
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
BertConfigResources::BERT_NER,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
BertVocabResources::BERT_NER,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
BertModelResources::BERT_NER,
));
let _ = config_resource.get_local_path()?;
let _ = vocab_resource.get_local_path()?;
let _ = weights_resource.get_local_path()?;
Ok(())
}
fn download_bart() -> anyhow::Result<()> {
// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
let config_resource =
Resource::Remote(RemoteResource::from_pretrained(BartConfigResources::BART));
let vocab_resource =
Resource::Remote(RemoteResource::from_pretrained(BartVocabResources::BART));
let merges_resource =
Resource::Remote(RemoteResource::from_pretrained(BartMergesResources::BART));
let weights_resource =
Resource::Remote(RemoteResource::from_pretrained(BartModelResources::BART));
let _ = config_resource.get_local_path()?;
let _ = vocab_resource.get_local_path()?;
let _ = merges_resource.get_local_path()?;
let _ = weights_resource.get_local_path()?;
Ok(())
}
fn download_bart_cnn() -> anyhow::Result<()> {
// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
BartConfigResources::BART_CNN,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
BartVocabResources::BART_CNN,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
BartMergesResources::BART_CNN,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
BartModelResources::BART_CNN,
));
let _ = config_resource.get_local_path()?;
let _ = vocab_resource.get_local_path()?;
let _ = merges_resource.get_local_path()?;
let _ = weights_resource.get_local_path()?;
Ok(())
}
fn download_electra_generator() -> anyhow::Result<()> {
// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/electra. Modified with conversion to C-array format.
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
ElectraConfigResources::BASE_GENERATOR,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
ElectraVocabResources::BASE_GENERATOR,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
ElectraModelResources::BASE_GENERATOR,
));
let _ = config_resource.get_local_path()?;
let _ = vocab_resource.get_local_path()?;
let _ = weights_resource.get_local_path()?;
Ok(())
}
fn download_electra_discriminator() -> anyhow::Result<()> {
// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/electra. Modified with conversion to C-array format.
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
ElectraConfigResources::BASE_DISCRIMINATOR,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
ElectraVocabResources::BASE_DISCRIMINATOR,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
ElectraModelResources::BASE_DISCRIMINATOR,
));
let _ = config_resource.get_local_path()?;
let _ = vocab_resource.get_local_path()?;
let _ = weights_resource.get_local_path()?;
Ok(())
}
fn download_albert_base_v2() -> anyhow::Result<()> {
// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/ALBERT. Modified with conversion to C-array format.
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
AlbertConfigResources::ALBERT_BASE_V2,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
AlbertVocabResources::ALBERT_BASE_V2,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
AlbertModelResources::ALBERT_BASE_V2,
));
let _ = config_resource.get_local_path()?;
let _ = vocab_resource.get_local_path()?;
let _ = weights_resource.get_local_path()?;
Ok(())
}
fn _download_dialogpt() -> anyhow::Result<()> {
// Shared under MIT license by the Microsoft team at https://huggingface.co/microsoft/DialoGPT-medium. Modified with conversion to C-array format.
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
Gpt2ConfigResources::DIALOGPT_MEDIUM,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
Gpt2VocabResources::DIALOGPT_MEDIUM,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
Gpt2MergesResources::DIALOGPT_MEDIUM,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
Gpt2ModelResources::DIALOGPT_MEDIUM,
));
let _ = config_resource.get_local_path()?;
let _ = vocab_resource.get_local_path()?;
let _ = merges_resource.get_local_path()?;
let _ = weights_resource.get_local_path()?;
Ok(())
}
fn download_t5_small() -> anyhow::Result<()> {
// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/text-to-text-transfer-transformer.
let config_resource =
Resource::Remote(RemoteResource::from_pretrained(T5ConfigResources::T5_SMALL));
let vocab_resource =
Resource::Remote(RemoteResource::from_pretrained(T5VocabResources::T5_SMALL));
let weights_resource =
Resource::Remote(RemoteResource::from_pretrained(T5ModelResources::T5_SMALL));
let _ = config_resource.get_local_path()?;
let _ = vocab_resource.get_local_path()?;
let _ = weights_resource.get_local_path()?;
Ok(())
}
fn download_roberta_qa() -> anyhow::Result<()> {
// Shared under Apache 2.0 license by [deepset](https://deepset.ai) at https://huggingface.co/deepset/roberta-base-squad2.
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
RobertaConfigResources::ROBERTA_QA,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
RobertaVocabResources::ROBERTA_QA,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
RobertaModelResources::ROBERTA_QA,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
RobertaMergesResources::ROBERTA_QA,
));
let _ = config_resource.get_local_path()?;
let _ = vocab_resource.get_local_path()?;
let _ = merges_resource.get_local_path()?;
let _ = weights_resource.get_local_path()?;
Ok(())
}
fn download_bert_qa() -> anyhow::Result<()> {
// Shared under Apache 2.0 license by [deepset](https://deepset.ai) at https://huggingface.co/deepset/roberta-base-squad2.
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
BertConfigResources::BERT_QA,
));
let vocab_resource =
Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT_QA));
let weights_resource =
Resource::Remote(RemoteResource::from_pretrained(BertModelResources::BERT_QA));
let _ = config_resource.get_local_path()?;
let _ = vocab_resource.get_local_path()?;
let _ = weights_resource.get_local_path()?;
Ok(())
}
fn download_xlm_roberta_ner_german() -> anyhow::Result<()> {
// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models.
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
RobertaConfigResources::XLM_ROBERTA_NER_DE,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
RobertaVocabResources::XLM_ROBERTA_NER_DE,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
RobertaModelResources::XLM_ROBERTA_NER_DE,
));
let _ = config_resource.get_local_path()?;
let _ = vocab_resource.get_local_path()?;
let _ = weights_resource.get_local_path()?;
Ok(())
}
fn download_xlnet_base_cased() -> anyhow::Result<()> {
// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models.
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
XLNetConfigResources::XLNET_BASE_CASED,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
XLNetVocabResources::XLNET_BASE_CASED,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
XLNetModelResources::XLNET_BASE_CASED,
));
let _ = config_resource.get_local_path()?;
let _ = vocab_resource.get_local_path()?;
let _ = weights_resource.get_local_path()?;
Ok(())
}
fn main() {
let _ = download_distil_gpt2();
let _ = download_distilbert_sst2();
let _ = download_distilbert_qa();
let _ = download_distilbert();
let _ = download_gpt2();
let _ = download_gpt();
let _ = download_roberta();
let _ = download_bert();
let _ = download_bert_ner();
let _ = download_bart();
let _ = download_bart_cnn();
let _ = download_electra_generator();
let _ = download_electra_discriminator();
let _ = download_albert_base_v2();
let _ = download_t5_small();
let _ = download_roberta_qa();
let _ = download_bert_qa();
let _ = download_xlm_roberta_ner_german();
let _ = download_xlnet_base_cased();
}

View File

@ -1,96 +0,0 @@
// Copyright 2020 The Google Research Authors.
// Copyright 2019-present, the HuggingFace Inc. team
// Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
// Copyright 2019 Guillaume Becquin
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use rust_bert::electra::{
ElectraConfig, ElectraConfigResources, ElectraDiscriminator, ElectraModelResources,
ElectraVocabResources,
};
use rust_bert::resources::{RemoteResource, Resource};
use rust_bert::Config;
use rust_tokenizers::tokenizer::{
BertTokenizer, MultiThreadedTokenizer, Tokenizer, TruncationStrategy,
};
use tch::{nn, no_grad, Device, Tensor};
fn main() -> anyhow::Result<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
ElectraConfigResources::BASE_DISCRIMINATOR,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
ElectraVocabResources::BASE_DISCRIMINATOR,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
ElectraModelResources::BASE_DISCRIMINATOR,
));
let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?;
let weights_path = weights_resource.get_local_path()?;
// Set-up masked LM model
let device = Device::Cpu;
let mut vs = nn::VarStore::new(device);
let tokenizer: BertTokenizer =
BertTokenizer::from_file(vocab_path.to_str().unwrap(), true, true)?;
let config = ElectraConfig::from_file(config_path);
let electra_model = ElectraDiscriminator::new(&vs.root(), &config);
vs.load(weights_path)?;
// Define input
let input = ["One Two Three Ten Five Six Seven Eight"];
let tokenized_input = MultiThreadedTokenizer::encode_list(
&tokenizer,
&input,
128,
&TruncationStrategy::LongestFirst,
0,
);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let encoded_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]);
input
})
.map(|input| Tensor::of_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(encoded_input.as_slice(), 0).to(device);
// Forward pass
let model_output =
no_grad(|| electra_model.forward_t(Some(input_tensor), None, None, None, None, false));
// Print model predictions
for (position, token) in tokenized_input[0].token_ids.iter().enumerate() {
let probability = model_output.probabilities.double_value(&[position as i64]);
let generated = if probability > 0.5 {
"generated"
} else {
"original"
};
println!(
"{:?}: {} ({:.1}%)",
tokenizer.decode([*token].to_vec(), false, false),
generated,
100f64 * probability
)
}
Ok(())
}

View File

@ -1,93 +0,0 @@
// Copyright 2020 The Google Research Authors.
// Copyright 2019-present, the HuggingFace Inc. team
// Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
// Copyright 2019 Guillaume Becquin
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use rust_bert::electra::{
ElectraConfig, ElectraConfigResources, ElectraForMaskedLM, ElectraModelResources,
ElectraVocabResources,
};
use rust_bert::resources::{RemoteResource, Resource};
use rust_bert::Config;
use rust_tokenizers::tokenizer::{BertTokenizer, MultiThreadedTokenizer, TruncationStrategy};
use rust_tokenizers::vocab::Vocab;
use tch::{nn, no_grad, Device, Tensor};
fn main() -> anyhow::Result<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
ElectraConfigResources::BASE_GENERATOR,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
ElectraVocabResources::BASE_GENERATOR,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
ElectraModelResources::BASE_GENERATOR,
));
let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?;
let weights_path = weights_resource.get_local_path()?;
// Set-up masked LM model
let device = Device::Cpu;
let mut vs = nn::VarStore::new(device);
let tokenizer: BertTokenizer =
BertTokenizer::from_file(vocab_path.to_str().unwrap(), true, true)?;
let config = ElectraConfig::from_file(config_path);
let electra_model = ElectraForMaskedLM::new(&vs.root(), &config);
vs.load(weights_path)?;
// Define input
let input = [
"Looks like one [MASK] is missing",
"It was a very nice and [MASK] day",
];
let tokenized_input = tokenizer.encode_list(&input, 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let tokenized_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]);
input
})
.map(|input| Tensor::of_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
let model_output =
no_grad(|| electra_model.forward_t(Some(input_tensor), None, None, None, None, false));
// Print masked tokens
let index_1 = model_output
.prediction_scores
.get(0)
.get(4)
.argmax(0, false);
let index_2 = model_output
.prediction_scores
.get(1)
.get(7)
.argmax(0, false);
let word_1 = tokenizer.vocab().id_to_token(&index_1.int64_value(&[]));
let word_2 = tokenizer.vocab().id_to_token(&index_2.int64_value(&[]));
println!("{}", word_1); // Outputs "thing" : "Looks like one [thing] is missing"
println!("{}", word_2); // Outputs "sunny" : "It was a very nice and [sunny] day"
Ok(())
}

View File

@ -16,7 +16,7 @@ use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::text_generation::{TextGenerationConfig, TextGenerationModel};
fn main() -> anyhow::Result<()> {
// Set-up masked LM model
// Set-up model
let generate_config = TextGenerationConfig {
model_type: ModelType::GPT2,
max_length: 30,

View File

@ -22,7 +22,7 @@ use rust_bert::reformer::{
use rust_bert::resources::{RemoteResource, Resource};
fn main() -> anyhow::Result<()> {
// Set-up masked LM model
// Set-up model
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
ReformerConfigResources::CRIME_AND_PUNISHMENT,

View File

@ -20,7 +20,7 @@ use rust_bert::resources::{RemoteResource, Resource};
use rust_bert::xlnet::{XLNetConfigResources, XLNetModelResources, XLNetVocabResources};
fn main() -> anyhow::Result<()> {
// Set-up masked LM model
// Set-up model
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
XLNetConfigResources::XLNET_BASE_CASED,

View File

@ -1,97 +0,0 @@
// Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc.
// Copyright 2019 Guillaume Becquin
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
extern crate anyhow;
use rust_bert::gpt2::{
GPT2LMHeadModel, Gpt2Config, Gpt2ConfigResources, Gpt2MergesResources, Gpt2ModelResources,
Gpt2VocabResources,
};
use rust_bert::pipelines::generation_utils::{Cache, LMHeadModel};
use rust_bert::resources::{RemoteResource, Resource};
use rust_bert::Config;
use rust_tokenizers::tokenizer::{Gpt2Tokenizer, Tokenizer, TruncationStrategy};
use tch::{nn, Device, Tensor};
fn main() -> anyhow::Result<()> {
// Resources set-up
let config_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2));
let vocab_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2));
let merges_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
let weights_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?;
let merges_path = merges_resource.get_local_path()?;
let weights_path = weights_resource.get_local_path()?;
// Set-up masked LM model
let device = Device::Cpu;
let mut vs = nn::VarStore::new(device);
let tokenizer: Gpt2Tokenizer = Gpt2Tokenizer::from_file(
vocab_path.to_str().unwrap(),
merges_path.to_str().unwrap(),
false,
)?;
let config = Gpt2Config::from_file(config_path);
let gpt2_model = GPT2LMHeadModel::new(&vs.root(), &config);
vs.load(weights_path)?;
// Define input
let input = ["One two three four five six seven eight nine ten eleven"];
let tokenized_input = tokenizer.encode_list(&input, 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let tokenized_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]);
input
})
.map(|input| Tensor::of_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
let model_output = gpt2_model
.forward_t(
&Some(input_tensor),
Cache::None,
&None,
&None,
&None,
&None,
None,
&None,
false,
)
.unwrap();
let next_word_id = model_output
.lm_logits
.get(0)
.get(-1)
.argmax(-1, true)
.int64_value(&[0]);
let next_word = tokenizer.decode(vec![next_word_id], true, true);
println!("Provided input: {}", input[0]);
println!("Next word: {}", next_word);
Ok(())
}

View File

@ -1,101 +0,0 @@
// Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc.
// Copyright 2019 Guillaume Becquin
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
extern crate anyhow;
use rust_bert::mobilebert::{
MobileBertConfig, MobileBertConfigResources, MobileBertForMaskedLM, MobileBertModelResources,
MobileBertVocabResources,
};
use rust_bert::resources::{RemoteResource, Resource};
use rust_bert::Config;
use rust_tokenizers::tokenizer::{BertTokenizer, MultiThreadedTokenizer, TruncationStrategy};
use rust_tokenizers::vocab::Vocab;
use tch::{nn, no_grad, Device, Tensor};
fn main() -> anyhow::Result<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
MobileBertConfigResources::MOBILEBERT_UNCASED,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
MobileBertVocabResources::MOBILEBERT_UNCASED,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
MobileBertModelResources::MOBILEBERT_UNCASED,
));
let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?;
let weights_path = weights_resource.get_local_path()?;
// Set-up masked LM model
let device = Device::Cpu;
let mut vs = nn::VarStore::new(device);
let tokenizer: BertTokenizer =
BertTokenizer::from_file(vocab_path.to_str().unwrap(), true, true)?;
let config = MobileBertConfig::from_file(config_path);
let mobilebert_model = MobileBertForMaskedLM::new(&vs.root(), &config);
vs.load(weights_path)?;
// Define input
let input = [
"Looks like one [MASK] is missing",
"It was a very nice and [MASK] day",
];
let tokenized_input = tokenizer.encode_list(&input, 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let tokenized_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]);
input
})
.map(|input| Tensor::of_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
let model_output =
no_grad(|| mobilebert_model.forward_t(Some(&input_tensor), None, None, None, None, false))?;
// Print masked tokens
let index_1 = model_output.logits.get(0).get(4).argmax(0, false);
let index_2 = model_output.logits.get(1).get(7).argmax(0, false);
let word_1 = tokenizer.vocab().id_to_token(&index_1.int64_value(&[]));
let word_2 = tokenizer.vocab().id_to_token(&index_2.int64_value(&[]));
println!("{}", word_1); // Outputs "thing" : "Looks like one [thing] is missing"
println!(
"score: {}",
model_output
.logits
.get(0)
.get(4)
.double_value(&[i64::from(&index_1)])
); // 10.0558
println!("{}", word_2); // Outputs "sunny" : "It was a very nice and [sunny] day"
println!(
"score: {}",
model_output
.logits
.get(1)
.get(7)
.double_value(&[i64::from(&index_2)])
); // 14.2708
Ok(())
}

View File

@ -1,102 +0,0 @@
// Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc.
// Copyright 2019 Guillaume Becquin
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
extern crate anyhow;
use rust_bert::gpt2::Gpt2Config;
use rust_bert::openai_gpt::{
OpenAIGPTLMHeadModel, OpenAiGptConfigResources, OpenAiGptMergesResources,
OpenAiGptModelResources, OpenAiGptVocabResources,
};
use rust_bert::pipelines::generation_utils::{Cache, LMHeadModel};
use rust_bert::resources::{RemoteResource, Resource};
use rust_bert::Config;
use rust_tokenizers::tokenizer::{OpenAiGptTokenizer, Tokenizer, TruncationStrategy};
use tch::{nn, Device, Tensor};
fn main() -> anyhow::Result<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
OpenAiGptConfigResources::GPT,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
OpenAiGptVocabResources::GPT,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
OpenAiGptMergesResources::GPT,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
OpenAiGptModelResources::GPT,
));
let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?;
let merges_path = merges_resource.get_local_path()?;
let weights_path = weights_resource.get_local_path()?;
// Set-up masked LM model
let device = Device::Cpu;
let mut vs = nn::VarStore::new(device);
let tokenizer = OpenAiGptTokenizer::from_file(
vocab_path.to_str().unwrap(),
merges_path.to_str().unwrap(),
true,
)?;
let config = Gpt2Config::from_file(config_path);
let openai_gpt = OpenAIGPTLMHeadModel::new(&vs.root(), &config);
vs.load(weights_path)?;
// Define input
let input = ["Wondering what the next word will"];
let tokenized_input = tokenizer.encode_list(&input, 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let tokenized_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]);
input
})
.map(|input| Tensor::of_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
let model_output = openai_gpt
.forward_t(
&Some(input_tensor),
Cache::None,
&None,
&None,
&None,
&None,
None,
&None,
false,
)
.unwrap();
let next_word_id = model_output
.lm_logits
.get(0)
.get(-1)
.argmax(-1, true)
.int64_value(&[0]);
let next_word = tokenizer.decode(vec![next_word_id], true, true);
println!("Provided input: {}", input[0]);
println!("Next word: {}", next_word);
Ok(())
}

View File

@ -1,74 +0,0 @@
// Copyright 2018 Google AI and Google Brain team.
// Copyright 2018 Carnegie Mellon University Authors.
// Copyright 2020-present, the HuggingFace Inc. team.
// Copyright 2020 Guillaume Becquin
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
extern crate anyhow;
use rust_bert::reformer::{
ReformerConfig, ReformerConfigResources, ReformerModelResources, ReformerModelWithLMHead,
ReformerVocabResources,
};
use rust_bert::resources::{RemoteResource, Resource};
use rust_bert::Config;
use rust_tokenizers::tokenizer::{MultiThreadedTokenizer, ReformerTokenizer, TruncationStrategy};
use tch::{nn, Device, Tensor};
fn main() -> anyhow::Result<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
ReformerConfigResources::CRIME_AND_PUNISHMENT,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
ReformerVocabResources::CRIME_AND_PUNISHMENT,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
ReformerModelResources::CRIME_AND_PUNISHMENT,
));
let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?;
let weights_path = weights_resource.get_local_path()?;
// Set-up masked LM model
let device = Device::cuda_if_available();
let mut vs = nn::VarStore::new(device);
let tokenizer = ReformerTokenizer::from_file(vocab_path.to_str().unwrap(), false)?;
let config = ReformerConfig::from_file(config_path);
let reformer_model = ReformerModelWithLMHead::new(&vs.root(), &config)?;
vs.load(weights_path)?;
// Define input
let input = ["One two three four five six seven eight nine ten eleven One two three four five six seven eight nine ten eleven One two three four five six seven eight nine ten eleven"];
let tokenized_input = tokenizer.encode_list(&input, 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let tokenized_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]);
input
})
.map(|input| Tensor::of_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
let _model_output =
reformer_model.forward_t(Some(&input_tensor), None, None, None, None, None, false)?;
_model_output.logits.print();
Ok(())
}

View File

@ -1,119 +0,0 @@
// Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc.
// Copyright 2019 Guillaume Becquin
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
extern crate anyhow;
use rust_bert::bert::BertConfig;
use rust_bert::resources::{RemoteResource, Resource};
use rust_bert::roberta::{
RobertaConfigResources, RobertaForMaskedLM, RobertaMergesResources, RobertaModelResources,
RobertaVocabResources,
};
use rust_bert::Config;
use rust_tokenizers::tokenizer::{RobertaTokenizer, Tokenizer, TruncationStrategy};
use rust_tokenizers::vocab::Vocab;
use tch::{nn, no_grad, Device, Tensor};
fn main() -> anyhow::Result<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
RobertaConfigResources::ROBERTA,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
RobertaVocabResources::ROBERTA,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
RobertaMergesResources::ROBERTA,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
RobertaModelResources::ROBERTA,
));
let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?;
let merges_path = merges_resource.get_local_path()?;
let weights_path = weights_resource.get_local_path()?;
// Set-up masked LM model
let device = Device::Cpu;
let mut vs = nn::VarStore::new(device);
let tokenizer: RobertaTokenizer = RobertaTokenizer::from_file(
vocab_path.to_str().unwrap(),
merges_path.to_str().unwrap(),
true,
false,
)?;
let config = BertConfig::from_file(config_path);
let bert_model = RobertaForMaskedLM::new(&vs.root(), &config);
vs.load(weights_path)?;
// Define input
let input = [
"<pad> Looks like one thing is missing",
"It\'s like comparing oranges to apples",
];
let tokenized_input = tokenizer.encode_list(&input, 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let mut tokenized_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]);
input
})
.collect::<Vec<_>>();
// Masking the token [thing] of sentence 1 and [oranges] of sentence 2
tokenized_input[0][4] = 103;
tokenized_input[1][5] = 103;
let tokenized_input = tokenized_input
.iter()
.map(|input| Tensor::of_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
let model_output = no_grad(|| {
bert_model.forward_t(
Some(input_tensor),
None,
None,
None,
None,
&None,
&None,
false,
)
});
// Print masked tokens
let index_1 = model_output
.prediction_scores
.get(0)
.get(4)
.argmax(0, false);
let index_2 = model_output
.prediction_scores
.get(1)
.get(5)
.argmax(0, false);
let word_1 = tokenizer.vocab().id_to_token(&index_1.int64_value(&[]));
let word_2 = tokenizer.vocab().id_to_token(&index_2.int64_value(&[]));
println!("{}", word_1); // Outputs "some" : "Looks like [some] thing is missing"
println!("{}", word_2); // Outputs "apple" : "It\'s like comparing [apple] to apples"
Ok(())
}

View File

@ -0,0 +1,59 @@
// Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc.
// Copyright 2019 Guillaume Becquin
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
extern crate anyhow;
use rust_bert::mbart::{
MBartConfigResources, MBartGenerator, MBartModelResources, MBartVocabResources,
};
use rust_bert::pipelines::generation_utils::{GenerateConfig, LanguageGenerator};
use rust_bert::resources::{RemoteResource, Resource};
fn main() -> anyhow::Result<()> {
let generate_config = GenerateConfig {
max_length: 56,
model_resource: Resource::Remote(RemoteResource::from_pretrained(
MBartModelResources::MBART50_MANY_TO_MANY,
)),
config_resource: Resource::Remote(RemoteResource::from_pretrained(
MBartConfigResources::MBART50_MANY_TO_MANY,
)),
vocab_resource: Resource::Remote(RemoteResource::from_pretrained(
MBartVocabResources::MBART50_MANY_TO_MANY,
)),
merges_resource: Resource::Remote(RemoteResource::from_pretrained(
MBartVocabResources::MBART50_MANY_TO_MANY,
)),
do_sample: false,
num_beams: 1,
..Default::default()
};
let model = MBartGenerator::new(generate_config)?;
let input_context_1 = "en_XX The quick brown fox jumps over the lazy dog.";
let target_language = model.get_tokenizer().convert_tokens_to_ids(["de_DE"])[0];
let output = model.generate(
Some(&[input_context_1]),
None,
None,
None,
None,
target_language,
None,
);
for sentence in output {
println!("{:?}", sentence);
}
Ok(())
}

View File

@ -35,13 +35,13 @@ fn main() -> anyhow::Result<()> {
..Default::default()
};
// Set-up masked LM model
// Set-up model
let t5_model = T5Generator::new(generate_config)?;
// Define input
let input = ["translate English to German: This sentence will get translated to German"];
let output = t5_model.generate(Some(input.to_vec()), None, None, None, None, None);
let output = t5_model.generate(Some(input.to_vec()), None, None, None, None, None, None);
println!("{:?}", output);
Ok(())

View File

@ -1,99 +0,0 @@
// Copyright 2018 Google AI and Google Brain team.
// Copyright 2018 Carnegie Mellon University Authors.
// Copyright 2020-present, the HuggingFace Inc. team.
// Copyright 2020 Guillaume Becquin
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
extern crate anyhow;
use rust_bert::resources::{RemoteResource, Resource};
use rust_bert::xlnet::{
XLNetConfig, XLNetConfigResources, XLNetLMHeadModel, XLNetModelResources, XLNetVocabResources,
};
use rust_bert::Config;
use rust_tokenizers::tokenizer::{MultiThreadedTokenizer, TruncationStrategy, XLNetTokenizer};
use rust_tokenizers::vocab::Vocab;
use tch::{nn, no_grad, Device, Kind, Tensor};
fn main() -> anyhow::Result<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
XLNetConfigResources::XLNET_BASE_CASED,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
XLNetVocabResources::XLNET_BASE_CASED,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
XLNetModelResources::XLNET_BASE_CASED,
));
let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?;
let weights_path = weights_resource.get_local_path()?;
// Set-up masked LM model
let device = Device::cuda_if_available();
let mut vs = nn::VarStore::new(device);
let tokenizer: XLNetTokenizer =
XLNetTokenizer::from_file(vocab_path.to_str().unwrap(), false, true)?;
let config = XLNetConfig::from_file(config_path);
let xlnet_model = XLNetLMHeadModel::new(&vs.root(), &config);
vs.load(weights_path)?;
// Define input
let input = ["One two three four"];
let tokenized_input = tokenizer.encode_list(&input, 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let tokenized_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]);
input
})
.map(|input| Tensor::of_slice(&(input[..input.len() - 2])))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
let perm_mask = Tensor::zeros(&[1, 4, 4], (Kind::Float, device));
let _ = perm_mask.narrow(2, 3, 1).fill_(1.0);
let target_mapping = Tensor::zeros(&[1, 1, 4], (Kind::Float, device));
let _ = target_mapping.narrow(2, 3, 1).fill_(1.0);
let model_output = no_grad(|| {
xlnet_model
.forward_t(
Some(&input_tensor),
None,
None,
Some(perm_mask.as_ref()),
Some(target_mapping.as_ref()),
None,
None,
false,
)
.unwrap()
});
let index_1 = model_output
.lm_logits
.get(0)
.argmax(1, false)
.int64_value(&[]);
let score_1 = model_output.lm_logits.double_value(&[0, 0, index_1]);
let word_1 = tokenizer.vocab().id_to_token(&index_1);
println!("{}, {}, {}", index_1, score_1, word_1);
Ok(())
}

View File

@ -11,7 +11,6 @@
//!
//! # Model set-up and pre-trained weights loading
//!
//! A full working example is provided in `examples/albert`, run with `cargo run --example albert`.
//! The example below illustrate a Masked language model example, the structure is similar for other models.
//! All models expect the following resources:
//! - Configuration file expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers)

View File

@ -1112,7 +1112,7 @@ impl PrivateLanguageGenerator<BartForConditionalGeneration, RobertaVocab, Robert
fn get_model(&self) -> &BartForConditionalGeneration {
&self.model
}
fn get_tokenizer(&self) -> &TokenizerOption {
fn _get_tokenizer(&self) -> &TokenizerOption {
&self.tokenizer
}
fn get_var_store(&self) -> &nn::VarStore {
@ -1201,7 +1201,7 @@ impl PrivateLanguageGenerator<BartForConditionalGeneration, RobertaVocab, Robert
where
S: AsRef<[&'a str]>,
{
let tokens = self.get_tokenizer().encode_list(
let tokens = self._get_tokenizer().encode_list(
prompt_text.as_ref(),
max_len as usize,
&TruncationStrategy::LongestFirst,
@ -1217,7 +1217,7 @@ impl PrivateLanguageGenerator<BartForConditionalGeneration, RobertaVocab, Robert
let pad_token = match pad_token_id {
Some(value) => value,
None => self
.get_tokenizer()
._get_tokenizer()
.convert_tokens_to_ids(&[RobertaVocab::unknown_value()])[0],
};

View File

@ -6,8 +6,7 @@
//!
//! # Model set-up and pre-trained weights loading
//!
//! A full working example is provided in `examples/bart`, run with `cargo run --example bart`.
//! Alternatively, the summarization capabilities are illustrated in `examples/summarization.rs`, run with `cargo run --example summarization`.
//! The summarization capabilities are illustrated in `examples/summarization_bart`, run with `cargo run --example summarization_bart`.
//! All models expect the following resources:
//! - Configuration file expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers)
//! - Model weights are expected to have a structure and parameter names following the [Transformers library](https://github.com/huggingface/transformers). A conversion using the Python utility scripts is required to convert the `.bin` weights to the `.ot` format.

View File

@ -1240,7 +1240,7 @@ mod test {
#[test]
#[ignore] // compilation is enough, no need to run
fn bart_model_send() {
fn bert_model_send() {
let config_resource =
Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT));
let config_path = config_resource.get_local_path().expect("");

View File

@ -10,7 +10,7 @@
//!
//! # Model set-up and pre-trained weights loading
//!
//! A full working example is provided in `examples/bert`, run with `cargo run --example bert`.
//! A full working example is provided in `examples/masked_language_model_bert`, run with `cargo run --example masked_language_model_bert`.
//! The example below illustrate a Masked language model example, the structure is similar for other models.
//! All models expect the following resources:
//! - Configuration file expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers)

View File

@ -9,7 +9,6 @@
//!
//! # Model set-up and pre-trained weights loading
//!
//! A full working example is provided in `examples/distilbert_masked_lm.rs`, run with `cargo run --example distilbert_masked_lm`.
//! The example below illustrate a DistilBERT Masked language model example, the structure is similar for other models.
//! All models expect the following resources:
//! - Configuration file expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers)

View File

@ -14,7 +14,6 @@
//!
//! # Model set-up and pre-trained weights loading
//!
//! A full working example is provided in `examples/electra_masked_lm.rs`, run with `cargo run --example electra_masked_lm`.
//! The example below illustrate a Masked language model example, the structure is similar for other models (e.g. discriminator).
//! All models expect the following resources:
//! - Configuration file expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers)

View File

@ -752,7 +752,7 @@ impl PrivateLanguageGenerator<GPT2LMHeadModel, Gpt2Vocab, Gpt2Tokenizer> for GPT
fn get_model(&self) -> &GPT2LMHeadModel {
&self.model
}
fn get_tokenizer(&self) -> &TokenizerOption {
fn _get_tokenizer(&self) -> &TokenizerOption {
&self.tokenizer
}
fn get_var_store(&self) -> &nn::VarStore {

View File

@ -6,7 +6,7 @@
//!
//! # Model set-up and pre-trained weights loading
//!
//! A full working example is provided in `examples/generation.rs`, run with `cargo run --example generation`.
//! A full working example is provided in `examples/generation_gpt2`, run with `cargo run --example generation_gpt2`.
//! All models expect the following resources:
//! - Configuration file expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers)
//! - Model weights are expected to have a structure and parameter names following the [Transformers library](https://github.com/huggingface/transformers). A conversion using the Python utility scripts is required to convert the `.bin` weights to the `.ot` format.

View File

@ -729,7 +729,7 @@ impl PrivateLanguageGenerator<GptNeoForCausalLM, Gpt2Vocab, Gpt2Tokenizer> for G
fn get_model(&self) -> &GptNeoForCausalLM {
&self.model
}
fn get_tokenizer(&self) -> &TokenizerOption {
fn _get_tokenizer(&self) -> &TokenizerOption {
&self.tokenizer
}
fn get_var_store(&self) -> &nn::VarStore {

View File

@ -5,6 +5,7 @@
//!
//! # Model set-up and pre-trained weights loading
//!
//! A full working example is provided in `examples/generation_gpt_neo`, run with `cargo run --example generation_gpt_neo`.
//! All models expect the following resources:
//! - Configuration file expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers)
//! - Model weights are expected to have a structure and parameter names following the [Transformers library](https://github.com/huggingface/transformers). A conversion using the Python utility scripts is required to convert the `.bin` weights to the `.ot` format.

View File

@ -58,6 +58,7 @@
//! GPT-Neo| | | |✅ | | | |
//! BART|✅| | |✅ |✅| | |
//! Marian| | | | | |✅| |
//! MBart|✅| | |✅ | | | |
//! Electra | |✅| | | | |✅|
//! ALBERT |✅|✅|✅| | | |✅|
//! T5 | | | |✅ |✅|✅| |

View File

@ -10,7 +10,7 @@
//!
//! # Model set-up and pre-trained weights loading
//!
//! A full working example (generation) is provided in `examples/question_answering_longformer`, run with `cargo run --example question_answering_longformer`.
//! A full working example (question answering) is provided in `examples/question_answering_longformer`, run with `cargo run --example question_answering_longformer`.
//! All models expect the following resources:
//! - Configuration file expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers)
//! - Model weights are expected to have a structure and parameter names following the [Transformers library](https://github.com/huggingface/transformers). A conversion using the Python utility scripts is required to convert the `.bin` weights to the `.ot` format.

View File

@ -859,7 +859,7 @@ impl PrivateLanguageGenerator<MarianForConditionalGeneration, MarianVocab, Maria
fn get_model(&self) -> &MarianForConditionalGeneration {
&self.model
}
fn get_tokenizer(&self) -> &TokenizerOption {
fn _get_tokenizer(&self) -> &TokenizerOption {
&self.tokenizer
}
fn get_var_store(&self) -> &nn::VarStore {
@ -950,7 +950,7 @@ impl PrivateLanguageGenerator<MarianForConditionalGeneration, MarianVocab, Maria
where
T: AsRef<[&'a str]>,
{
let tokens = self.get_tokenizer().encode_list(
let tokens = self._get_tokenizer().encode_list(
prompt_text.as_ref(),
max_len as usize,
&TruncationStrategy::LongestFirst,
@ -965,7 +965,7 @@ impl PrivateLanguageGenerator<MarianForConditionalGeneration, MarianVocab, Maria
let pad_token = match pad_token_id {
Some(value) => value,
None => self.get_tokenizer().get_unk_id(),
None => self._get_tokenizer().get_unk_id(),
};
let token_ids = token_ids

View File

@ -6,7 +6,7 @@
//!
//! # Model set-up and pre-trained weights loading
//!
//! A full working example is provided in `examples/translation.rs`, run with `cargo run --example translation`.
//! A full working example is provided in `examples/translation_marian`, run with `cargo run --example translation_marian`.
//! All models expect the following resources:
//! - Configuration file expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers)
//! - Model weights are expected to have a structure and parameter names following the [Transformers library](https://github.com/huggingface/transformers). A conversion using the Python utility scripts is required to convert the `.bin` weights to the `.ot` format.

View File

@ -698,7 +698,7 @@ impl LMHeadModel for MBartForConditionalGeneration {
/// # let device = Device::Cpu;
/// # let vs = nn::VarStore::new(device);
/// # let config = MBartConfig::from_file(config_path);
/// # let bart_model: MBartForConditionalGeneration = MBartForConditionalGeneration::new(&vs.root(), &config);
/// # let mbart_model: MBartForConditionalGeneration = MBartForConditionalGeneration::new(&vs.root(), &config);
/// let (batch_size, source_sequence_length, target_sequence_length) = (64, 128, 56);
/// let input_tensor = Tensor::rand(&[batch_size, source_sequence_length], (Int64, device));
/// let target_tensor = Tensor::rand(&[batch_size, target_sequence_length], (Int64, device));
@ -915,7 +915,7 @@ impl PrivateLanguageGenerator<MBartForConditionalGeneration, MBart50Vocab, MBart
fn get_model(&self) -> &MBartForConditionalGeneration {
&self.model
}
fn get_tokenizer(&self) -> &TokenizerOption {
fn _get_tokenizer(&self) -> &TokenizerOption {
&self.tokenizer
}
fn get_var_store(&self) -> &nn::VarStore {
@ -1002,7 +1002,7 @@ impl PrivateLanguageGenerator<MBartForConditionalGeneration, MBart50Vocab, MBart
where
S: AsRef<[&'a str]>,
{
let tokens = self.get_tokenizer().encode_list(
let tokens = self._get_tokenizer().encode_list(
prompt_text.as_ref(),
max_len as usize,
&TruncationStrategy::LongestFirst,
@ -1018,7 +1018,7 @@ impl PrivateLanguageGenerator<MBartForConditionalGeneration, MBart50Vocab, MBart
let pad_token = match pad_token_id {
Some(value) => value,
None => self
.get_tokenizer()
._get_tokenizer()
.convert_tokens_to_ids(&[MBart50Vocab::unknown_value()])[0],
};

View File

@ -1,3 +1,55 @@
//! # MBart (Liu et al.)
//!
//! Implementation of the MBart language model ([Multilingual Denoising Pre-training for Neural Machine Translation](https://arxiv.org/abs/2001.08210) Liu, Gu, Goyal, Li, Edunov, Ghazvininejad, Lewis, Zettlemoyer, 2020).
//! The base model is implemented in the `mbart_model::MBartModel` struct. The model also includes a language model head: `mbart_model::MBartForConditionalGeneration`
//! implementing the common `generation_utils::LMHeadModel` trait shared between the models used for generation (see `pipelines` for more information).
//!
//! # Model set-up and pre-trained weights loading
//!
//! The summarization capabilities are illustrated in `examples/translation_mbart`, run with `cargo run --example translation_mbart`.
//! All models expect the following resources:
//! - Configuration file expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers)
//! - Model weights are expected to have a structure and parameter names following the [Transformers library](https://github.com/huggingface/transformers). A conversion using the Python utility scripts is required to convert the `.bin` weights to the `.ot` format.
//! - `MBart50Tokenizer` using a `spiece.model` SentencePiece model
//! Pretrained models are available and can be downloaded using RemoteResources.
//!
//! ```no_run
//! # fn main() -> anyhow::Result<()> {
//! #
//! use tch::{nn, Device};
//! # use std::path::PathBuf;
//! use rust_bert::resources::{LocalResource, Resource};
//! use rust_bert::Config;
//! use rust_tokenizers::tokenizer::MBart50Tokenizer;
//! use rust_bert::mbart::{MBartConfig, MBartModel};
//!
//! let config_resource = Resource::Local(LocalResource {
//! local_path: PathBuf::from("path/to/config.json"),
//! });
//! let vocab_resource = Resource::Local(LocalResource {
//! local_path: PathBuf::from("path/to/vocab.txt"),
//! });
//! let weights_resource = Resource::Local(LocalResource {
//! local_path: PathBuf::from("path/to/model.ot"),
//! });
//! let config_path = config_resource.get_local_path()?;
//! let vocab_path = vocab_resource.get_local_path()?;
//! let weights_path = weights_resource.get_local_path()?;
//!
//! let device = Device::cuda_if_available();
//! let mut vs = nn::VarStore::new(device);
//! let tokenizer: MBart50Tokenizer = MBart50Tokenizer::from_file(
//! vocab_path.to_str().unwrap(),
//! false,
//! )?;
//! let config = MBartConfig::from_file(config_path);
//! let bart_model = MBartModel::new(&vs.root(), &config);
//! vs.load(weights_path)?;
//!
//! # Ok(())
//! # }
//! ```
mod attention;
mod decoder;
mod embeddings;

View File

@ -9,7 +9,6 @@
//!
//! # Model set-up and pre-trained weights loading
//!
//! A full working example (generation) is provided in `examples/mobilebert_masked_lm`, run with `cargo run --example mobilebert_masked_lm`.
//! All models expect the following resources:
//! - Configuration file expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers)
//! - Model weights are expected to have a structure and parameter names following the [Transformers library](https://github.com/huggingface/transformers). A conversion using the Python utility scripts is required to convert the `.bin` weights to the `.ot` format.

View File

@ -6,7 +6,6 @@
//!
//! # Model set-up and pre-trained weights loading
//!
//! A full working example is provided in `examples/openai_gpt`, run with `cargo run --example openai_gpt`.
//! All models expect the following resources:
//! - Configuration file expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers)
//! - Model weights are expected to have a structure and parameter names following the [Transformers library](https://github.com/huggingface/transformers). A conversion using the Python utility scripts is required to convert the `.bin` weights to the `.ot` format.

View File

@ -576,7 +576,7 @@ impl PrivateLanguageGenerator<OpenAIGPTLMHeadModel, OpenAiGptVocab, OpenAiGptTok
fn get_model(&self) -> &OpenAIGPTLMHeadModel {
&self.model
}
fn get_tokenizer(&self) -> &TokenizerOption {
fn _get_tokenizer(&self) -> &TokenizerOption {
&self.tokenizer
}
fn get_var_store(&self) -> &nn::VarStore {

View File

@ -691,7 +691,7 @@ impl PrivateLanguageGenerator<PegasusForConditionalGeneration, PegasusVocab, Peg
fn get_model(&self) -> &PegasusForConditionalGeneration {
&self.model
}
fn get_tokenizer(&self) -> &TokenizerOption {
fn _get_tokenizer(&self) -> &TokenizerOption {
&self.tokenizer
}
fn get_var_store(&self) -> &nn::VarStore {
@ -775,7 +775,7 @@ impl PrivateLanguageGenerator<PegasusForConditionalGeneration, PegasusVocab, Peg
where
S: AsRef<[&'a str]>,
{
let tokens = self.get_tokenizer().encode_list(
let tokens = self._get_tokenizer().encode_list(
prompt_text.as_ref(),
max_len as usize,
&TruncationStrategy::LongestFirst,
@ -791,7 +791,7 @@ impl PrivateLanguageGenerator<PegasusForConditionalGeneration, PegasusVocab, Peg
let pad_token = match pad_token_id {
Some(value) => value,
None => self
.get_tokenizer()
._get_tokenizer()
.convert_tokens_to_ids(&[PegasusVocab::pad_value()])[0],
};

View File

@ -698,7 +698,7 @@ impl ConversationOption {
pub fn get_tokenizer(&self) -> &TokenizerOption {
match self {
Self::GPT2(model_ref) => model_ref.get_tokenizer(),
Self::GPT2(model_ref) => model_ref._get_tokenizer(),
}
}

View File

@ -36,6 +36,7 @@
//! let min_length = Some(32);
//! let max_length = Some(128);
//! let decoder_start_id = None;
//! let forced_bos_token_id = None;
//!
//! let input_context = "The dog";
//! let second_input_context = "The cat was";
@ -45,6 +46,7 @@
//! min_length,
//! max_length,
//! decoder_start_id,
//! forced_bos_token_id,
//! None,
//! );
//! # Ok(())
@ -86,6 +88,7 @@ use crate::t5::LayerState as T5LayerState;
use crate::xlnet::LayerState as XLNetLayerState;
use self::ordered_float::OrderedFloat;
use crate::pipelines::common::TokenizerOption;
extern crate ordered_float;
@ -272,7 +275,7 @@ pub(crate) mod private_generation_utils {
pub trait PrivateLanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>> {
fn get_model(&self) -> &T;
fn get_tokenizer(&self) -> &TokenizerOption;
fn _get_tokenizer(&self) -> &TokenizerOption;
fn get_var_store(&self) -> &nn::VarStore;
fn get_config(&self) -> &GenerateConfig;
fn get_bos_id(&self) -> &Option<i64>;
@ -322,10 +325,10 @@ pub(crate) mod private_generation_utils {
where
S: AsRef<[&'a str]>,
{
let tokens = self.get_tokenizer().tokenize_list(prompt_text.as_ref());
let tokens = self._get_tokenizer().tokenize_list(prompt_text.as_ref());
let token_ids = tokens
.into_iter()
.map(|prompt_tokens| self.get_tokenizer().convert_tokens_to_ids(&prompt_tokens))
.map(|prompt_tokens| self._get_tokenizer().convert_tokens_to_ids(&prompt_tokens))
.collect::<Vec<Vec<i64>>>();
let num_truncated_tokens = token_ids
@ -365,7 +368,7 @@ pub(crate) mod private_generation_utils {
let pad_token = match pad_token_id {
Some(value) => value,
None => self.get_tokenizer().get_unk_id(),
None => self._get_tokenizer().get_unk_id(),
};
let token_ids = token_ids
@ -1219,7 +1222,7 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
/// num_return_sequences: 3,
/// ..Default::default()
/// };
/// let mut gpt2_generator = GPT2Generator::new(generate_config)?;
/// let gpt2_generator = GPT2Generator::new(generate_config)?;
/// let input_context = "The dog";
/// let second_input_context = "The cat was";
///
@ -1227,6 +1230,7 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
/// let min_length = 32;
/// let max_length = 128;
/// let decoder_start_token_id = None;
/// let forced_bos_token_id = None;
///
/// //Example custom function for fine-grained generation control
/// fn force_one_paragraph(_batch_id: i64, previous_token_ids: &Tensor) -> Vec<i64> {
@ -1251,6 +1255,7 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
/// min_length,
/// max_length,
/// decoder_start_token_id,
/// forced_bos_token_id,
/// Some(&force_one_paragraph)
/// );
/// # Ok(())
@ -1293,7 +1298,7 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
);
let mut output = Vec::with_capacity(generated.len());
for generated_sequence in generated {
output.push(self.get_tokenizer().decode(generated_sequence, true, true));
output.push(self._get_tokenizer().decode(generated_sequence, true, true));
}
output
}
@ -1337,13 +1342,14 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
/// num_return_sequences: 3,
/// ..Default::default()
/// };
/// let mut gpt2_generator = GPT2Generator::new(generate_config)?;
/// let gpt2_generator = GPT2Generator::new(generate_config)?;
/// let input_context = "The dog";
/// let second_input_context = "The cat was";
/// let attention_mask = None;
/// let min_length = 32;
/// let max_length = 128;
/// let decoder_start_token_id = None;
/// let forced_bos_token_id = None;
///
/// //Example custom function for fine-grained generation control
/// fn force_one_paragraph(_batch_id: i64, previous_token_ids: &Tensor) -> Vec<i64> {
@ -1368,6 +1374,7 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
/// min_length,
/// max_length,
/// decoder_start_token_id,
/// forced_bos_token_id,
/// Some(&force_one_paragraph),
/// );
/// # Ok(())
@ -1462,13 +1469,14 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
/// num_return_sequences: 3,
/// ..Default::default()
/// };
/// let mut gpt2_generator = GPT2Generator::new(generate_config)?;
/// let gpt2_generator = GPT2Generator::new(generate_config)?;
/// let input_context = "The dog";
/// let second_input_context = "The cat was";
/// let attention_mask = None;
/// let min_length = 32;
/// let max_length = 128;
/// let decoder_start_token_id = None;
/// let forced_bos_token_id = None;
///
/// //Example custom function for fine-grained generation control
/// fn force_one_paragraph(_batch_id: i64, previous_token_ids: &Tensor) -> Vec<i64> {
@ -1493,6 +1501,7 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
/// min_length,
/// max_length,
/// decoder_start_token_id,
/// forced_bos_token_id,
/// Some(&force_one_paragraph),
/// );
/// # Ok(())
@ -1674,6 +1683,46 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
}
output_ids
}
/// Returns a reference to the text generator's tokenizer
///
/// # Returns
/// * `&TokenizerOption` Reference to the generator's tokenizer.
///
/// # Example
///
/// ```no_run
/// # use std::path::PathBuf;
/// # use tch::Device;
/// # fn main() -> anyhow::Result<()> {
/// use rust_bert::gpt2::GPT2Generator;
/// use rust_bert::pipelines::generation_utils::{GenerateConfig, LanguageGenerator};
/// use tch::Tensor;
/// # let mut home: PathBuf = dirs::home_dir().unwrap();
/// # home.push("rustbert");
/// # home.push("gpt2");
/// # let config_path = &home.as_path().join("config.json");
/// # let vocab_path = &home.as_path().join("vocab.txt");
/// # let merges_path = &home.as_path().join("merges.txt");
/// # let weights_path = &home.as_path().join("model.ot");
/// let device = Device::cuda_if_available();
/// let generate_config = GenerateConfig {
/// max_length: 30,
/// do_sample: true,
/// num_beams: 5,
/// temperature: 1.1,
/// num_return_sequences: 3,
/// ..Default::default()
/// };
/// let gpt2_generator = GPT2Generator::new(generate_config)?;
/// let tokenizer = gpt2_generator.get_tokenizer();
/// tokenizer.tokenize("Hello, world!");
/// # Ok(())
/// # }
/// ```
fn get_tokenizer(&self) -> &TokenizerOption {
self._get_tokenizer()
}
}
#[derive(Debug)]

View File

@ -229,11 +229,11 @@ impl TextGenerationOption {
/// Interface method to access tokenizer
pub fn get_tokenizer(&self) -> &TokenizerOption {
match self {
Self::GPT(model_ref) => model_ref.get_tokenizer(),
Self::GPT2(model_ref) => model_ref.get_tokenizer(),
Self::GPTNeo(model_ref) => model_ref.get_tokenizer(),
Self::XLNet(model_ref) => model_ref.get_tokenizer(),
Self::Reformer(model_ref) => model_ref.get_tokenizer(),
Self::GPT(model_ref) => model_ref._get_tokenizer(),
Self::GPT2(model_ref) => model_ref._get_tokenizer(),
Self::GPTNeo(model_ref) => model_ref._get_tokenizer(),
Self::XLNet(model_ref) => model_ref._get_tokenizer(),
Self::Reformer(model_ref) => model_ref._get_tokenizer(),
}
}

View File

@ -7,7 +7,7 @@
//!
//! # Model set-up and pre-trained weights loading
//!
//! A full working example (generation) is provided in `examples/summarization_prophetnet`, run with `cargo run --example summarization_prophetnet`.
//! A full working example (summarization) is provided in `examples/summarization_prophetnet`, run with `cargo run --example summarization_prophetnet`.
//! All models expect the following resources:
//! - Configuration file expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers)
//! - Model weights are expected to have a structure and parameter names following the [Transformers library](https://github.com/huggingface/transformers). A conversion using the Python utility scripts is required to convert the `.bin` weights to the `.ot` format.

View File

@ -993,7 +993,7 @@ impl
fn get_model(&self) -> &ProphetNetForConditionalGeneration {
&self.model
}
fn get_tokenizer(&self) -> &TokenizerOption {
fn _get_tokenizer(&self) -> &TokenizerOption {
&self.tokenizer
}
fn get_var_store(&self) -> &nn::VarStore {
@ -1069,7 +1069,7 @@ impl
where
S: AsRef<[&'a str]>,
{
let tokens = self.get_tokenizer().encode_list(
let tokens = self._get_tokenizer().encode_list(
prompt_text.as_ref(),
max_len as usize,
&TruncationStrategy::LongestFirst,
@ -1085,7 +1085,7 @@ impl
let pad_token = match pad_token_id {
Some(value) => value,
None => self
.get_tokenizer()
._get_tokenizer()
.convert_tokens_to_ids(&[ProphetNetVocab::unknown_value()])[0],
};

View File

@ -1114,7 +1114,7 @@ impl PrivateLanguageGenerator<ReformerModelWithLMHead, ReformerVocab, ReformerTo
fn get_model(&self) -> &ReformerModelWithLMHead {
&self.model
}
fn get_tokenizer(&self) -> &TokenizerOption {
fn _get_tokenizer(&self) -> &TokenizerOption {
&self.tokenizer
}
fn get_var_store(&self) -> &nn::VarStore {

View File

@ -10,7 +10,6 @@
//!
//! # Model set-up and pre-trained weights loading
//!
//! A full working example is provided in `examples/roberta.rs`, run with `cargo run --example roberta`.
//! The example below illustrate a Masked language model example, the structure is similar for other models.
//! All models expect the following resources:
//! - Configuration file expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers)

View File

@ -6,7 +6,7 @@
//!
//! # Model set-up and pre-trained weights loading
//!
//! A full working example (translation) is provided in `examples/t5`, run with `cargo run --example t5`.
//! A full working example (summarization) is provided in `examples/summarization_t5`, run with `cargo run --example summarization_t5`.
//! All models expect the following resources:
//! - Configuration file expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers)
//! - Model weights are expected to have a structure and parameter names following the [Transformers library](https://github.com/huggingface/transformers). A conversion using the Python utility scripts is required to convert the `.bin` weights to the `.ot` format.

View File

@ -778,7 +778,7 @@ impl PrivateLanguageGenerator<T5ForConditionalGeneration, T5Vocab, T5Tokenizer>
fn get_model(&self) -> &T5ForConditionalGeneration {
&self.model
}
fn get_tokenizer(&self) -> &TokenizerOption {
fn _get_tokenizer(&self) -> &TokenizerOption {
&self.tokenizer
}
fn get_var_store(&self) -> &nn::VarStore {
@ -850,7 +850,7 @@ impl PrivateLanguageGenerator<T5ForConditionalGeneration, T5Vocab, T5Tokenizer>
where
S: AsRef<[&'a str]>,
{
let tokens = self.get_tokenizer().encode_list(
let tokens = self._get_tokenizer().encode_list(
prompt_text.as_ref(),
max_len as usize,
&TruncationStrategy::LongestFirst,
@ -865,7 +865,7 @@ impl PrivateLanguageGenerator<T5ForConditionalGeneration, T5Vocab, T5Tokenizer>
let pad_token = match pad_token_id {
Some(value) => value,
None => self.get_tokenizer().get_unk_id(),
None => self._get_tokenizer().get_unk_id(),
};
let token_ids = token_ids

View File

@ -1615,7 +1615,7 @@ impl PrivateLanguageGenerator<XLNetLMHeadModel, XLNetVocab, XLNetTokenizer> for
fn get_model(&self) -> &XLNetLMHeadModel {
&self.model
}
fn get_tokenizer(&self) -> &TokenizerOption {
fn _get_tokenizer(&self) -> &TokenizerOption {
&self.tokenizer
}
fn get_var_store(&self) -> &nn::VarStore {

View File

@ -77,7 +77,6 @@ fn bart_lm_model() -> anyhow::Result<()> {
#[test]
fn bart_summarization_greedy() -> anyhow::Result<()> {
// Set-up masked LM model
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
BartConfigResources::DISTILBART_CNN_6_6,
));
@ -139,7 +138,6 @@ about exoplanets like K2-18b."];
#[test]
fn bart_summarization_beam_search() -> anyhow::Result<()> {
// Set-up masked LM model
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
BartConfigResources::DISTILBART_CNN_6_6,
));
@ -202,7 +200,7 @@ about exoplanets like K2-18b."];
#[test]
#[cfg_attr(not(feature = "all-tests"), ignore)]
fn bart_zero_shot_classification() -> anyhow::Result<()> {
// Set-up model model
// Set-up model
let zero_shot_config = ZeroShotClassificationConfig {
device: Device::Cpu,
..Default::default()
@ -235,7 +233,7 @@ fn bart_zero_shot_classification() -> anyhow::Result<()> {
#[test]
#[cfg_attr(not(feature = "all-tests"), ignore)]
fn bart_zero_shot_classification_multilabel() -> anyhow::Result<()> {
// Set-up model model
// Set-up model
let zero_shot_config = ZeroShotClassificationConfig {
device: Device::Cpu,
..Default::default()

View File

@ -28,7 +28,7 @@ fn distilgpt2_lm_model() -> anyhow::Result<()> {
let merges_path = merges_resource.get_local_path()?;
let weights_path = weights_resource.get_local_path()?;
// Set-up masked LM model
// Set-up model
let device = Device::Cpu;
let mut vs = nn::VarStore::new(device);
let tokenizer: Gpt2Tokenizer = Gpt2Tokenizer::from_file(

View File

@ -426,6 +426,7 @@ fn gpt2_prefix_allowed_token_greedy() -> anyhow::Result<()> {
None,
None,
None,
None,
Some(&force_one_paragraph),
);
@ -490,6 +491,7 @@ fn gpt2_prefix_allowed_token_beam_search() -> anyhow::Result<()> {
None,
None,
None,
None,
Some(&force_one_paragraph),
);

View File

@ -29,7 +29,7 @@ fn gpt_neo_lm() -> anyhow::Result<()> {
let merges_path = merges_resource.get_local_path()?;
let weights_path = weights_resource.get_local_path()?;
// Set-up masked LM model
// Set-up model
let device = Device::Cpu;
let mut vs = nn::VarStore::new(device);
let tokenizer: Gpt2Tokenizer = Gpt2Tokenizer::from_file(
@ -122,7 +122,7 @@ fn test_generation_gpt_neo() -> anyhow::Result<()> {
GptNeoModelResources::GPT_NEO_125M,
));
// Set-up translation model
// Set-up model
let generation_config = TextGenerationConfig {
model_type: ModelType::GPTNeo,
model_resource,

110
tests/mbart.rs Normal file
View File

@ -0,0 +1,110 @@
use rust_bert::mbart::{
MBartConfig, MBartConfigResources, MBartGenerator, MBartModel, MBartModelResources,
MBartVocabResources,
};
use rust_bert::pipelines::generation_utils::{GenerateConfig, LanguageGenerator};
use rust_bert::pipelines::summarization::{SummarizationConfig, SummarizationModel};
use rust_bert::resources::{RemoteResource, Resource};
use rust_bert::Config;
use rust_tokenizers::tokenizer::{MBart50Tokenizer, Tokenizer, TruncationStrategy};
use tch::{nn, Device, Tensor};
#[test]
fn mbart_lm_model() -> anyhow::Result<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
MBartConfigResources::MBART50_MANY_TO_MANY,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
MBartVocabResources::MBART50_MANY_TO_MANY,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
MBartModelResources::MBART50_MANY_TO_MANY,
));
let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?;
let weights_path = weights_resource.get_local_path()?;
// Set-up masked LM model
let device = Device::Cpu;
let mut vs = nn::VarStore::new(device);
let tokenizer = MBart50Tokenizer::from_file(vocab_path.to_str().unwrap(), false)?;
let config = MBartConfig::from_file(config_path);
let mbart_model = MBartModel::new(&vs.root() / "model", &config);
vs.load(weights_path)?;
// Define input
let input = ["One two three four"];
let tokenized_input = tokenizer.encode_list(&input, 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let tokenized_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]);
input
})
.map(|input| Tensor::of_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
let model_output =
mbart_model.forward_t(Some(&input_tensor), None, None, None, None, None, false);
assert_eq!(model_output.decoder_output.size(), vec!(1, 5, 1024));
assert_eq!(
model_output.encoder_hidden_state.unwrap().size(),
vec!(1, 5, 1024)
);
assert!((model_output.decoder_output.double_value(&[0, 0, 0]) - -0.8936).abs() < 1e-4);
Ok(())
}
#[test]
fn mbart_translation() -> anyhow::Result<()> {
// Resources paths
let generate_config = GenerateConfig {
max_length: 56,
model_resource: Resource::Remote(RemoteResource::from_pretrained(
MBartModelResources::MBART50_MANY_TO_MANY,
)),
config_resource: Resource::Remote(RemoteResource::from_pretrained(
MBartConfigResources::MBART50_MANY_TO_MANY,
)),
vocab_resource: Resource::Remote(RemoteResource::from_pretrained(
MBartVocabResources::MBART50_MANY_TO_MANY,
)),
merges_resource: Resource::Remote(RemoteResource::from_pretrained(
MBartVocabResources::MBART50_MANY_TO_MANY,
)),
do_sample: false,
num_beams: 3,
..Default::default()
};
let model = MBartGenerator::new(generate_config)?;
let input_context = "en_XX The quick brown fox jumps over the lazy dog.";
let target_language = model.get_tokenizer().convert_tokens_to_ids(["de_DE"])[0];
let output = model.generate(
Some(&[input_context]),
None,
None,
None,
None,
target_language,
None,
);
assert_eq!(output.len(), 1);
assert_eq!(
output[0],
"de_DE Der schnelle braune Fuchs springt über den faulen Hund."
);
Ok(())
}

View File

@ -117,7 +117,7 @@ fn openai_gpt_generation_greedy() -> anyhow::Result<()> {
OpenAiGptModelResources::GPT,
));
// Set-up masked LM model
// Set-up model
let generate_config = TextGenerationConfig {
model_type: ModelType::OpenAiGpt,
model_resource,
@ -159,7 +159,7 @@ fn openai_gpt_generation_beam_search() -> anyhow::Result<()> {
OpenAiGptModelResources::GPT,
));
// Set-up masked LM model
// Set-up model
let generate_config = TextGenerationConfig {
model_type: ModelType::OpenAiGpt,
model_resource,
@ -211,7 +211,7 @@ fn openai_gpt_generation_beam_search_multiple_prompts_without_padding() -> anyho
OpenAiGptModelResources::GPT,
));
// Set-up masked LM model
// Set-up model
let generate_config = TextGenerationConfig {
model_type: ModelType::OpenAiGpt,
model_resource,
@ -279,7 +279,7 @@ fn openai_gpt_generation_beam_search_multiple_prompts_with_padding() -> anyhow::
OpenAiGptModelResources::GPT,
));
// Set-up masked LM model
// Set-up model
let generate_config = TextGenerationConfig {
model_type: ModelType::OpenAiGpt,
model_resource,

View File

@ -7,7 +7,7 @@ use tch::Device;
#[test]
fn pegasus_summarization_greedy() -> anyhow::Result<()> {
// Set-up masked LM model
// Set-up model
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
PegasusConfigResources::CNN_DAILYMAIL,
));

View File

@ -9,7 +9,7 @@ use tch::Device;
#[test]
fn prophetnet_summarization_greedy() -> anyhow::Result<()> {
// Set-up masked LM model
// Set-up model
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
ProphetNetConfigResources::PROPHETNET_LARGE_CNN_DM,
));