mirror of
https://github.com/guillaume-be/rust-bert.git
synced 2024-07-14 16:00:23 +03:00
Updated documentation, cleaned examples, added integration tests
This commit is contained in:
parent
698e7143e8
commit
5907b7d954
@ -3,7 +3,9 @@ All notable changes to this project will be documented in this file. The format
|
||||
|
||||
## [Unreleased]
|
||||
## Added
|
||||
- (BREAKING) Support for `prefix_allowed_tokens_fn` arguments for generation, allowing users to control the generation via custom functions
|
||||
- (BREAKING) Support for `prefix_allowed_tokens_fn` argument for generation, allowing users to control the generation via custom functions
|
||||
- (BREAKING) Support for `forced_bos_token_id` argument for generation, allowing users to force a given BOS token for generation (useful for MBart/M2M-class models)
|
||||
- Addition of the MBart Language model and support for text generation / direct translation between 50 language
|
||||
|
||||
## [0.15.1] - 2021-06-01
|
||||
### Fixed
|
||||
|
@ -46,8 +46,9 @@ RoBERTa|✅|✅|✅| | | |✅|
|
||||
GPT| | | |✅ | | | |
|
||||
GPT2| | | |✅ | | | |
|
||||
GPT-Neo| | | |✅ | | | |
|
||||
BART|✅| | |✅ |✅| | |
|
||||
BART|✅| | |✅ |✅| | |
|
||||
Marian| | | | | |✅| |
|
||||
MBart|✅| | |✅ | | | |
|
||||
Electra | |✅| | | | |✅|
|
||||
ALBERT |✅|✅|✅| | | |✅|
|
||||
T5 | | | |✅ |✅|✅| |
|
||||
|
@ -1,97 +0,0 @@
|
||||
// Copyright 2018 Google AI and Google Brain team.
|
||||
// Copyright 2020-present, the HuggingFace Inc. team.
|
||||
// Copyright 2020 Guillaume Becquin
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
extern crate anyhow;
|
||||
|
||||
use rust_bert::albert::{
|
||||
AlbertConfig, AlbertConfigResources, AlbertForMaskedLM, AlbertModelResources,
|
||||
AlbertVocabResources,
|
||||
};
|
||||
use rust_bert::resources::{RemoteResource, Resource};
|
||||
use rust_bert::Config;
|
||||
use rust_tokenizers::tokenizer::{AlbertTokenizer, Tokenizer, TruncationStrategy};
|
||||
use rust_tokenizers::vocab::Vocab;
|
||||
use tch::{nn, no_grad, Device, Tensor};
|
||||
|
||||
fn main() -> anyhow::Result<()> {
|
||||
// Resources paths
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
AlbertConfigResources::ALBERT_BASE_V2,
|
||||
));
|
||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
AlbertVocabResources::ALBERT_BASE_V2,
|
||||
));
|
||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
AlbertModelResources::ALBERT_BASE_V2,
|
||||
));
|
||||
let config_path = config_resource.get_local_path()?;
|
||||
let vocab_path = vocab_resource.get_local_path()?;
|
||||
let weights_path = weights_resource.get_local_path()?;
|
||||
|
||||
// Set-up masked LM model
|
||||
let device = Device::Cpu;
|
||||
let mut vs = nn::VarStore::new(device);
|
||||
let tokenizer: AlbertTokenizer =
|
||||
AlbertTokenizer::from_file(vocab_path.to_str().unwrap(), true, false)?;
|
||||
let config = AlbertConfig::from_file(config_path);
|
||||
let albert_model = AlbertForMaskedLM::new(&vs.root(), &config);
|
||||
vs.load(weights_path)?;
|
||||
|
||||
// Define input
|
||||
let input = [
|
||||
"Looks like one [MASK] is missing",
|
||||
"It was a very nice and [MASK] day",
|
||||
];
|
||||
let tokenized_input = tokenizer.encode_list(&input, 128, &TruncationStrategy::LongestFirst, 0);
|
||||
let max_len = tokenized_input
|
||||
.iter()
|
||||
.map(|input| input.token_ids.len())
|
||||
.max()
|
||||
.unwrap();
|
||||
let tokenized_input = tokenized_input
|
||||
.iter()
|
||||
.map(|input| input.token_ids.clone())
|
||||
.map(|mut input| {
|
||||
input.extend(vec![0; max_len - input.len()]);
|
||||
input
|
||||
})
|
||||
.map(|input| Tensor::of_slice(&(input)))
|
||||
.collect::<Vec<_>>();
|
||||
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
|
||||
|
||||
// Forward pass
|
||||
let model_output =
|
||||
no_grad(|| albert_model.forward_t(Some(input_tensor), None, None, None, None, false));
|
||||
println!(
|
||||
"{:?}",
|
||||
model_output.prediction_scores.double_value(&[0, 0, 0])
|
||||
);
|
||||
// Print masked tokens
|
||||
let index_1 = model_output
|
||||
.prediction_scores
|
||||
.get(0)
|
||||
.get(4)
|
||||
.argmax(0, false);
|
||||
let index_2 = model_output
|
||||
.prediction_scores
|
||||
.get(1)
|
||||
.get(7)
|
||||
.argmax(0, false);
|
||||
let word_1 = tokenizer.vocab().id_to_token(&index_1.int64_value(&[]));
|
||||
let word_2 = tokenizer.vocab().id_to_token(&index_2.int64_value(&[]));
|
||||
|
||||
println!("{} - {}", &index_1.int64_value(&[]), word_1); // Outputs "_them" : "Looks like one [them] is missing"
|
||||
println!("{} - {}", &index_2.int64_value(&[]), word_2); // Outputs "_enjoyable" : "It was a very nice and [enjoyable] day"
|
||||
|
||||
Ok(())
|
||||
}
|
@ -1,81 +0,0 @@
|
||||
// Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc.
|
||||
// Copyright 2019 Guillaume Becquin
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
extern crate anyhow;
|
||||
|
||||
use rust_bert::bart::{
|
||||
BartConfig, BartConfigResources, BartMergesResources, BartModel, BartModelResources,
|
||||
BartVocabResources,
|
||||
};
|
||||
use rust_bert::resources::{RemoteResource, Resource};
|
||||
use rust_bert::Config;
|
||||
use rust_tokenizers::tokenizer::{RobertaTokenizer, Tokenizer, TruncationStrategy};
|
||||
use tch::{nn, no_grad, Device, Tensor};
|
||||
|
||||
fn main() -> anyhow::Result<()> {
|
||||
// Resources paths
|
||||
let config_resource =
|
||||
Resource::Remote(RemoteResource::from_pretrained(BartConfigResources::BART));
|
||||
let vocab_resource =
|
||||
Resource::Remote(RemoteResource::from_pretrained(BartVocabResources::BART));
|
||||
let merges_resource =
|
||||
Resource::Remote(RemoteResource::from_pretrained(BartMergesResources::BART));
|
||||
let weights_resource =
|
||||
Resource::Remote(RemoteResource::from_pretrained(BartModelResources::BART));
|
||||
let config_path = config_resource.get_local_path()?;
|
||||
let vocab_path = vocab_resource.get_local_path()?;
|
||||
let merges_path = merges_resource.get_local_path()?;
|
||||
let weights_path = weights_resource.get_local_path()?;
|
||||
|
||||
// Set-up masked LM model
|
||||
let device = Device::cuda_if_available();
|
||||
let mut vs = nn::VarStore::new(device);
|
||||
let tokenizer = RobertaTokenizer::from_file(
|
||||
vocab_path.to_str().unwrap(),
|
||||
merges_path.to_str().unwrap(),
|
||||
false,
|
||||
false,
|
||||
)?;
|
||||
let config = BartConfig::from_file(config_path);
|
||||
let bart_model = BartModel::new(&vs.root(), &config);
|
||||
vs.load(weights_path)?;
|
||||
|
||||
// Define input
|
||||
let input = ["One two three four"];
|
||||
|
||||
let tokenized_input = tokenizer.encode_list(input, 1024, &TruncationStrategy::LongestFirst, 0);
|
||||
let max_len = tokenized_input
|
||||
.iter()
|
||||
.map(|input| input.token_ids.len())
|
||||
.max()
|
||||
.unwrap();
|
||||
let tokenized_input = tokenized_input
|
||||
.iter()
|
||||
.map(|input| input.token_ids.clone())
|
||||
.map(|mut input| {
|
||||
input.extend(vec![0; max_len - input.len()]);
|
||||
input
|
||||
})
|
||||
.map(|input| Tensor::of_slice(&(input)))
|
||||
.collect::<Vec<_>>();
|
||||
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
|
||||
|
||||
// Forward pass
|
||||
let model_output =
|
||||
no_grad(|| bart_model.forward_t(Some(&input_tensor), None, None, None, None, None, false));
|
||||
|
||||
// Print masked tokens
|
||||
println!("{:?}", model_output.encoder_hidden_state);
|
||||
println!("{:?}", model_output.decoder_output);
|
||||
println!("{:?}", model_output.decoder_output.double_value(&[0, 0, 0]));
|
||||
Ok(())
|
||||
}
|
@ -1,102 +0,0 @@
|
||||
// Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc.
|
||||
// Copyright 2019 Guillaume Becquin
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
extern crate anyhow;
|
||||
|
||||
use rust_bert::distilbert::{
|
||||
DistilBertConfig, DistilBertConfigResources, DistilBertModelMaskedLM, DistilBertModelResources,
|
||||
DistilBertVocabResources,
|
||||
};
|
||||
use rust_bert::resources::{RemoteResource, Resource};
|
||||
use rust_bert::Config;
|
||||
use rust_tokenizers::tokenizer::{BertTokenizer, MultiThreadedTokenizer, TruncationStrategy};
|
||||
use rust_tokenizers::vocab::Vocab;
|
||||
use tch::{nn, no_grad, Device, Tensor};
|
||||
|
||||
fn main() -> anyhow::Result<()> {
|
||||
// Resources paths
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
DistilBertConfigResources::DISTIL_BERT,
|
||||
));
|
||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
DistilBertVocabResources::DISTIL_BERT,
|
||||
));
|
||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
DistilBertModelResources::DISTIL_BERT,
|
||||
));
|
||||
let config_path = config_resource.get_local_path()?;
|
||||
let vocab_path = vocab_resource.get_local_path()?;
|
||||
let weights_path = weights_resource.get_local_path()?;
|
||||
|
||||
// Set-up masked LM model
|
||||
let device = Device::Cpu;
|
||||
let mut vs = nn::VarStore::new(device);
|
||||
let tokenizer: BertTokenizer =
|
||||
BertTokenizer::from_file(vocab_path.to_str().unwrap(), true, true)?;
|
||||
let config = DistilBertConfig::from_file(config_path);
|
||||
let distil_bert_model = DistilBertModelMaskedLM::new(&vs.root(), &config);
|
||||
vs.load(weights_path)?;
|
||||
|
||||
// Define input
|
||||
let input = [
|
||||
"Looks like one thing is missing",
|
||||
"It\'s like comparing oranges to apples",
|
||||
];
|
||||
let tokenized_input = tokenizer.encode_list(input, 128, &TruncationStrategy::LongestFirst, 0);
|
||||
let max_len = tokenized_input
|
||||
.iter()
|
||||
.map(|input| input.token_ids.len())
|
||||
.max()
|
||||
.unwrap();
|
||||
let mut tokenized_input = tokenized_input
|
||||
.iter()
|
||||
.map(|input| input.token_ids.clone())
|
||||
.map(|mut input| {
|
||||
input.extend(vec![0; max_len - input.len()]);
|
||||
input
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
// Masking the token [thing] of sentence 1 and [oranges] of sentence 2
|
||||
tokenized_input[0][4] = 103;
|
||||
tokenized_input[1][6] = 103;
|
||||
let tokenized_input = tokenized_input
|
||||
.iter()
|
||||
.map(|input| Tensor::of_slice(&(input)))
|
||||
.collect::<Vec<_>>();
|
||||
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
|
||||
|
||||
// Forward pass
|
||||
let model_output = no_grad(|| {
|
||||
distil_bert_model
|
||||
.forward_t(Some(input_tensor), None, None, false)
|
||||
.unwrap()
|
||||
});
|
||||
|
||||
// Print masked tokens
|
||||
let index_1 = model_output
|
||||
.prediction_scores
|
||||
.get(0)
|
||||
.get(4)
|
||||
.argmax(0, false);
|
||||
let index_2 = model_output
|
||||
.prediction_scores
|
||||
.get(1)
|
||||
.get(6)
|
||||
.argmax(0, false);
|
||||
let word_1 = tokenizer.vocab().id_to_token(&index_1.int64_value(&[]));
|
||||
let word_2 = tokenizer.vocab().id_to_token(&index_2.int64_value(&[]));
|
||||
|
||||
println!("{}", word_1); // Outputs "person" : "Looks like one [person] is missing"
|
||||
println!("{}", word_2); // Outputs "pear" : "It\'s like comparing [pear] to apples"
|
||||
|
||||
Ok(())
|
||||
}
|
@ -1,406 +0,0 @@
|
||||
extern crate anyhow;
|
||||
|
||||
use rust_bert::albert::{AlbertConfigResources, AlbertModelResources, AlbertVocabResources};
|
||||
use rust_bert::bart::{
|
||||
BartConfigResources, BartMergesResources, BartModelResources, BartVocabResources,
|
||||
};
|
||||
use rust_bert::bert::{BertConfigResources, BertModelResources, BertVocabResources};
|
||||
use rust_bert::distilbert::{
|
||||
DistilBertConfigResources, DistilBertModelResources, DistilBertVocabResources,
|
||||
};
|
||||
use rust_bert::electra::{ElectraConfigResources, ElectraModelResources, ElectraVocabResources};
|
||||
use rust_bert::gpt2::{
|
||||
Gpt2ConfigResources, Gpt2MergesResources, Gpt2ModelResources, Gpt2VocabResources,
|
||||
};
|
||||
use rust_bert::openai_gpt::{
|
||||
OpenAiGptConfigResources, OpenAiGptMergesResources, OpenAiGptModelResources,
|
||||
OpenAiGptVocabResources,
|
||||
};
|
||||
use rust_bert::resources::{RemoteResource, Resource};
|
||||
use rust_bert::roberta::{
|
||||
RobertaConfigResources, RobertaMergesResources, RobertaModelResources, RobertaVocabResources,
|
||||
};
|
||||
use rust_bert::t5::{T5ConfigResources, T5ModelResources, T5VocabResources};
|
||||
use rust_bert::xlnet::{XLNetConfigResources, XLNetModelResources, XLNetVocabResources};
|
||||
|
||||
/// This example downloads and caches all dependencies used in model tests. This allows for safe
|
||||
/// multi threaded testing (two test using the same resource would otherwise download the file to
|
||||
/// the same location).
|
||||
|
||||
fn download_distil_gpt2() -> anyhow::Result<()> {
|
||||
// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
Gpt2ConfigResources::DISTIL_GPT2,
|
||||
));
|
||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
Gpt2VocabResources::DISTIL_GPT2,
|
||||
));
|
||||
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
Gpt2MergesResources::DISTIL_GPT2,
|
||||
));
|
||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
Gpt2ModelResources::DISTIL_GPT2,
|
||||
));
|
||||
let _ = config_resource.get_local_path()?;
|
||||
let _ = vocab_resource.get_local_path()?;
|
||||
let _ = merges_resource.get_local_path()?;
|
||||
let _ = weights_resource.get_local_path()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn download_distilbert_sst2() -> anyhow::Result<()> {
|
||||
// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
|
||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
DistilBertModelResources::DISTIL_BERT_SST2,
|
||||
));
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
DistilBertConfigResources::DISTIL_BERT_SST2,
|
||||
));
|
||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
DistilBertVocabResources::DISTIL_BERT_SST2,
|
||||
));
|
||||
let _ = config_resource.get_local_path()?;
|
||||
let _ = vocab_resource.get_local_path()?;
|
||||
let _ = weights_resource.get_local_path()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn download_distilbert_qa() -> anyhow::Result<()> {
|
||||
// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
|
||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
DistilBertModelResources::DISTIL_BERT_SQUAD,
|
||||
));
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
DistilBertConfigResources::DISTIL_BERT_SQUAD,
|
||||
));
|
||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
DistilBertVocabResources::DISTIL_BERT_SQUAD,
|
||||
));
|
||||
let _ = config_resource.get_local_path()?;
|
||||
let _ = vocab_resource.get_local_path()?;
|
||||
let _ = weights_resource.get_local_path()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn download_distilbert() -> anyhow::Result<()> {
|
||||
// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
|
||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
DistilBertModelResources::DISTIL_BERT,
|
||||
));
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
DistilBertConfigResources::DISTIL_BERT,
|
||||
));
|
||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
DistilBertVocabResources::DISTIL_BERT,
|
||||
));
|
||||
let _ = config_resource.get_local_path()?;
|
||||
let _ = vocab_resource.get_local_path()?;
|
||||
let _ = weights_resource.get_local_path()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn download_gpt2() -> anyhow::Result<()> {
|
||||
// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2. Modified with conversion to C-array format.
|
||||
let config_resource =
|
||||
Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2));
|
||||
let vocab_resource =
|
||||
Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2));
|
||||
let merges_resource =
|
||||
Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
|
||||
let weights_resource =
|
||||
Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
|
||||
let _ = config_resource.get_local_path()?;
|
||||
let _ = vocab_resource.get_local_path()?;
|
||||
let _ = merges_resource.get_local_path()?;
|
||||
let _ = weights_resource.get_local_path()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn download_gpt() -> anyhow::Result<()> {
|
||||
// Shared under MIT license by the OpenAI team at https://github.com/openai/finetune-transformer-lm. Modified with conversion to C-array format.
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
OpenAiGptConfigResources::GPT,
|
||||
));
|
||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
OpenAiGptVocabResources::GPT,
|
||||
));
|
||||
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
OpenAiGptMergesResources::GPT,
|
||||
));
|
||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
OpenAiGptModelResources::GPT,
|
||||
));
|
||||
let _ = config_resource.get_local_path()?;
|
||||
let _ = vocab_resource.get_local_path()?;
|
||||
let _ = merges_resource.get_local_path()?;
|
||||
let _ = weights_resource.get_local_path()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn download_roberta() -> anyhow::Result<()> {
|
||||
// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
RobertaConfigResources::ROBERTA,
|
||||
));
|
||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
RobertaVocabResources::ROBERTA,
|
||||
));
|
||||
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
RobertaMergesResources::ROBERTA,
|
||||
));
|
||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
RobertaModelResources::ROBERTA,
|
||||
));
|
||||
let _ = config_resource.get_local_path()?;
|
||||
let _ = vocab_resource.get_local_path()?;
|
||||
let _ = merges_resource.get_local_path()?;
|
||||
let _ = weights_resource.get_local_path()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn download_bert() -> anyhow::Result<()> {
|
||||
// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/bert. Modified with conversion to C-array format.
|
||||
let config_resource =
|
||||
Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT));
|
||||
let vocab_resource =
|
||||
Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT));
|
||||
let weights_resource =
|
||||
Resource::Remote(RemoteResource::from_pretrained(BertModelResources::BERT));
|
||||
let _ = config_resource.get_local_path()?;
|
||||
let _ = vocab_resource.get_local_path()?;
|
||||
let _ = weights_resource.get_local_path()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn download_bert_ner() -> anyhow::Result<()> {
|
||||
// Shared under MIT license by the MDZ Digital Library team at the Bavarian State Library at https://github.com/dbmdz/berts. Modified with conversion to C-array format.
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
BertConfigResources::BERT_NER,
|
||||
));
|
||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
BertVocabResources::BERT_NER,
|
||||
));
|
||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
BertModelResources::BERT_NER,
|
||||
));
|
||||
let _ = config_resource.get_local_path()?;
|
||||
let _ = vocab_resource.get_local_path()?;
|
||||
let _ = weights_resource.get_local_path()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn download_bart() -> anyhow::Result<()> {
|
||||
// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
|
||||
let config_resource =
|
||||
Resource::Remote(RemoteResource::from_pretrained(BartConfigResources::BART));
|
||||
let vocab_resource =
|
||||
Resource::Remote(RemoteResource::from_pretrained(BartVocabResources::BART));
|
||||
let merges_resource =
|
||||
Resource::Remote(RemoteResource::from_pretrained(BartMergesResources::BART));
|
||||
let weights_resource =
|
||||
Resource::Remote(RemoteResource::from_pretrained(BartModelResources::BART));
|
||||
let _ = config_resource.get_local_path()?;
|
||||
let _ = vocab_resource.get_local_path()?;
|
||||
let _ = merges_resource.get_local_path()?;
|
||||
let _ = weights_resource.get_local_path()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn download_bart_cnn() -> anyhow::Result<()> {
|
||||
// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
BartConfigResources::BART_CNN,
|
||||
));
|
||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
BartVocabResources::BART_CNN,
|
||||
));
|
||||
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
BartMergesResources::BART_CNN,
|
||||
));
|
||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
BartModelResources::BART_CNN,
|
||||
));
|
||||
let _ = config_resource.get_local_path()?;
|
||||
let _ = vocab_resource.get_local_path()?;
|
||||
let _ = merges_resource.get_local_path()?;
|
||||
let _ = weights_resource.get_local_path()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn download_electra_generator() -> anyhow::Result<()> {
|
||||
// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/electra. Modified with conversion to C-array format.
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
ElectraConfigResources::BASE_GENERATOR,
|
||||
));
|
||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
ElectraVocabResources::BASE_GENERATOR,
|
||||
));
|
||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
ElectraModelResources::BASE_GENERATOR,
|
||||
));
|
||||
let _ = config_resource.get_local_path()?;
|
||||
let _ = vocab_resource.get_local_path()?;
|
||||
let _ = weights_resource.get_local_path()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn download_electra_discriminator() -> anyhow::Result<()> {
|
||||
// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/electra. Modified with conversion to C-array format.
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
ElectraConfigResources::BASE_DISCRIMINATOR,
|
||||
));
|
||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
ElectraVocabResources::BASE_DISCRIMINATOR,
|
||||
));
|
||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
ElectraModelResources::BASE_DISCRIMINATOR,
|
||||
));
|
||||
let _ = config_resource.get_local_path()?;
|
||||
let _ = vocab_resource.get_local_path()?;
|
||||
let _ = weights_resource.get_local_path()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn download_albert_base_v2() -> anyhow::Result<()> {
|
||||
// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/ALBERT. Modified with conversion to C-array format.
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
AlbertConfigResources::ALBERT_BASE_V2,
|
||||
));
|
||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
AlbertVocabResources::ALBERT_BASE_V2,
|
||||
));
|
||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
AlbertModelResources::ALBERT_BASE_V2,
|
||||
));
|
||||
let _ = config_resource.get_local_path()?;
|
||||
let _ = vocab_resource.get_local_path()?;
|
||||
let _ = weights_resource.get_local_path()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn _download_dialogpt() -> anyhow::Result<()> {
|
||||
// Shared under MIT license by the Microsoft team at https://huggingface.co/microsoft/DialoGPT-medium. Modified with conversion to C-array format.
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
Gpt2ConfigResources::DIALOGPT_MEDIUM,
|
||||
));
|
||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
Gpt2VocabResources::DIALOGPT_MEDIUM,
|
||||
));
|
||||
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
Gpt2MergesResources::DIALOGPT_MEDIUM,
|
||||
));
|
||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
Gpt2ModelResources::DIALOGPT_MEDIUM,
|
||||
));
|
||||
let _ = config_resource.get_local_path()?;
|
||||
let _ = vocab_resource.get_local_path()?;
|
||||
let _ = merges_resource.get_local_path()?;
|
||||
let _ = weights_resource.get_local_path()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn download_t5_small() -> anyhow::Result<()> {
|
||||
// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/text-to-text-transfer-transformer.
|
||||
let config_resource =
|
||||
Resource::Remote(RemoteResource::from_pretrained(T5ConfigResources::T5_SMALL));
|
||||
let vocab_resource =
|
||||
Resource::Remote(RemoteResource::from_pretrained(T5VocabResources::T5_SMALL));
|
||||
let weights_resource =
|
||||
Resource::Remote(RemoteResource::from_pretrained(T5ModelResources::T5_SMALL));
|
||||
let _ = config_resource.get_local_path()?;
|
||||
let _ = vocab_resource.get_local_path()?;
|
||||
let _ = weights_resource.get_local_path()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn download_roberta_qa() -> anyhow::Result<()> {
|
||||
// Shared under Apache 2.0 license by [deepset](https://deepset.ai) at https://huggingface.co/deepset/roberta-base-squad2.
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
RobertaConfigResources::ROBERTA_QA,
|
||||
));
|
||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
RobertaVocabResources::ROBERTA_QA,
|
||||
));
|
||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
RobertaModelResources::ROBERTA_QA,
|
||||
));
|
||||
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
RobertaMergesResources::ROBERTA_QA,
|
||||
));
|
||||
let _ = config_resource.get_local_path()?;
|
||||
let _ = vocab_resource.get_local_path()?;
|
||||
let _ = merges_resource.get_local_path()?;
|
||||
let _ = weights_resource.get_local_path()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn download_bert_qa() -> anyhow::Result<()> {
|
||||
// Shared under Apache 2.0 license by [deepset](https://deepset.ai) at https://huggingface.co/deepset/roberta-base-squad2.
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
BertConfigResources::BERT_QA,
|
||||
));
|
||||
let vocab_resource =
|
||||
Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT_QA));
|
||||
let weights_resource =
|
||||
Resource::Remote(RemoteResource::from_pretrained(BertModelResources::BERT_QA));
|
||||
let _ = config_resource.get_local_path()?;
|
||||
let _ = vocab_resource.get_local_path()?;
|
||||
let _ = weights_resource.get_local_path()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn download_xlm_roberta_ner_german() -> anyhow::Result<()> {
|
||||
// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models.
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
RobertaConfigResources::XLM_ROBERTA_NER_DE,
|
||||
));
|
||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
RobertaVocabResources::XLM_ROBERTA_NER_DE,
|
||||
));
|
||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
RobertaModelResources::XLM_ROBERTA_NER_DE,
|
||||
));
|
||||
let _ = config_resource.get_local_path()?;
|
||||
let _ = vocab_resource.get_local_path()?;
|
||||
let _ = weights_resource.get_local_path()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn download_xlnet_base_cased() -> anyhow::Result<()> {
|
||||
// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models.
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
XLNetConfigResources::XLNET_BASE_CASED,
|
||||
));
|
||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
XLNetVocabResources::XLNET_BASE_CASED,
|
||||
));
|
||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
XLNetModelResources::XLNET_BASE_CASED,
|
||||
));
|
||||
let _ = config_resource.get_local_path()?;
|
||||
let _ = vocab_resource.get_local_path()?;
|
||||
let _ = weights_resource.get_local_path()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn main() {
|
||||
let _ = download_distil_gpt2();
|
||||
let _ = download_distilbert_sst2();
|
||||
let _ = download_distilbert_qa();
|
||||
let _ = download_distilbert();
|
||||
let _ = download_gpt2();
|
||||
let _ = download_gpt();
|
||||
let _ = download_roberta();
|
||||
let _ = download_bert();
|
||||
let _ = download_bert_ner();
|
||||
let _ = download_bart();
|
||||
let _ = download_bart_cnn();
|
||||
let _ = download_electra_generator();
|
||||
let _ = download_electra_discriminator();
|
||||
let _ = download_albert_base_v2();
|
||||
let _ = download_t5_small();
|
||||
let _ = download_roberta_qa();
|
||||
let _ = download_bert_qa();
|
||||
let _ = download_xlm_roberta_ner_german();
|
||||
let _ = download_xlnet_base_cased();
|
||||
}
|
@ -1,96 +0,0 @@
|
||||
// Copyright 2020 The Google Research Authors.
|
||||
// Copyright 2019-present, the HuggingFace Inc. team
|
||||
// Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
||||
// Copyright 2019 Guillaume Becquin
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use rust_bert::electra::{
|
||||
ElectraConfig, ElectraConfigResources, ElectraDiscriminator, ElectraModelResources,
|
||||
ElectraVocabResources,
|
||||
};
|
||||
use rust_bert::resources::{RemoteResource, Resource};
|
||||
use rust_bert::Config;
|
||||
use rust_tokenizers::tokenizer::{
|
||||
BertTokenizer, MultiThreadedTokenizer, Tokenizer, TruncationStrategy,
|
||||
};
|
||||
use tch::{nn, no_grad, Device, Tensor};
|
||||
|
||||
fn main() -> anyhow::Result<()> {
|
||||
// Resources paths
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
ElectraConfigResources::BASE_DISCRIMINATOR,
|
||||
));
|
||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
ElectraVocabResources::BASE_DISCRIMINATOR,
|
||||
));
|
||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
ElectraModelResources::BASE_DISCRIMINATOR,
|
||||
));
|
||||
let config_path = config_resource.get_local_path()?;
|
||||
let vocab_path = vocab_resource.get_local_path()?;
|
||||
let weights_path = weights_resource.get_local_path()?;
|
||||
|
||||
// Set-up masked LM model
|
||||
let device = Device::Cpu;
|
||||
let mut vs = nn::VarStore::new(device);
|
||||
let tokenizer: BertTokenizer =
|
||||
BertTokenizer::from_file(vocab_path.to_str().unwrap(), true, true)?;
|
||||
let config = ElectraConfig::from_file(config_path);
|
||||
let electra_model = ElectraDiscriminator::new(&vs.root(), &config);
|
||||
vs.load(weights_path)?;
|
||||
|
||||
// Define input
|
||||
let input = ["One Two Three Ten Five Six Seven Eight"];
|
||||
let tokenized_input = MultiThreadedTokenizer::encode_list(
|
||||
&tokenizer,
|
||||
&input,
|
||||
128,
|
||||
&TruncationStrategy::LongestFirst,
|
||||
0,
|
||||
);
|
||||
let max_len = tokenized_input
|
||||
.iter()
|
||||
.map(|input| input.token_ids.len())
|
||||
.max()
|
||||
.unwrap();
|
||||
let encoded_input = tokenized_input
|
||||
.iter()
|
||||
.map(|input| input.token_ids.clone())
|
||||
.map(|mut input| {
|
||||
input.extend(vec![0; max_len - input.len()]);
|
||||
input
|
||||
})
|
||||
.map(|input| Tensor::of_slice(&(input)))
|
||||
.collect::<Vec<_>>();
|
||||
let input_tensor = Tensor::stack(encoded_input.as_slice(), 0).to(device);
|
||||
|
||||
// Forward pass
|
||||
let model_output =
|
||||
no_grad(|| electra_model.forward_t(Some(input_tensor), None, None, None, None, false));
|
||||
|
||||
// Print model predictions
|
||||
for (position, token) in tokenized_input[0].token_ids.iter().enumerate() {
|
||||
let probability = model_output.probabilities.double_value(&[position as i64]);
|
||||
let generated = if probability > 0.5 {
|
||||
"generated"
|
||||
} else {
|
||||
"original"
|
||||
};
|
||||
println!(
|
||||
"{:?}: {} ({:.1}%)",
|
||||
tokenizer.decode([*token].to_vec(), false, false),
|
||||
generated,
|
||||
100f64 * probability
|
||||
)
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
@ -1,93 +0,0 @@
|
||||
// Copyright 2020 The Google Research Authors.
|
||||
// Copyright 2019-present, the HuggingFace Inc. team
|
||||
// Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
||||
// Copyright 2019 Guillaume Becquin
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use rust_bert::electra::{
|
||||
ElectraConfig, ElectraConfigResources, ElectraForMaskedLM, ElectraModelResources,
|
||||
ElectraVocabResources,
|
||||
};
|
||||
use rust_bert::resources::{RemoteResource, Resource};
|
||||
use rust_bert::Config;
|
||||
use rust_tokenizers::tokenizer::{BertTokenizer, MultiThreadedTokenizer, TruncationStrategy};
|
||||
use rust_tokenizers::vocab::Vocab;
|
||||
use tch::{nn, no_grad, Device, Tensor};
|
||||
|
||||
fn main() -> anyhow::Result<()> {
|
||||
// Resources paths
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
ElectraConfigResources::BASE_GENERATOR,
|
||||
));
|
||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
ElectraVocabResources::BASE_GENERATOR,
|
||||
));
|
||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
ElectraModelResources::BASE_GENERATOR,
|
||||
));
|
||||
let config_path = config_resource.get_local_path()?;
|
||||
let vocab_path = vocab_resource.get_local_path()?;
|
||||
let weights_path = weights_resource.get_local_path()?;
|
||||
|
||||
// Set-up masked LM model
|
||||
let device = Device::Cpu;
|
||||
let mut vs = nn::VarStore::new(device);
|
||||
let tokenizer: BertTokenizer =
|
||||
BertTokenizer::from_file(vocab_path.to_str().unwrap(), true, true)?;
|
||||
let config = ElectraConfig::from_file(config_path);
|
||||
let electra_model = ElectraForMaskedLM::new(&vs.root(), &config);
|
||||
vs.load(weights_path)?;
|
||||
|
||||
// Define input
|
||||
let input = [
|
||||
"Looks like one [MASK] is missing",
|
||||
"It was a very nice and [MASK] day",
|
||||
];
|
||||
let tokenized_input = tokenizer.encode_list(&input, 128, &TruncationStrategy::LongestFirst, 0);
|
||||
let max_len = tokenized_input
|
||||
.iter()
|
||||
.map(|input| input.token_ids.len())
|
||||
.max()
|
||||
.unwrap();
|
||||
let tokenized_input = tokenized_input
|
||||
.iter()
|
||||
.map(|input| input.token_ids.clone())
|
||||
.map(|mut input| {
|
||||
input.extend(vec![0; max_len - input.len()]);
|
||||
input
|
||||
})
|
||||
.map(|input| Tensor::of_slice(&(input)))
|
||||
.collect::<Vec<_>>();
|
||||
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
|
||||
|
||||
// Forward pass
|
||||
let model_output =
|
||||
no_grad(|| electra_model.forward_t(Some(input_tensor), None, None, None, None, false));
|
||||
|
||||
// Print masked tokens
|
||||
let index_1 = model_output
|
||||
.prediction_scores
|
||||
.get(0)
|
||||
.get(4)
|
||||
.argmax(0, false);
|
||||
let index_2 = model_output
|
||||
.prediction_scores
|
||||
.get(1)
|
||||
.get(7)
|
||||
.argmax(0, false);
|
||||
let word_1 = tokenizer.vocab().id_to_token(&index_1.int64_value(&[]));
|
||||
let word_2 = tokenizer.vocab().id_to_token(&index_2.int64_value(&[]));
|
||||
|
||||
println!("{}", word_1); // Outputs "thing" : "Looks like one [thing] is missing"
|
||||
println!("{}", word_2); // Outputs "sunny" : "It was a very nice and [sunny] day"
|
||||
|
||||
Ok(())
|
||||
}
|
@ -16,7 +16,7 @@ use rust_bert::pipelines::common::ModelType;
|
||||
use rust_bert::pipelines::text_generation::{TextGenerationConfig, TextGenerationModel};
|
||||
|
||||
fn main() -> anyhow::Result<()> {
|
||||
// Set-up masked LM model
|
||||
// Set-up model
|
||||
let generate_config = TextGenerationConfig {
|
||||
model_type: ModelType::GPT2,
|
||||
max_length: 30,
|
@ -22,7 +22,7 @@ use rust_bert::reformer::{
|
||||
use rust_bert::resources::{RemoteResource, Resource};
|
||||
|
||||
fn main() -> anyhow::Result<()> {
|
||||
// Set-up masked LM model
|
||||
// Set-up model
|
||||
// Resources paths
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
ReformerConfigResources::CRIME_AND_PUNISHMENT,
|
||||
|
@ -20,7 +20,7 @@ use rust_bert::resources::{RemoteResource, Resource};
|
||||
use rust_bert::xlnet::{XLNetConfigResources, XLNetModelResources, XLNetVocabResources};
|
||||
|
||||
fn main() -> anyhow::Result<()> {
|
||||
// Set-up masked LM model
|
||||
// Set-up model
|
||||
// Resources paths
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
XLNetConfigResources::XLNET_BASE_CASED,
|
||||
|
@ -1,97 +0,0 @@
|
||||
// Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc.
|
||||
// Copyright 2019 Guillaume Becquin
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
extern crate anyhow;
|
||||
|
||||
use rust_bert::gpt2::{
|
||||
GPT2LMHeadModel, Gpt2Config, Gpt2ConfigResources, Gpt2MergesResources, Gpt2ModelResources,
|
||||
Gpt2VocabResources,
|
||||
};
|
||||
use rust_bert::pipelines::generation_utils::{Cache, LMHeadModel};
|
||||
use rust_bert::resources::{RemoteResource, Resource};
|
||||
use rust_bert::Config;
|
||||
use rust_tokenizers::tokenizer::{Gpt2Tokenizer, Tokenizer, TruncationStrategy};
|
||||
use tch::{nn, Device, Tensor};
|
||||
|
||||
fn main() -> anyhow::Result<()> {
|
||||
// Resources set-up
|
||||
let config_resource =
|
||||
Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2));
|
||||
let vocab_resource =
|
||||
Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2));
|
||||
let merges_resource =
|
||||
Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
|
||||
let weights_resource =
|
||||
Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
|
||||
let config_path = config_resource.get_local_path()?;
|
||||
let vocab_path = vocab_resource.get_local_path()?;
|
||||
let merges_path = merges_resource.get_local_path()?;
|
||||
let weights_path = weights_resource.get_local_path()?;
|
||||
|
||||
// Set-up masked LM model
|
||||
let device = Device::Cpu;
|
||||
let mut vs = nn::VarStore::new(device);
|
||||
let tokenizer: Gpt2Tokenizer = Gpt2Tokenizer::from_file(
|
||||
vocab_path.to_str().unwrap(),
|
||||
merges_path.to_str().unwrap(),
|
||||
false,
|
||||
)?;
|
||||
let config = Gpt2Config::from_file(config_path);
|
||||
let gpt2_model = GPT2LMHeadModel::new(&vs.root(), &config);
|
||||
vs.load(weights_path)?;
|
||||
|
||||
// Define input
|
||||
let input = ["One two three four five six seven eight nine ten eleven"];
|
||||
let tokenized_input = tokenizer.encode_list(&input, 128, &TruncationStrategy::LongestFirst, 0);
|
||||
let max_len = tokenized_input
|
||||
.iter()
|
||||
.map(|input| input.token_ids.len())
|
||||
.max()
|
||||
.unwrap();
|
||||
let tokenized_input = tokenized_input
|
||||
.iter()
|
||||
.map(|input| input.token_ids.clone())
|
||||
.map(|mut input| {
|
||||
input.extend(vec![0; max_len - input.len()]);
|
||||
input
|
||||
})
|
||||
.map(|input| Tensor::of_slice(&(input)))
|
||||
.collect::<Vec<_>>();
|
||||
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
|
||||
|
||||
// Forward pass
|
||||
let model_output = gpt2_model
|
||||
.forward_t(
|
||||
&Some(input_tensor),
|
||||
Cache::None,
|
||||
&None,
|
||||
&None,
|
||||
&None,
|
||||
&None,
|
||||
None,
|
||||
&None,
|
||||
false,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let next_word_id = model_output
|
||||
.lm_logits
|
||||
.get(0)
|
||||
.get(-1)
|
||||
.argmax(-1, true)
|
||||
.int64_value(&[0]);
|
||||
let next_word = tokenizer.decode(vec![next_word_id], true, true);
|
||||
println!("Provided input: {}", input[0]);
|
||||
println!("Next word: {}", next_word);
|
||||
|
||||
Ok(())
|
||||
}
|
@ -1,101 +0,0 @@
|
||||
// Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc.
|
||||
// Copyright 2019 Guillaume Becquin
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
extern crate anyhow;
|
||||
|
||||
use rust_bert::mobilebert::{
|
||||
MobileBertConfig, MobileBertConfigResources, MobileBertForMaskedLM, MobileBertModelResources,
|
||||
MobileBertVocabResources,
|
||||
};
|
||||
use rust_bert::resources::{RemoteResource, Resource};
|
||||
use rust_bert::Config;
|
||||
use rust_tokenizers::tokenizer::{BertTokenizer, MultiThreadedTokenizer, TruncationStrategy};
|
||||
use rust_tokenizers::vocab::Vocab;
|
||||
use tch::{nn, no_grad, Device, Tensor};
|
||||
|
||||
fn main() -> anyhow::Result<()> {
|
||||
// Resources paths
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
MobileBertConfigResources::MOBILEBERT_UNCASED,
|
||||
));
|
||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
MobileBertVocabResources::MOBILEBERT_UNCASED,
|
||||
));
|
||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
MobileBertModelResources::MOBILEBERT_UNCASED,
|
||||
));
|
||||
let config_path = config_resource.get_local_path()?;
|
||||
let vocab_path = vocab_resource.get_local_path()?;
|
||||
let weights_path = weights_resource.get_local_path()?;
|
||||
|
||||
// Set-up masked LM model
|
||||
let device = Device::Cpu;
|
||||
let mut vs = nn::VarStore::new(device);
|
||||
let tokenizer: BertTokenizer =
|
||||
BertTokenizer::from_file(vocab_path.to_str().unwrap(), true, true)?;
|
||||
let config = MobileBertConfig::from_file(config_path);
|
||||
let mobilebert_model = MobileBertForMaskedLM::new(&vs.root(), &config);
|
||||
vs.load(weights_path)?;
|
||||
|
||||
// Define input
|
||||
let input = [
|
||||
"Looks like one [MASK] is missing",
|
||||
"It was a very nice and [MASK] day",
|
||||
];
|
||||
let tokenized_input = tokenizer.encode_list(&input, 128, &TruncationStrategy::LongestFirst, 0);
|
||||
let max_len = tokenized_input
|
||||
.iter()
|
||||
.map(|input| input.token_ids.len())
|
||||
.max()
|
||||
.unwrap();
|
||||
let tokenized_input = tokenized_input
|
||||
.iter()
|
||||
.map(|input| input.token_ids.clone())
|
||||
.map(|mut input| {
|
||||
input.extend(vec![0; max_len - input.len()]);
|
||||
input
|
||||
})
|
||||
.map(|input| Tensor::of_slice(&(input)))
|
||||
.collect::<Vec<_>>();
|
||||
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
|
||||
|
||||
// Forward pass
|
||||
let model_output =
|
||||
no_grad(|| mobilebert_model.forward_t(Some(&input_tensor), None, None, None, None, false))?;
|
||||
|
||||
// Print masked tokens
|
||||
let index_1 = model_output.logits.get(0).get(4).argmax(0, false);
|
||||
let index_2 = model_output.logits.get(1).get(7).argmax(0, false);
|
||||
let word_1 = tokenizer.vocab().id_to_token(&index_1.int64_value(&[]));
|
||||
let word_2 = tokenizer.vocab().id_to_token(&index_2.int64_value(&[]));
|
||||
|
||||
println!("{}", word_1); // Outputs "thing" : "Looks like one [thing] is missing"
|
||||
println!(
|
||||
"score: {}",
|
||||
model_output
|
||||
.logits
|
||||
.get(0)
|
||||
.get(4)
|
||||
.double_value(&[i64::from(&index_1)])
|
||||
); // 10.0558
|
||||
|
||||
println!("{}", word_2); // Outputs "sunny" : "It was a very nice and [sunny] day"
|
||||
println!(
|
||||
"score: {}",
|
||||
model_output
|
||||
.logits
|
||||
.get(1)
|
||||
.get(7)
|
||||
.double_value(&[i64::from(&index_2)])
|
||||
); // 14.2708
|
||||
Ok(())
|
||||
}
|
@ -1,102 +0,0 @@
|
||||
// Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc.
|
||||
// Copyright 2019 Guillaume Becquin
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
extern crate anyhow;
|
||||
|
||||
use rust_bert::gpt2::Gpt2Config;
|
||||
use rust_bert::openai_gpt::{
|
||||
OpenAIGPTLMHeadModel, OpenAiGptConfigResources, OpenAiGptMergesResources,
|
||||
OpenAiGptModelResources, OpenAiGptVocabResources,
|
||||
};
|
||||
use rust_bert::pipelines::generation_utils::{Cache, LMHeadModel};
|
||||
use rust_bert::resources::{RemoteResource, Resource};
|
||||
use rust_bert::Config;
|
||||
use rust_tokenizers::tokenizer::{OpenAiGptTokenizer, Tokenizer, TruncationStrategy};
|
||||
use tch::{nn, Device, Tensor};
|
||||
|
||||
fn main() -> anyhow::Result<()> {
|
||||
// Resources paths
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
OpenAiGptConfigResources::GPT,
|
||||
));
|
||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
OpenAiGptVocabResources::GPT,
|
||||
));
|
||||
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
OpenAiGptMergesResources::GPT,
|
||||
));
|
||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
OpenAiGptModelResources::GPT,
|
||||
));
|
||||
let config_path = config_resource.get_local_path()?;
|
||||
let vocab_path = vocab_resource.get_local_path()?;
|
||||
let merges_path = merges_resource.get_local_path()?;
|
||||
let weights_path = weights_resource.get_local_path()?;
|
||||
|
||||
// Set-up masked LM model
|
||||
let device = Device::Cpu;
|
||||
let mut vs = nn::VarStore::new(device);
|
||||
let tokenizer = OpenAiGptTokenizer::from_file(
|
||||
vocab_path.to_str().unwrap(),
|
||||
merges_path.to_str().unwrap(),
|
||||
true,
|
||||
)?;
|
||||
let config = Gpt2Config::from_file(config_path);
|
||||
let openai_gpt = OpenAIGPTLMHeadModel::new(&vs.root(), &config);
|
||||
vs.load(weights_path)?;
|
||||
|
||||
// Define input
|
||||
let input = ["Wondering what the next word will"];
|
||||
let tokenized_input = tokenizer.encode_list(&input, 128, &TruncationStrategy::LongestFirst, 0);
|
||||
let max_len = tokenized_input
|
||||
.iter()
|
||||
.map(|input| input.token_ids.len())
|
||||
.max()
|
||||
.unwrap();
|
||||
let tokenized_input = tokenized_input
|
||||
.iter()
|
||||
.map(|input| input.token_ids.clone())
|
||||
.map(|mut input| {
|
||||
input.extend(vec![0; max_len - input.len()]);
|
||||
input
|
||||
})
|
||||
.map(|input| Tensor::of_slice(&(input)))
|
||||
.collect::<Vec<_>>();
|
||||
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
|
||||
|
||||
// Forward pass
|
||||
let model_output = openai_gpt
|
||||
.forward_t(
|
||||
&Some(input_tensor),
|
||||
Cache::None,
|
||||
&None,
|
||||
&None,
|
||||
&None,
|
||||
&None,
|
||||
None,
|
||||
&None,
|
||||
false,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let next_word_id = model_output
|
||||
.lm_logits
|
||||
.get(0)
|
||||
.get(-1)
|
||||
.argmax(-1, true)
|
||||
.int64_value(&[0]);
|
||||
let next_word = tokenizer.decode(vec![next_word_id], true, true);
|
||||
println!("Provided input: {}", input[0]);
|
||||
println!("Next word: {}", next_word);
|
||||
|
||||
Ok(())
|
||||
}
|
@ -1,74 +0,0 @@
|
||||
// Copyright 2018 Google AI and Google Brain team.
|
||||
// Copyright 2018 Carnegie Mellon University Authors.
|
||||
// Copyright 2020-present, the HuggingFace Inc. team.
|
||||
// Copyright 2020 Guillaume Becquin
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
extern crate anyhow;
|
||||
|
||||
use rust_bert::reformer::{
|
||||
ReformerConfig, ReformerConfigResources, ReformerModelResources, ReformerModelWithLMHead,
|
||||
ReformerVocabResources,
|
||||
};
|
||||
use rust_bert::resources::{RemoteResource, Resource};
|
||||
use rust_bert::Config;
|
||||
use rust_tokenizers::tokenizer::{MultiThreadedTokenizer, ReformerTokenizer, TruncationStrategy};
|
||||
use tch::{nn, Device, Tensor};
|
||||
|
||||
fn main() -> anyhow::Result<()> {
|
||||
// Resources paths
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
ReformerConfigResources::CRIME_AND_PUNISHMENT,
|
||||
));
|
||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
ReformerVocabResources::CRIME_AND_PUNISHMENT,
|
||||
));
|
||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
ReformerModelResources::CRIME_AND_PUNISHMENT,
|
||||
));
|
||||
let config_path = config_resource.get_local_path()?;
|
||||
let vocab_path = vocab_resource.get_local_path()?;
|
||||
let weights_path = weights_resource.get_local_path()?;
|
||||
|
||||
// Set-up masked LM model
|
||||
let device = Device::cuda_if_available();
|
||||
let mut vs = nn::VarStore::new(device);
|
||||
let tokenizer = ReformerTokenizer::from_file(vocab_path.to_str().unwrap(), false)?;
|
||||
let config = ReformerConfig::from_file(config_path);
|
||||
let reformer_model = ReformerModelWithLMHead::new(&vs.root(), &config)?;
|
||||
vs.load(weights_path)?;
|
||||
|
||||
// Define input
|
||||
let input = ["One two three four five six seven eight nine ten eleven One two three four five six seven eight nine ten eleven One two three four five six seven eight nine ten eleven"];
|
||||
let tokenized_input = tokenizer.encode_list(&input, 128, &TruncationStrategy::LongestFirst, 0);
|
||||
let max_len = tokenized_input
|
||||
.iter()
|
||||
.map(|input| input.token_ids.len())
|
||||
.max()
|
||||
.unwrap();
|
||||
let tokenized_input = tokenized_input
|
||||
.iter()
|
||||
.map(|input| input.token_ids.clone())
|
||||
.map(|mut input| {
|
||||
input.extend(vec![0; max_len - input.len()]);
|
||||
input
|
||||
})
|
||||
.map(|input| Tensor::of_slice(&(input)))
|
||||
.collect::<Vec<_>>();
|
||||
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
|
||||
|
||||
// Forward pass
|
||||
let _model_output =
|
||||
reformer_model.forward_t(Some(&input_tensor), None, None, None, None, None, false)?;
|
||||
|
||||
_model_output.logits.print();
|
||||
Ok(())
|
||||
}
|
@ -1,119 +0,0 @@
|
||||
// Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc.
|
||||
// Copyright 2019 Guillaume Becquin
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
extern crate anyhow;
|
||||
|
||||
use rust_bert::bert::BertConfig;
|
||||
use rust_bert::resources::{RemoteResource, Resource};
|
||||
use rust_bert::roberta::{
|
||||
RobertaConfigResources, RobertaForMaskedLM, RobertaMergesResources, RobertaModelResources,
|
||||
RobertaVocabResources,
|
||||
};
|
||||
use rust_bert::Config;
|
||||
use rust_tokenizers::tokenizer::{RobertaTokenizer, Tokenizer, TruncationStrategy};
|
||||
use rust_tokenizers::vocab::Vocab;
|
||||
use tch::{nn, no_grad, Device, Tensor};
|
||||
|
||||
fn main() -> anyhow::Result<()> {
|
||||
// Resources paths
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
RobertaConfigResources::ROBERTA,
|
||||
));
|
||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
RobertaVocabResources::ROBERTA,
|
||||
));
|
||||
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
RobertaMergesResources::ROBERTA,
|
||||
));
|
||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
RobertaModelResources::ROBERTA,
|
||||
));
|
||||
let config_path = config_resource.get_local_path()?;
|
||||
let vocab_path = vocab_resource.get_local_path()?;
|
||||
let merges_path = merges_resource.get_local_path()?;
|
||||
let weights_path = weights_resource.get_local_path()?;
|
||||
|
||||
// Set-up masked LM model
|
||||
let device = Device::Cpu;
|
||||
let mut vs = nn::VarStore::new(device);
|
||||
let tokenizer: RobertaTokenizer = RobertaTokenizer::from_file(
|
||||
vocab_path.to_str().unwrap(),
|
||||
merges_path.to_str().unwrap(),
|
||||
true,
|
||||
false,
|
||||
)?;
|
||||
let config = BertConfig::from_file(config_path);
|
||||
let bert_model = RobertaForMaskedLM::new(&vs.root(), &config);
|
||||
vs.load(weights_path)?;
|
||||
|
||||
// Define input
|
||||
let input = [
|
||||
"<pad> Looks like one thing is missing",
|
||||
"It\'s like comparing oranges to apples",
|
||||
];
|
||||
let tokenized_input = tokenizer.encode_list(&input, 128, &TruncationStrategy::LongestFirst, 0);
|
||||
let max_len = tokenized_input
|
||||
.iter()
|
||||
.map(|input| input.token_ids.len())
|
||||
.max()
|
||||
.unwrap();
|
||||
let mut tokenized_input = tokenized_input
|
||||
.iter()
|
||||
.map(|input| input.token_ids.clone())
|
||||
.map(|mut input| {
|
||||
input.extend(vec![0; max_len - input.len()]);
|
||||
input
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
// Masking the token [thing] of sentence 1 and [oranges] of sentence 2
|
||||
tokenized_input[0][4] = 103;
|
||||
tokenized_input[1][5] = 103;
|
||||
let tokenized_input = tokenized_input
|
||||
.iter()
|
||||
.map(|input| Tensor::of_slice(&(input)))
|
||||
.collect::<Vec<_>>();
|
||||
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
|
||||
|
||||
// Forward pass
|
||||
let model_output = no_grad(|| {
|
||||
bert_model.forward_t(
|
||||
Some(input_tensor),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
&None,
|
||||
&None,
|
||||
false,
|
||||
)
|
||||
});
|
||||
|
||||
// Print masked tokens
|
||||
let index_1 = model_output
|
||||
.prediction_scores
|
||||
.get(0)
|
||||
.get(4)
|
||||
.argmax(0, false);
|
||||
let index_2 = model_output
|
||||
.prediction_scores
|
||||
.get(1)
|
||||
.get(5)
|
||||
.argmax(0, false);
|
||||
let word_1 = tokenizer.vocab().id_to_token(&index_1.int64_value(&[]));
|
||||
let word_2 = tokenizer.vocab().id_to_token(&index_2.int64_value(&[]));
|
||||
|
||||
println!("{}", word_1); // Outputs "some" : "Looks like [some] thing is missing"
|
||||
println!("{}", word_2); // Outputs "apple" : "It\'s like comparing [apple] to apples"
|
||||
|
||||
Ok(())
|
||||
}
|
59
examples/translation_mbart.rs
Normal file
59
examples/translation_mbart.rs
Normal file
@ -0,0 +1,59 @@
|
||||
// Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc.
|
||||
// Copyright 2019 Guillaume Becquin
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
extern crate anyhow;
|
||||
|
||||
use rust_bert::mbart::{
|
||||
MBartConfigResources, MBartGenerator, MBartModelResources, MBartVocabResources,
|
||||
};
|
||||
use rust_bert::pipelines::generation_utils::{GenerateConfig, LanguageGenerator};
|
||||
use rust_bert::resources::{RemoteResource, Resource};
|
||||
|
||||
fn main() -> anyhow::Result<()> {
|
||||
let generate_config = GenerateConfig {
|
||||
max_length: 56,
|
||||
model_resource: Resource::Remote(RemoteResource::from_pretrained(
|
||||
MBartModelResources::MBART50_MANY_TO_MANY,
|
||||
)),
|
||||
config_resource: Resource::Remote(RemoteResource::from_pretrained(
|
||||
MBartConfigResources::MBART50_MANY_TO_MANY,
|
||||
)),
|
||||
vocab_resource: Resource::Remote(RemoteResource::from_pretrained(
|
||||
MBartVocabResources::MBART50_MANY_TO_MANY,
|
||||
)),
|
||||
merges_resource: Resource::Remote(RemoteResource::from_pretrained(
|
||||
MBartVocabResources::MBART50_MANY_TO_MANY,
|
||||
)),
|
||||
do_sample: false,
|
||||
num_beams: 1,
|
||||
..Default::default()
|
||||
};
|
||||
let model = MBartGenerator::new(generate_config)?;
|
||||
|
||||
let input_context_1 = "en_XX The quick brown fox jumps over the lazy dog.";
|
||||
let target_language = model.get_tokenizer().convert_tokens_to_ids(["de_DE"])[0];
|
||||
|
||||
let output = model.generate(
|
||||
Some(&[input_context_1]),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
target_language,
|
||||
None,
|
||||
);
|
||||
|
||||
for sentence in output {
|
||||
println!("{:?}", sentence);
|
||||
}
|
||||
Ok(())
|
||||
}
|
@ -35,13 +35,13 @@ fn main() -> anyhow::Result<()> {
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
// Set-up masked LM model
|
||||
// Set-up model
|
||||
let t5_model = T5Generator::new(generate_config)?;
|
||||
|
||||
// Define input
|
||||
let input = ["translate English to German: This sentence will get translated to German"];
|
||||
|
||||
let output = t5_model.generate(Some(input.to_vec()), None, None, None, None, None);
|
||||
let output = t5_model.generate(Some(input.to_vec()), None, None, None, None, None, None);
|
||||
println!("{:?}", output);
|
||||
|
||||
Ok(())
|
@ -1,99 +0,0 @@
|
||||
// Copyright 2018 Google AI and Google Brain team.
|
||||
// Copyright 2018 Carnegie Mellon University Authors.
|
||||
// Copyright 2020-present, the HuggingFace Inc. team.
|
||||
// Copyright 2020 Guillaume Becquin
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
extern crate anyhow;
|
||||
|
||||
use rust_bert::resources::{RemoteResource, Resource};
|
||||
use rust_bert::xlnet::{
|
||||
XLNetConfig, XLNetConfigResources, XLNetLMHeadModel, XLNetModelResources, XLNetVocabResources,
|
||||
};
|
||||
use rust_bert::Config;
|
||||
use rust_tokenizers::tokenizer::{MultiThreadedTokenizer, TruncationStrategy, XLNetTokenizer};
|
||||
use rust_tokenizers::vocab::Vocab;
|
||||
use tch::{nn, no_grad, Device, Kind, Tensor};
|
||||
|
||||
fn main() -> anyhow::Result<()> {
|
||||
// Resources paths
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
XLNetConfigResources::XLNET_BASE_CASED,
|
||||
));
|
||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
XLNetVocabResources::XLNET_BASE_CASED,
|
||||
));
|
||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
XLNetModelResources::XLNET_BASE_CASED,
|
||||
));
|
||||
let config_path = config_resource.get_local_path()?;
|
||||
let vocab_path = vocab_resource.get_local_path()?;
|
||||
let weights_path = weights_resource.get_local_path()?;
|
||||
|
||||
// Set-up masked LM model
|
||||
let device = Device::cuda_if_available();
|
||||
let mut vs = nn::VarStore::new(device);
|
||||
let tokenizer: XLNetTokenizer =
|
||||
XLNetTokenizer::from_file(vocab_path.to_str().unwrap(), false, true)?;
|
||||
let config = XLNetConfig::from_file(config_path);
|
||||
let xlnet_model = XLNetLMHeadModel::new(&vs.root(), &config);
|
||||
vs.load(weights_path)?;
|
||||
|
||||
// Define input
|
||||
let input = ["One two three four"];
|
||||
let tokenized_input = tokenizer.encode_list(&input, 128, &TruncationStrategy::LongestFirst, 0);
|
||||
let max_len = tokenized_input
|
||||
.iter()
|
||||
.map(|input| input.token_ids.len())
|
||||
.max()
|
||||
.unwrap();
|
||||
let tokenized_input = tokenized_input
|
||||
.iter()
|
||||
.map(|input| input.token_ids.clone())
|
||||
.map(|mut input| {
|
||||
input.extend(vec![0; max_len - input.len()]);
|
||||
input
|
||||
})
|
||||
.map(|input| Tensor::of_slice(&(input[..input.len() - 2])))
|
||||
.collect::<Vec<_>>();
|
||||
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
|
||||
|
||||
// Forward pass
|
||||
let perm_mask = Tensor::zeros(&[1, 4, 4], (Kind::Float, device));
|
||||
let _ = perm_mask.narrow(2, 3, 1).fill_(1.0);
|
||||
|
||||
let target_mapping = Tensor::zeros(&[1, 1, 4], (Kind::Float, device));
|
||||
let _ = target_mapping.narrow(2, 3, 1).fill_(1.0);
|
||||
let model_output = no_grad(|| {
|
||||
xlnet_model
|
||||
.forward_t(
|
||||
Some(&input_tensor),
|
||||
None,
|
||||
None,
|
||||
Some(perm_mask.as_ref()),
|
||||
Some(target_mapping.as_ref()),
|
||||
None,
|
||||
None,
|
||||
false,
|
||||
)
|
||||
.unwrap()
|
||||
});
|
||||
|
||||
let index_1 = model_output
|
||||
.lm_logits
|
||||
.get(0)
|
||||
.argmax(1, false)
|
||||
.int64_value(&[]);
|
||||
let score_1 = model_output.lm_logits.double_value(&[0, 0, index_1]);
|
||||
let word_1 = tokenizer.vocab().id_to_token(&index_1);
|
||||
println!("{}, {}, {}", index_1, score_1, word_1);
|
||||
Ok(())
|
||||
}
|
@ -11,7 +11,6 @@
|
||||
//!
|
||||
//! # Model set-up and pre-trained weights loading
|
||||
//!
|
||||
//! A full working example is provided in `examples/albert`, run with `cargo run --example albert`.
|
||||
//! The example below illustrate a Masked language model example, the structure is similar for other models.
|
||||
//! All models expect the following resources:
|
||||
//! - Configuration file expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers)
|
||||
|
@ -1112,7 +1112,7 @@ impl PrivateLanguageGenerator<BartForConditionalGeneration, RobertaVocab, Robert
|
||||
fn get_model(&self) -> &BartForConditionalGeneration {
|
||||
&self.model
|
||||
}
|
||||
fn get_tokenizer(&self) -> &TokenizerOption {
|
||||
fn _get_tokenizer(&self) -> &TokenizerOption {
|
||||
&self.tokenizer
|
||||
}
|
||||
fn get_var_store(&self) -> &nn::VarStore {
|
||||
@ -1201,7 +1201,7 @@ impl PrivateLanguageGenerator<BartForConditionalGeneration, RobertaVocab, Robert
|
||||
where
|
||||
S: AsRef<[&'a str]>,
|
||||
{
|
||||
let tokens = self.get_tokenizer().encode_list(
|
||||
let tokens = self._get_tokenizer().encode_list(
|
||||
prompt_text.as_ref(),
|
||||
max_len as usize,
|
||||
&TruncationStrategy::LongestFirst,
|
||||
@ -1217,7 +1217,7 @@ impl PrivateLanguageGenerator<BartForConditionalGeneration, RobertaVocab, Robert
|
||||
let pad_token = match pad_token_id {
|
||||
Some(value) => value,
|
||||
None => self
|
||||
.get_tokenizer()
|
||||
._get_tokenizer()
|
||||
.convert_tokens_to_ids(&[RobertaVocab::unknown_value()])[0],
|
||||
};
|
||||
|
||||
|
@ -6,8 +6,7 @@
|
||||
//!
|
||||
//! # Model set-up and pre-trained weights loading
|
||||
//!
|
||||
//! A full working example is provided in `examples/bart`, run with `cargo run --example bart`.
|
||||
//! Alternatively, the summarization capabilities are illustrated in `examples/summarization.rs`, run with `cargo run --example summarization`.
|
||||
//! The summarization capabilities are illustrated in `examples/summarization_bart`, run with `cargo run --example summarization_bart`.
|
||||
//! All models expect the following resources:
|
||||
//! - Configuration file expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers)
|
||||
//! - Model weights are expected to have a structure and parameter names following the [Transformers library](https://github.com/huggingface/transformers). A conversion using the Python utility scripts is required to convert the `.bin` weights to the `.ot` format.
|
||||
|
@ -1240,7 +1240,7 @@ mod test {
|
||||
|
||||
#[test]
|
||||
#[ignore] // compilation is enough, no need to run
|
||||
fn bart_model_send() {
|
||||
fn bert_model_send() {
|
||||
let config_resource =
|
||||
Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT));
|
||||
let config_path = config_resource.get_local_path().expect("");
|
||||
|
@ -10,7 +10,7 @@
|
||||
//!
|
||||
//! # Model set-up and pre-trained weights loading
|
||||
//!
|
||||
//! A full working example is provided in `examples/bert`, run with `cargo run --example bert`.
|
||||
//! A full working example is provided in `examples/masked_language_model_bert`, run with `cargo run --example masked_language_model_bert`.
|
||||
//! The example below illustrate a Masked language model example, the structure is similar for other models.
|
||||
//! All models expect the following resources:
|
||||
//! - Configuration file expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers)
|
||||
|
@ -9,7 +9,6 @@
|
||||
//!
|
||||
//! # Model set-up and pre-trained weights loading
|
||||
//!
|
||||
//! A full working example is provided in `examples/distilbert_masked_lm.rs`, run with `cargo run --example distilbert_masked_lm`.
|
||||
//! The example below illustrate a DistilBERT Masked language model example, the structure is similar for other models.
|
||||
//! All models expect the following resources:
|
||||
//! - Configuration file expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers)
|
||||
|
@ -14,7 +14,6 @@
|
||||
//!
|
||||
//! # Model set-up and pre-trained weights loading
|
||||
//!
|
||||
//! A full working example is provided in `examples/electra_masked_lm.rs`, run with `cargo run --example electra_masked_lm`.
|
||||
//! The example below illustrate a Masked language model example, the structure is similar for other models (e.g. discriminator).
|
||||
//! All models expect the following resources:
|
||||
//! - Configuration file expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers)
|
||||
|
@ -752,7 +752,7 @@ impl PrivateLanguageGenerator<GPT2LMHeadModel, Gpt2Vocab, Gpt2Tokenizer> for GPT
|
||||
fn get_model(&self) -> &GPT2LMHeadModel {
|
||||
&self.model
|
||||
}
|
||||
fn get_tokenizer(&self) -> &TokenizerOption {
|
||||
fn _get_tokenizer(&self) -> &TokenizerOption {
|
||||
&self.tokenizer
|
||||
}
|
||||
fn get_var_store(&self) -> &nn::VarStore {
|
||||
|
@ -6,7 +6,7 @@
|
||||
//!
|
||||
//! # Model set-up and pre-trained weights loading
|
||||
//!
|
||||
//! A full working example is provided in `examples/generation.rs`, run with `cargo run --example generation`.
|
||||
//! A full working example is provided in `examples/generation_gpt2`, run with `cargo run --example generation_gpt2`.
|
||||
//! All models expect the following resources:
|
||||
//! - Configuration file expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers)
|
||||
//! - Model weights are expected to have a structure and parameter names following the [Transformers library](https://github.com/huggingface/transformers). A conversion using the Python utility scripts is required to convert the `.bin` weights to the `.ot` format.
|
||||
|
@ -729,7 +729,7 @@ impl PrivateLanguageGenerator<GptNeoForCausalLM, Gpt2Vocab, Gpt2Tokenizer> for G
|
||||
fn get_model(&self) -> &GptNeoForCausalLM {
|
||||
&self.model
|
||||
}
|
||||
fn get_tokenizer(&self) -> &TokenizerOption {
|
||||
fn _get_tokenizer(&self) -> &TokenizerOption {
|
||||
&self.tokenizer
|
||||
}
|
||||
fn get_var_store(&self) -> &nn::VarStore {
|
||||
|
@ -5,6 +5,7 @@
|
||||
//!
|
||||
//! # Model set-up and pre-trained weights loading
|
||||
//!
|
||||
//! A full working example is provided in `examples/generation_gpt_neo`, run with `cargo run --example generation_gpt_neo`.
|
||||
//! All models expect the following resources:
|
||||
//! - Configuration file expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers)
|
||||
//! - Model weights are expected to have a structure and parameter names following the [Transformers library](https://github.com/huggingface/transformers). A conversion using the Python utility scripts is required to convert the `.bin` weights to the `.ot` format.
|
||||
|
@ -58,6 +58,7 @@
|
||||
//! GPT-Neo| | | |✅ | | | |
|
||||
//! BART|✅| | |✅ |✅| | |
|
||||
//! Marian| | | | | |✅| |
|
||||
//! MBart|✅| | |✅ | | | |
|
||||
//! Electra | |✅| | | | |✅|
|
||||
//! ALBERT |✅|✅|✅| | | |✅|
|
||||
//! T5 | | | |✅ |✅|✅| |
|
||||
|
@ -10,7 +10,7 @@
|
||||
//!
|
||||
//! # Model set-up and pre-trained weights loading
|
||||
//!
|
||||
//! A full working example (generation) is provided in `examples/question_answering_longformer`, run with `cargo run --example question_answering_longformer`.
|
||||
//! A full working example (question answering) is provided in `examples/question_answering_longformer`, run with `cargo run --example question_answering_longformer`.
|
||||
//! All models expect the following resources:
|
||||
//! - Configuration file expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers)
|
||||
//! - Model weights are expected to have a structure and parameter names following the [Transformers library](https://github.com/huggingface/transformers). A conversion using the Python utility scripts is required to convert the `.bin` weights to the `.ot` format.
|
||||
|
@ -859,7 +859,7 @@ impl PrivateLanguageGenerator<MarianForConditionalGeneration, MarianVocab, Maria
|
||||
fn get_model(&self) -> &MarianForConditionalGeneration {
|
||||
&self.model
|
||||
}
|
||||
fn get_tokenizer(&self) -> &TokenizerOption {
|
||||
fn _get_tokenizer(&self) -> &TokenizerOption {
|
||||
&self.tokenizer
|
||||
}
|
||||
fn get_var_store(&self) -> &nn::VarStore {
|
||||
@ -950,7 +950,7 @@ impl PrivateLanguageGenerator<MarianForConditionalGeneration, MarianVocab, Maria
|
||||
where
|
||||
T: AsRef<[&'a str]>,
|
||||
{
|
||||
let tokens = self.get_tokenizer().encode_list(
|
||||
let tokens = self._get_tokenizer().encode_list(
|
||||
prompt_text.as_ref(),
|
||||
max_len as usize,
|
||||
&TruncationStrategy::LongestFirst,
|
||||
@ -965,7 +965,7 @@ impl PrivateLanguageGenerator<MarianForConditionalGeneration, MarianVocab, Maria
|
||||
|
||||
let pad_token = match pad_token_id {
|
||||
Some(value) => value,
|
||||
None => self.get_tokenizer().get_unk_id(),
|
||||
None => self._get_tokenizer().get_unk_id(),
|
||||
};
|
||||
|
||||
let token_ids = token_ids
|
||||
|
@ -6,7 +6,7 @@
|
||||
//!
|
||||
//! # Model set-up and pre-trained weights loading
|
||||
//!
|
||||
//! A full working example is provided in `examples/translation.rs`, run with `cargo run --example translation`.
|
||||
//! A full working example is provided in `examples/translation_marian`, run with `cargo run --example translation_marian`.
|
||||
//! All models expect the following resources:
|
||||
//! - Configuration file expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers)
|
||||
//! - Model weights are expected to have a structure and parameter names following the [Transformers library](https://github.com/huggingface/transformers). A conversion using the Python utility scripts is required to convert the `.bin` weights to the `.ot` format.
|
||||
|
@ -698,7 +698,7 @@ impl LMHeadModel for MBartForConditionalGeneration {
|
||||
/// # let device = Device::Cpu;
|
||||
/// # let vs = nn::VarStore::new(device);
|
||||
/// # let config = MBartConfig::from_file(config_path);
|
||||
/// # let bart_model: MBartForConditionalGeneration = MBartForConditionalGeneration::new(&vs.root(), &config);
|
||||
/// # let mbart_model: MBartForConditionalGeneration = MBartForConditionalGeneration::new(&vs.root(), &config);
|
||||
/// let (batch_size, source_sequence_length, target_sequence_length) = (64, 128, 56);
|
||||
/// let input_tensor = Tensor::rand(&[batch_size, source_sequence_length], (Int64, device));
|
||||
/// let target_tensor = Tensor::rand(&[batch_size, target_sequence_length], (Int64, device));
|
||||
@ -915,7 +915,7 @@ impl PrivateLanguageGenerator<MBartForConditionalGeneration, MBart50Vocab, MBart
|
||||
fn get_model(&self) -> &MBartForConditionalGeneration {
|
||||
&self.model
|
||||
}
|
||||
fn get_tokenizer(&self) -> &TokenizerOption {
|
||||
fn _get_tokenizer(&self) -> &TokenizerOption {
|
||||
&self.tokenizer
|
||||
}
|
||||
fn get_var_store(&self) -> &nn::VarStore {
|
||||
@ -1002,7 +1002,7 @@ impl PrivateLanguageGenerator<MBartForConditionalGeneration, MBart50Vocab, MBart
|
||||
where
|
||||
S: AsRef<[&'a str]>,
|
||||
{
|
||||
let tokens = self.get_tokenizer().encode_list(
|
||||
let tokens = self._get_tokenizer().encode_list(
|
||||
prompt_text.as_ref(),
|
||||
max_len as usize,
|
||||
&TruncationStrategy::LongestFirst,
|
||||
@ -1018,7 +1018,7 @@ impl PrivateLanguageGenerator<MBartForConditionalGeneration, MBart50Vocab, MBart
|
||||
let pad_token = match pad_token_id {
|
||||
Some(value) => value,
|
||||
None => self
|
||||
.get_tokenizer()
|
||||
._get_tokenizer()
|
||||
.convert_tokens_to_ids(&[MBart50Vocab::unknown_value()])[0],
|
||||
};
|
||||
|
||||
|
@ -1,3 +1,55 @@
|
||||
//! # MBart (Liu et al.)
|
||||
//!
|
||||
//! Implementation of the MBart language model ([Multilingual Denoising Pre-training for Neural Machine Translation](https://arxiv.org/abs/2001.08210) Liu, Gu, Goyal, Li, Edunov, Ghazvininejad, Lewis, Zettlemoyer, 2020).
|
||||
//! The base model is implemented in the `mbart_model::MBartModel` struct. The model also includes a language model head: `mbart_model::MBartForConditionalGeneration`
|
||||
//! implementing the common `generation_utils::LMHeadModel` trait shared between the models used for generation (see `pipelines` for more information).
|
||||
//!
|
||||
//! # Model set-up and pre-trained weights loading
|
||||
//!
|
||||
//! The summarization capabilities are illustrated in `examples/translation_mbart`, run with `cargo run --example translation_mbart`.
|
||||
//! All models expect the following resources:
|
||||
//! - Configuration file expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers)
|
||||
//! - Model weights are expected to have a structure and parameter names following the [Transformers library](https://github.com/huggingface/transformers). A conversion using the Python utility scripts is required to convert the `.bin` weights to the `.ot` format.
|
||||
//! - `MBart50Tokenizer` using a `spiece.model` SentencePiece model
|
||||
//! Pretrained models are available and can be downloaded using RemoteResources.
|
||||
//!
|
||||
//! ```no_run
|
||||
//! # fn main() -> anyhow::Result<()> {
|
||||
//! #
|
||||
//! use tch::{nn, Device};
|
||||
//! # use std::path::PathBuf;
|
||||
//! use rust_bert::resources::{LocalResource, Resource};
|
||||
//! use rust_bert::Config;
|
||||
//! use rust_tokenizers::tokenizer::MBart50Tokenizer;
|
||||
//! use rust_bert::mbart::{MBartConfig, MBartModel};
|
||||
//!
|
||||
//! let config_resource = Resource::Local(LocalResource {
|
||||
//! local_path: PathBuf::from("path/to/config.json"),
|
||||
//! });
|
||||
//! let vocab_resource = Resource::Local(LocalResource {
|
||||
//! local_path: PathBuf::from("path/to/vocab.txt"),
|
||||
//! });
|
||||
//! let weights_resource = Resource::Local(LocalResource {
|
||||
//! local_path: PathBuf::from("path/to/model.ot"),
|
||||
//! });
|
||||
//! let config_path = config_resource.get_local_path()?;
|
||||
//! let vocab_path = vocab_resource.get_local_path()?;
|
||||
//! let weights_path = weights_resource.get_local_path()?;
|
||||
//!
|
||||
//! let device = Device::cuda_if_available();
|
||||
//! let mut vs = nn::VarStore::new(device);
|
||||
//! let tokenizer: MBart50Tokenizer = MBart50Tokenizer::from_file(
|
||||
//! vocab_path.to_str().unwrap(),
|
||||
//! false,
|
||||
//! )?;
|
||||
//! let config = MBartConfig::from_file(config_path);
|
||||
//! let bart_model = MBartModel::new(&vs.root(), &config);
|
||||
//! vs.load(weights_path)?;
|
||||
//!
|
||||
//! # Ok(())
|
||||
//! # }
|
||||
//! ```
|
||||
|
||||
mod attention;
|
||||
mod decoder;
|
||||
mod embeddings;
|
||||
|
@ -9,7 +9,6 @@
|
||||
//!
|
||||
//! # Model set-up and pre-trained weights loading
|
||||
//!
|
||||
//! A full working example (generation) is provided in `examples/mobilebert_masked_lm`, run with `cargo run --example mobilebert_masked_lm`.
|
||||
//! All models expect the following resources:
|
||||
//! - Configuration file expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers)
|
||||
//! - Model weights are expected to have a structure and parameter names following the [Transformers library](https://github.com/huggingface/transformers). A conversion using the Python utility scripts is required to convert the `.bin` weights to the `.ot` format.
|
||||
|
@ -6,7 +6,6 @@
|
||||
//!
|
||||
//! # Model set-up and pre-trained weights loading
|
||||
//!
|
||||
//! A full working example is provided in `examples/openai_gpt`, run with `cargo run --example openai_gpt`.
|
||||
//! All models expect the following resources:
|
||||
//! - Configuration file expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers)
|
||||
//! - Model weights are expected to have a structure and parameter names following the [Transformers library](https://github.com/huggingface/transformers). A conversion using the Python utility scripts is required to convert the `.bin` weights to the `.ot` format.
|
||||
|
@ -576,7 +576,7 @@ impl PrivateLanguageGenerator<OpenAIGPTLMHeadModel, OpenAiGptVocab, OpenAiGptTok
|
||||
fn get_model(&self) -> &OpenAIGPTLMHeadModel {
|
||||
&self.model
|
||||
}
|
||||
fn get_tokenizer(&self) -> &TokenizerOption {
|
||||
fn _get_tokenizer(&self) -> &TokenizerOption {
|
||||
&self.tokenizer
|
||||
}
|
||||
fn get_var_store(&self) -> &nn::VarStore {
|
||||
|
@ -691,7 +691,7 @@ impl PrivateLanguageGenerator<PegasusForConditionalGeneration, PegasusVocab, Peg
|
||||
fn get_model(&self) -> &PegasusForConditionalGeneration {
|
||||
&self.model
|
||||
}
|
||||
fn get_tokenizer(&self) -> &TokenizerOption {
|
||||
fn _get_tokenizer(&self) -> &TokenizerOption {
|
||||
&self.tokenizer
|
||||
}
|
||||
fn get_var_store(&self) -> &nn::VarStore {
|
||||
@ -775,7 +775,7 @@ impl PrivateLanguageGenerator<PegasusForConditionalGeneration, PegasusVocab, Peg
|
||||
where
|
||||
S: AsRef<[&'a str]>,
|
||||
{
|
||||
let tokens = self.get_tokenizer().encode_list(
|
||||
let tokens = self._get_tokenizer().encode_list(
|
||||
prompt_text.as_ref(),
|
||||
max_len as usize,
|
||||
&TruncationStrategy::LongestFirst,
|
||||
@ -791,7 +791,7 @@ impl PrivateLanguageGenerator<PegasusForConditionalGeneration, PegasusVocab, Peg
|
||||
let pad_token = match pad_token_id {
|
||||
Some(value) => value,
|
||||
None => self
|
||||
.get_tokenizer()
|
||||
._get_tokenizer()
|
||||
.convert_tokens_to_ids(&[PegasusVocab::pad_value()])[0],
|
||||
};
|
||||
|
||||
|
@ -698,7 +698,7 @@ impl ConversationOption {
|
||||
|
||||
pub fn get_tokenizer(&self) -> &TokenizerOption {
|
||||
match self {
|
||||
Self::GPT2(model_ref) => model_ref.get_tokenizer(),
|
||||
Self::GPT2(model_ref) => model_ref._get_tokenizer(),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -36,6 +36,7 @@
|
||||
//! let min_length = Some(32);
|
||||
//! let max_length = Some(128);
|
||||
//! let decoder_start_id = None;
|
||||
//! let forced_bos_token_id = None;
|
||||
//!
|
||||
//! let input_context = "The dog";
|
||||
//! let second_input_context = "The cat was";
|
||||
@ -45,6 +46,7 @@
|
||||
//! min_length,
|
||||
//! max_length,
|
||||
//! decoder_start_id,
|
||||
//! forced_bos_token_id,
|
||||
//! None,
|
||||
//! );
|
||||
//! # Ok(())
|
||||
@ -86,6 +88,7 @@ use crate::t5::LayerState as T5LayerState;
|
||||
use crate::xlnet::LayerState as XLNetLayerState;
|
||||
|
||||
use self::ordered_float::OrderedFloat;
|
||||
use crate::pipelines::common::TokenizerOption;
|
||||
|
||||
extern crate ordered_float;
|
||||
|
||||
@ -272,7 +275,7 @@ pub(crate) mod private_generation_utils {
|
||||
|
||||
pub trait PrivateLanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>> {
|
||||
fn get_model(&self) -> &T;
|
||||
fn get_tokenizer(&self) -> &TokenizerOption;
|
||||
fn _get_tokenizer(&self) -> &TokenizerOption;
|
||||
fn get_var_store(&self) -> &nn::VarStore;
|
||||
fn get_config(&self) -> &GenerateConfig;
|
||||
fn get_bos_id(&self) -> &Option<i64>;
|
||||
@ -322,10 +325,10 @@ pub(crate) mod private_generation_utils {
|
||||
where
|
||||
S: AsRef<[&'a str]>,
|
||||
{
|
||||
let tokens = self.get_tokenizer().tokenize_list(prompt_text.as_ref());
|
||||
let tokens = self._get_tokenizer().tokenize_list(prompt_text.as_ref());
|
||||
let token_ids = tokens
|
||||
.into_iter()
|
||||
.map(|prompt_tokens| self.get_tokenizer().convert_tokens_to_ids(&prompt_tokens))
|
||||
.map(|prompt_tokens| self._get_tokenizer().convert_tokens_to_ids(&prompt_tokens))
|
||||
.collect::<Vec<Vec<i64>>>();
|
||||
|
||||
let num_truncated_tokens = token_ids
|
||||
@ -365,7 +368,7 @@ pub(crate) mod private_generation_utils {
|
||||
|
||||
let pad_token = match pad_token_id {
|
||||
Some(value) => value,
|
||||
None => self.get_tokenizer().get_unk_id(),
|
||||
None => self._get_tokenizer().get_unk_id(),
|
||||
};
|
||||
|
||||
let token_ids = token_ids
|
||||
@ -1219,7 +1222,7 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
|
||||
/// num_return_sequences: 3,
|
||||
/// ..Default::default()
|
||||
/// };
|
||||
/// let mut gpt2_generator = GPT2Generator::new(generate_config)?;
|
||||
/// let gpt2_generator = GPT2Generator::new(generate_config)?;
|
||||
/// let input_context = "The dog";
|
||||
/// let second_input_context = "The cat was";
|
||||
///
|
||||
@ -1227,6 +1230,7 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
|
||||
/// let min_length = 32;
|
||||
/// let max_length = 128;
|
||||
/// let decoder_start_token_id = None;
|
||||
/// let forced_bos_token_id = None;
|
||||
///
|
||||
/// //Example custom function for fine-grained generation control
|
||||
/// fn force_one_paragraph(_batch_id: i64, previous_token_ids: &Tensor) -> Vec<i64> {
|
||||
@ -1251,6 +1255,7 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
|
||||
/// min_length,
|
||||
/// max_length,
|
||||
/// decoder_start_token_id,
|
||||
/// forced_bos_token_id,
|
||||
/// Some(&force_one_paragraph)
|
||||
/// );
|
||||
/// # Ok(())
|
||||
@ -1293,7 +1298,7 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
|
||||
);
|
||||
let mut output = Vec::with_capacity(generated.len());
|
||||
for generated_sequence in generated {
|
||||
output.push(self.get_tokenizer().decode(generated_sequence, true, true));
|
||||
output.push(self._get_tokenizer().decode(generated_sequence, true, true));
|
||||
}
|
||||
output
|
||||
}
|
||||
@ -1337,13 +1342,14 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
|
||||
/// num_return_sequences: 3,
|
||||
/// ..Default::default()
|
||||
/// };
|
||||
/// let mut gpt2_generator = GPT2Generator::new(generate_config)?;
|
||||
/// let gpt2_generator = GPT2Generator::new(generate_config)?;
|
||||
/// let input_context = "The dog";
|
||||
/// let second_input_context = "The cat was";
|
||||
/// let attention_mask = None;
|
||||
/// let min_length = 32;
|
||||
/// let max_length = 128;
|
||||
/// let decoder_start_token_id = None;
|
||||
/// let forced_bos_token_id = None;
|
||||
///
|
||||
/// //Example custom function for fine-grained generation control
|
||||
/// fn force_one_paragraph(_batch_id: i64, previous_token_ids: &Tensor) -> Vec<i64> {
|
||||
@ -1368,6 +1374,7 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
|
||||
/// min_length,
|
||||
/// max_length,
|
||||
/// decoder_start_token_id,
|
||||
/// forced_bos_token_id,
|
||||
/// Some(&force_one_paragraph),
|
||||
/// );
|
||||
/// # Ok(())
|
||||
@ -1462,13 +1469,14 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
|
||||
/// num_return_sequences: 3,
|
||||
/// ..Default::default()
|
||||
/// };
|
||||
/// let mut gpt2_generator = GPT2Generator::new(generate_config)?;
|
||||
/// let gpt2_generator = GPT2Generator::new(generate_config)?;
|
||||
/// let input_context = "The dog";
|
||||
/// let second_input_context = "The cat was";
|
||||
/// let attention_mask = None;
|
||||
/// let min_length = 32;
|
||||
/// let max_length = 128;
|
||||
/// let decoder_start_token_id = None;
|
||||
/// let forced_bos_token_id = None;
|
||||
///
|
||||
/// //Example custom function for fine-grained generation control
|
||||
/// fn force_one_paragraph(_batch_id: i64, previous_token_ids: &Tensor) -> Vec<i64> {
|
||||
@ -1493,6 +1501,7 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
|
||||
/// min_length,
|
||||
/// max_length,
|
||||
/// decoder_start_token_id,
|
||||
/// forced_bos_token_id,
|
||||
/// Some(&force_one_paragraph),
|
||||
/// );
|
||||
/// # Ok(())
|
||||
@ -1674,6 +1683,46 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
|
||||
}
|
||||
output_ids
|
||||
}
|
||||
|
||||
/// Returns a reference to the text generator's tokenizer
|
||||
///
|
||||
/// # Returns
|
||||
/// * `&TokenizerOption` Reference to the generator's tokenizer.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```no_run
|
||||
/// # use std::path::PathBuf;
|
||||
/// # use tch::Device;
|
||||
/// # fn main() -> anyhow::Result<()> {
|
||||
/// use rust_bert::gpt2::GPT2Generator;
|
||||
/// use rust_bert::pipelines::generation_utils::{GenerateConfig, LanguageGenerator};
|
||||
/// use tch::Tensor;
|
||||
/// # let mut home: PathBuf = dirs::home_dir().unwrap();
|
||||
/// # home.push("rustbert");
|
||||
/// # home.push("gpt2");
|
||||
/// # let config_path = &home.as_path().join("config.json");
|
||||
/// # let vocab_path = &home.as_path().join("vocab.txt");
|
||||
/// # let merges_path = &home.as_path().join("merges.txt");
|
||||
/// # let weights_path = &home.as_path().join("model.ot");
|
||||
/// let device = Device::cuda_if_available();
|
||||
/// let generate_config = GenerateConfig {
|
||||
/// max_length: 30,
|
||||
/// do_sample: true,
|
||||
/// num_beams: 5,
|
||||
/// temperature: 1.1,
|
||||
/// num_return_sequences: 3,
|
||||
/// ..Default::default()
|
||||
/// };
|
||||
/// let gpt2_generator = GPT2Generator::new(generate_config)?;
|
||||
/// let tokenizer = gpt2_generator.get_tokenizer();
|
||||
/// tokenizer.tokenize("Hello, world!");
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
fn get_tokenizer(&self) -> &TokenizerOption {
|
||||
self._get_tokenizer()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
|
@ -229,11 +229,11 @@ impl TextGenerationOption {
|
||||
/// Interface method to access tokenizer
|
||||
pub fn get_tokenizer(&self) -> &TokenizerOption {
|
||||
match self {
|
||||
Self::GPT(model_ref) => model_ref.get_tokenizer(),
|
||||
Self::GPT2(model_ref) => model_ref.get_tokenizer(),
|
||||
Self::GPTNeo(model_ref) => model_ref.get_tokenizer(),
|
||||
Self::XLNet(model_ref) => model_ref.get_tokenizer(),
|
||||
Self::Reformer(model_ref) => model_ref.get_tokenizer(),
|
||||
Self::GPT(model_ref) => model_ref._get_tokenizer(),
|
||||
Self::GPT2(model_ref) => model_ref._get_tokenizer(),
|
||||
Self::GPTNeo(model_ref) => model_ref._get_tokenizer(),
|
||||
Self::XLNet(model_ref) => model_ref._get_tokenizer(),
|
||||
Self::Reformer(model_ref) => model_ref._get_tokenizer(),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -7,7 +7,7 @@
|
||||
//!
|
||||
//! # Model set-up and pre-trained weights loading
|
||||
//!
|
||||
//! A full working example (generation) is provided in `examples/summarization_prophetnet`, run with `cargo run --example summarization_prophetnet`.
|
||||
//! A full working example (summarization) is provided in `examples/summarization_prophetnet`, run with `cargo run --example summarization_prophetnet`.
|
||||
//! All models expect the following resources:
|
||||
//! - Configuration file expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers)
|
||||
//! - Model weights are expected to have a structure and parameter names following the [Transformers library](https://github.com/huggingface/transformers). A conversion using the Python utility scripts is required to convert the `.bin` weights to the `.ot` format.
|
||||
|
@ -993,7 +993,7 @@ impl
|
||||
fn get_model(&self) -> &ProphetNetForConditionalGeneration {
|
||||
&self.model
|
||||
}
|
||||
fn get_tokenizer(&self) -> &TokenizerOption {
|
||||
fn _get_tokenizer(&self) -> &TokenizerOption {
|
||||
&self.tokenizer
|
||||
}
|
||||
fn get_var_store(&self) -> &nn::VarStore {
|
||||
@ -1069,7 +1069,7 @@ impl
|
||||
where
|
||||
S: AsRef<[&'a str]>,
|
||||
{
|
||||
let tokens = self.get_tokenizer().encode_list(
|
||||
let tokens = self._get_tokenizer().encode_list(
|
||||
prompt_text.as_ref(),
|
||||
max_len as usize,
|
||||
&TruncationStrategy::LongestFirst,
|
||||
@ -1085,7 +1085,7 @@ impl
|
||||
let pad_token = match pad_token_id {
|
||||
Some(value) => value,
|
||||
None => self
|
||||
.get_tokenizer()
|
||||
._get_tokenizer()
|
||||
.convert_tokens_to_ids(&[ProphetNetVocab::unknown_value()])[0],
|
||||
};
|
||||
|
||||
|
@ -1114,7 +1114,7 @@ impl PrivateLanguageGenerator<ReformerModelWithLMHead, ReformerVocab, ReformerTo
|
||||
fn get_model(&self) -> &ReformerModelWithLMHead {
|
||||
&self.model
|
||||
}
|
||||
fn get_tokenizer(&self) -> &TokenizerOption {
|
||||
fn _get_tokenizer(&self) -> &TokenizerOption {
|
||||
&self.tokenizer
|
||||
}
|
||||
fn get_var_store(&self) -> &nn::VarStore {
|
||||
|
@ -10,7 +10,6 @@
|
||||
//!
|
||||
//! # Model set-up and pre-trained weights loading
|
||||
//!
|
||||
//! A full working example is provided in `examples/roberta.rs`, run with `cargo run --example roberta`.
|
||||
//! The example below illustrate a Masked language model example, the structure is similar for other models.
|
||||
//! All models expect the following resources:
|
||||
//! - Configuration file expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers)
|
||||
|
@ -6,7 +6,7 @@
|
||||
//!
|
||||
//! # Model set-up and pre-trained weights loading
|
||||
//!
|
||||
//! A full working example (translation) is provided in `examples/t5`, run with `cargo run --example t5`.
|
||||
//! A full working example (summarization) is provided in `examples/summarization_t5`, run with `cargo run --example summarization_t5`.
|
||||
//! All models expect the following resources:
|
||||
//! - Configuration file expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers)
|
||||
//! - Model weights are expected to have a structure and parameter names following the [Transformers library](https://github.com/huggingface/transformers). A conversion using the Python utility scripts is required to convert the `.bin` weights to the `.ot` format.
|
||||
|
@ -778,7 +778,7 @@ impl PrivateLanguageGenerator<T5ForConditionalGeneration, T5Vocab, T5Tokenizer>
|
||||
fn get_model(&self) -> &T5ForConditionalGeneration {
|
||||
&self.model
|
||||
}
|
||||
fn get_tokenizer(&self) -> &TokenizerOption {
|
||||
fn _get_tokenizer(&self) -> &TokenizerOption {
|
||||
&self.tokenizer
|
||||
}
|
||||
fn get_var_store(&self) -> &nn::VarStore {
|
||||
@ -850,7 +850,7 @@ impl PrivateLanguageGenerator<T5ForConditionalGeneration, T5Vocab, T5Tokenizer>
|
||||
where
|
||||
S: AsRef<[&'a str]>,
|
||||
{
|
||||
let tokens = self.get_tokenizer().encode_list(
|
||||
let tokens = self._get_tokenizer().encode_list(
|
||||
prompt_text.as_ref(),
|
||||
max_len as usize,
|
||||
&TruncationStrategy::LongestFirst,
|
||||
@ -865,7 +865,7 @@ impl PrivateLanguageGenerator<T5ForConditionalGeneration, T5Vocab, T5Tokenizer>
|
||||
|
||||
let pad_token = match pad_token_id {
|
||||
Some(value) => value,
|
||||
None => self.get_tokenizer().get_unk_id(),
|
||||
None => self._get_tokenizer().get_unk_id(),
|
||||
};
|
||||
|
||||
let token_ids = token_ids
|
||||
|
@ -1615,7 +1615,7 @@ impl PrivateLanguageGenerator<XLNetLMHeadModel, XLNetVocab, XLNetTokenizer> for
|
||||
fn get_model(&self) -> &XLNetLMHeadModel {
|
||||
&self.model
|
||||
}
|
||||
fn get_tokenizer(&self) -> &TokenizerOption {
|
||||
fn _get_tokenizer(&self) -> &TokenizerOption {
|
||||
&self.tokenizer
|
||||
}
|
||||
fn get_var_store(&self) -> &nn::VarStore {
|
||||
|
@ -77,7 +77,6 @@ fn bart_lm_model() -> anyhow::Result<()> {
|
||||
|
||||
#[test]
|
||||
fn bart_summarization_greedy() -> anyhow::Result<()> {
|
||||
// Set-up masked LM model
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
BartConfigResources::DISTILBART_CNN_6_6,
|
||||
));
|
||||
@ -139,7 +138,6 @@ about exoplanets like K2-18b."];
|
||||
|
||||
#[test]
|
||||
fn bart_summarization_beam_search() -> anyhow::Result<()> {
|
||||
// Set-up masked LM model
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
BartConfigResources::DISTILBART_CNN_6_6,
|
||||
));
|
||||
@ -202,7 +200,7 @@ about exoplanets like K2-18b."];
|
||||
#[test]
|
||||
#[cfg_attr(not(feature = "all-tests"), ignore)]
|
||||
fn bart_zero_shot_classification() -> anyhow::Result<()> {
|
||||
// Set-up model model
|
||||
// Set-up model
|
||||
let zero_shot_config = ZeroShotClassificationConfig {
|
||||
device: Device::Cpu,
|
||||
..Default::default()
|
||||
@ -235,7 +233,7 @@ fn bart_zero_shot_classification() -> anyhow::Result<()> {
|
||||
#[test]
|
||||
#[cfg_attr(not(feature = "all-tests"), ignore)]
|
||||
fn bart_zero_shot_classification_multilabel() -> anyhow::Result<()> {
|
||||
// Set-up model model
|
||||
// Set-up model
|
||||
let zero_shot_config = ZeroShotClassificationConfig {
|
||||
device: Device::Cpu,
|
||||
..Default::default()
|
||||
|
@ -28,7 +28,7 @@ fn distilgpt2_lm_model() -> anyhow::Result<()> {
|
||||
let merges_path = merges_resource.get_local_path()?;
|
||||
let weights_path = weights_resource.get_local_path()?;
|
||||
|
||||
// Set-up masked LM model
|
||||
// Set-up model
|
||||
let device = Device::Cpu;
|
||||
let mut vs = nn::VarStore::new(device);
|
||||
let tokenizer: Gpt2Tokenizer = Gpt2Tokenizer::from_file(
|
||||
|
@ -426,6 +426,7 @@ fn gpt2_prefix_allowed_token_greedy() -> anyhow::Result<()> {
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
Some(&force_one_paragraph),
|
||||
);
|
||||
|
||||
@ -490,6 +491,7 @@ fn gpt2_prefix_allowed_token_beam_search() -> anyhow::Result<()> {
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
Some(&force_one_paragraph),
|
||||
);
|
||||
|
||||
|
@ -29,7 +29,7 @@ fn gpt_neo_lm() -> anyhow::Result<()> {
|
||||
let merges_path = merges_resource.get_local_path()?;
|
||||
let weights_path = weights_resource.get_local_path()?;
|
||||
|
||||
// Set-up masked LM model
|
||||
// Set-up model
|
||||
let device = Device::Cpu;
|
||||
let mut vs = nn::VarStore::new(device);
|
||||
let tokenizer: Gpt2Tokenizer = Gpt2Tokenizer::from_file(
|
||||
@ -122,7 +122,7 @@ fn test_generation_gpt_neo() -> anyhow::Result<()> {
|
||||
GptNeoModelResources::GPT_NEO_125M,
|
||||
));
|
||||
|
||||
// Set-up translation model
|
||||
// Set-up model
|
||||
let generation_config = TextGenerationConfig {
|
||||
model_type: ModelType::GPTNeo,
|
||||
model_resource,
|
||||
|
110
tests/mbart.rs
Normal file
110
tests/mbart.rs
Normal file
@ -0,0 +1,110 @@
|
||||
use rust_bert::mbart::{
|
||||
MBartConfig, MBartConfigResources, MBartGenerator, MBartModel, MBartModelResources,
|
||||
MBartVocabResources,
|
||||
};
|
||||
use rust_bert::pipelines::generation_utils::{GenerateConfig, LanguageGenerator};
|
||||
use rust_bert::pipelines::summarization::{SummarizationConfig, SummarizationModel};
|
||||
use rust_bert::resources::{RemoteResource, Resource};
|
||||
use rust_bert::Config;
|
||||
use rust_tokenizers::tokenizer::{MBart50Tokenizer, Tokenizer, TruncationStrategy};
|
||||
use tch::{nn, Device, Tensor};
|
||||
|
||||
#[test]
|
||||
fn mbart_lm_model() -> anyhow::Result<()> {
|
||||
// Resources paths
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
MBartConfigResources::MBART50_MANY_TO_MANY,
|
||||
));
|
||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
MBartVocabResources::MBART50_MANY_TO_MANY,
|
||||
));
|
||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
MBartModelResources::MBART50_MANY_TO_MANY,
|
||||
));
|
||||
let config_path = config_resource.get_local_path()?;
|
||||
let vocab_path = vocab_resource.get_local_path()?;
|
||||
let weights_path = weights_resource.get_local_path()?;
|
||||
|
||||
// Set-up masked LM model
|
||||
let device = Device::Cpu;
|
||||
let mut vs = nn::VarStore::new(device);
|
||||
let tokenizer = MBart50Tokenizer::from_file(vocab_path.to_str().unwrap(), false)?;
|
||||
let config = MBartConfig::from_file(config_path);
|
||||
let mbart_model = MBartModel::new(&vs.root() / "model", &config);
|
||||
vs.load(weights_path)?;
|
||||
|
||||
// Define input
|
||||
let input = ["One two three four"];
|
||||
let tokenized_input = tokenizer.encode_list(&input, 128, &TruncationStrategy::LongestFirst, 0);
|
||||
let max_len = tokenized_input
|
||||
.iter()
|
||||
.map(|input| input.token_ids.len())
|
||||
.max()
|
||||
.unwrap();
|
||||
let tokenized_input = tokenized_input
|
||||
.iter()
|
||||
.map(|input| input.token_ids.clone())
|
||||
.map(|mut input| {
|
||||
input.extend(vec![0; max_len - input.len()]);
|
||||
input
|
||||
})
|
||||
.map(|input| Tensor::of_slice(&(input)))
|
||||
.collect::<Vec<_>>();
|
||||
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
|
||||
|
||||
// Forward pass
|
||||
let model_output =
|
||||
mbart_model.forward_t(Some(&input_tensor), None, None, None, None, None, false);
|
||||
assert_eq!(model_output.decoder_output.size(), vec!(1, 5, 1024));
|
||||
assert_eq!(
|
||||
model_output.encoder_hidden_state.unwrap().size(),
|
||||
vec!(1, 5, 1024)
|
||||
);
|
||||
assert!((model_output.decoder_output.double_value(&[0, 0, 0]) - -0.8936).abs() < 1e-4);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mbart_translation() -> anyhow::Result<()> {
|
||||
// Resources paths
|
||||
let generate_config = GenerateConfig {
|
||||
max_length: 56,
|
||||
model_resource: Resource::Remote(RemoteResource::from_pretrained(
|
||||
MBartModelResources::MBART50_MANY_TO_MANY,
|
||||
)),
|
||||
config_resource: Resource::Remote(RemoteResource::from_pretrained(
|
||||
MBartConfigResources::MBART50_MANY_TO_MANY,
|
||||
)),
|
||||
vocab_resource: Resource::Remote(RemoteResource::from_pretrained(
|
||||
MBartVocabResources::MBART50_MANY_TO_MANY,
|
||||
)),
|
||||
merges_resource: Resource::Remote(RemoteResource::from_pretrained(
|
||||
MBartVocabResources::MBART50_MANY_TO_MANY,
|
||||
)),
|
||||
do_sample: false,
|
||||
num_beams: 3,
|
||||
..Default::default()
|
||||
};
|
||||
let model = MBartGenerator::new(generate_config)?;
|
||||
|
||||
let input_context = "en_XX The quick brown fox jumps over the lazy dog.";
|
||||
let target_language = model.get_tokenizer().convert_tokens_to_ids(["de_DE"])[0];
|
||||
|
||||
let output = model.generate(
|
||||
Some(&[input_context]),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
target_language,
|
||||
None,
|
||||
);
|
||||
|
||||
assert_eq!(output.len(), 1);
|
||||
assert_eq!(
|
||||
output[0],
|
||||
"de_DE Der schnelle braune Fuchs springt über den faulen Hund."
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
@ -117,7 +117,7 @@ fn openai_gpt_generation_greedy() -> anyhow::Result<()> {
|
||||
OpenAiGptModelResources::GPT,
|
||||
));
|
||||
|
||||
// Set-up masked LM model
|
||||
// Set-up model
|
||||
let generate_config = TextGenerationConfig {
|
||||
model_type: ModelType::OpenAiGpt,
|
||||
model_resource,
|
||||
@ -159,7 +159,7 @@ fn openai_gpt_generation_beam_search() -> anyhow::Result<()> {
|
||||
OpenAiGptModelResources::GPT,
|
||||
));
|
||||
|
||||
// Set-up masked LM model
|
||||
// Set-up model
|
||||
let generate_config = TextGenerationConfig {
|
||||
model_type: ModelType::OpenAiGpt,
|
||||
model_resource,
|
||||
@ -211,7 +211,7 @@ fn openai_gpt_generation_beam_search_multiple_prompts_without_padding() -> anyho
|
||||
OpenAiGptModelResources::GPT,
|
||||
));
|
||||
|
||||
// Set-up masked LM model
|
||||
// Set-up model
|
||||
let generate_config = TextGenerationConfig {
|
||||
model_type: ModelType::OpenAiGpt,
|
||||
model_resource,
|
||||
@ -279,7 +279,7 @@ fn openai_gpt_generation_beam_search_multiple_prompts_with_padding() -> anyhow::
|
||||
OpenAiGptModelResources::GPT,
|
||||
));
|
||||
|
||||
// Set-up masked LM model
|
||||
// Set-up model
|
||||
let generate_config = TextGenerationConfig {
|
||||
model_type: ModelType::OpenAiGpt,
|
||||
model_resource,
|
||||
|
@ -7,7 +7,7 @@ use tch::Device;
|
||||
|
||||
#[test]
|
||||
fn pegasus_summarization_greedy() -> anyhow::Result<()> {
|
||||
// Set-up masked LM model
|
||||
// Set-up model
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
PegasusConfigResources::CNN_DAILYMAIL,
|
||||
));
|
||||
|
@ -9,7 +9,7 @@ use tch::Device;
|
||||
|
||||
#[test]
|
||||
fn prophetnet_summarization_greedy() -> anyhow::Result<()> {
|
||||
// Set-up masked LM model
|
||||
// Set-up model
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
ProphetNetConfigResources::PROPHETNET_LARGE_CNN_DM,
|
||||
));
|
||||
|
Loading…
Reference in New Issue
Block a user