Code formatted using rustfmt

This commit is contained in:
Guillaume B 2020-06-23 16:54:46 +02:00
parent 0624a5368c
commit 47e36c4e8c
86 changed files with 9498 additions and 5324 deletions

View File

@ -13,18 +13,26 @@
extern crate failure;
use tch::{Device, nn, Tensor, no_grad};
use rust_tokenizers::{AlbertTokenizer, TruncationStrategy, Tokenizer, Vocab};
use rust_bert::albert::{
AlbertConfig, AlbertConfigResources, AlbertForMaskedLM, AlbertModelResources,
AlbertVocabResources,
};
use rust_bert::resources::{download_resource, RemoteResource, Resource};
use rust_bert::Config;
use rust_bert::resources::{Resource, download_resource, RemoteResource};
use rust_bert::albert::{AlbertConfig, AlbertForMaskedLM, AlbertConfigResources, AlbertVocabResources, AlbertModelResources};
use rust_tokenizers::{AlbertTokenizer, Tokenizer, TruncationStrategy, Vocab};
use tch::{nn, no_grad, Device, Tensor};
fn main() -> failure::Fallible<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(AlbertConfigResources::ALBERT_BASE_V2));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(AlbertVocabResources::ALBERT_BASE_V2));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(AlbertModelResources::ALBERT_BASE_V2));
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
AlbertConfigResources::ALBERT_BASE_V2,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
AlbertVocabResources::ALBERT_BASE_V2,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
AlbertModelResources::ALBERT_BASE_V2,
));
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
let weights_path = download_resource(&weights_resource)?;
@ -32,37 +40,38 @@ fn main() -> failure::Fallible<()> {
// Set-up masked LM model
let device = Device::Cpu;
let mut vs = nn::VarStore::new(device);
let tokenizer: AlbertTokenizer = AlbertTokenizer::from_file(vocab_path.to_str().unwrap(), true, false);
let tokenizer: AlbertTokenizer =
AlbertTokenizer::from_file(vocab_path.to_str().unwrap(), true, false);
let config = AlbertConfig::from_file(config_path);
let albert_model = AlbertForMaskedLM::new(&vs.root(), &config);
vs.load(weights_path)?;
// Define input
let input = ["Looks like one [MASK] is missing", "It was a very nice and [MASK] day"];
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
let tokenized_input = tokenized_input.
iter().
map(|input| input.token_ids.clone()).
map(|mut input| {
let input = [
"Looks like one [MASK] is missing",
"It was a very nice and [MASK] day",
];
let tokenized_input =
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let tokenized_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]);
input
}).
map(|input|
Tensor::of_slice(&(input))).
collect::<Vec<_>>();
})
.map(|input| Tensor::of_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
let (output, _, _) = no_grad(|| {
albert_model
.forward_t(Some(input_tensor),
None,
None,
None,
None,
false)
});
let (output, _, _) =
no_grad(|| albert_model.forward_t(Some(input_tensor), None, None, None, None, false));
println!("{:?}", output.double_value(&[0, 0, 0]));
// Print masked tokens
let index_1 = output.get(0).get(4).argmax(0, false);

View File

@ -12,19 +12,25 @@
extern crate failure;
use tch::{Device, nn, Tensor, no_grad};
use rust_tokenizers::{RobertaTokenizer, TruncationStrategy, Tokenizer};
use rust_bert::bart::{BartConfig, BartConfigResources, BartVocabResources, BartMergesResources, BartModelResources, BartModel};
use rust_bert::bart::{
BartConfig, BartConfigResources, BartMergesResources, BartModel, BartModelResources,
BartVocabResources,
};
use rust_bert::resources::{download_resource, RemoteResource, Resource};
use rust_bert::Config;
use rust_bert::resources::{Resource, download_resource, RemoteResource};
use rust_tokenizers::{RobertaTokenizer, Tokenizer, TruncationStrategy};
use tch::{nn, no_grad, Device, Tensor};
fn main() -> failure::Fallible<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(BartConfigResources::BART));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(BartVocabResources::BART));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(BartMergesResources::BART));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(BartModelResources::BART));
let config_resource =
Resource::Remote(RemoteResource::from_pretrained(BartConfigResources::BART));
let vocab_resource =
Resource::Remote(RemoteResource::from_pretrained(BartVocabResources::BART));
let merges_resource =
Resource::Remote(RemoteResource::from_pretrained(BartMergesResources::BART));
let weights_resource =
Resource::Remote(RemoteResource::from_pretrained(BartModelResources::BART));
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
let merges_path = download_resource(&merges_resource)?;
@ -33,7 +39,11 @@ fn main() -> failure::Fallible<()> {
// Set-up masked LM model
let device = Device::cuda_if_available();
let mut vs = nn::VarStore::new(device);
let tokenizer = RobertaTokenizer::from_file(vocab_path.to_str().unwrap(), merges_path.to_str().unwrap(), false);
let tokenizer = RobertaTokenizer::from_file(
vocab_path.to_str().unwrap(),
merges_path.to_str().unwrap(),
false,
);
let config = BartConfig::from_file(config_path);
let bart_model = BartModel::new(&vs.root(), &config, false);
vs.load(weights_path)?;
@ -63,31 +73,27 @@ about exoplanets like K2-18b."];
// Credits: WikiNews, CC BY 2.5 license (https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b)
let tokenized_input = tokenizer.encode_list(input.to_vec(), 1024, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
let tokenized_input = tokenized_input.
iter().
map(|input| input.token_ids.clone()).
map(|mut input| {
let tokenized_input =
tokenizer.encode_list(input.to_vec(), 1024, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let tokenized_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]);
input
}).
map(|input|
Tensor::of_slice(&(input))).
collect::<Vec<_>>();
})
.map(|input| Tensor::of_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
let (decoder_output, encoder_output, _, _, _, _, _) = no_grad(|| {
bart_model
.forward_t(Some(&input_tensor),
None,
None,
None,
None,
None,
false)
});
let (decoder_output, encoder_output, _, _, _, _, _) =
no_grad(|| bart_model.forward_t(Some(&input_tensor), None, None, None, None, None, false));
// Print masked tokens
println!("{:?}", encoder_output);

View File

@ -12,18 +12,22 @@
extern crate failure;
use tch::{Device, nn, Tensor, no_grad};
use rust_tokenizers::{BertTokenizer, TruncationStrategy, Tokenizer, Vocab};
use rust_bert::bert::{
BertConfig, BertConfigResources, BertForMaskedLM, BertModelResources, BertVocabResources,
};
use rust_bert::resources::{download_resource, RemoteResource, Resource};
use rust_bert::Config;
use rust_bert::bert::{BertConfig, BertForMaskedLM, BertConfigResources, BertVocabResources, BertModelResources};
use rust_bert::resources::{Resource, download_resource, RemoteResource};
use rust_tokenizers::{BertTokenizer, Tokenizer, TruncationStrategy, Vocab};
use tch::{nn, no_grad, Device, Tensor};
fn main() -> failure::Fallible<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(BertModelResources::BERT));
let config_resource =
Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT));
let vocab_resource =
Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT));
let weights_resource =
Resource::Remote(RemoteResource::from_pretrained(BertModelResources::BERT));
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
let weights_path = download_resource(&weights_resource)?;
@ -37,32 +41,40 @@ fn main() -> failure::Fallible<()> {
vs.load(weights_path)?;
// Define input
let input = ["Looks like one [MASK] is missing", "It was a very nice and [MASK] day"];
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
let tokenized_input = tokenized_input.
iter().
map(|input| input.token_ids.clone()).
map(|mut input| {
let input = [
"Looks like one [MASK] is missing",
"It was a very nice and [MASK] day",
];
let tokenized_input =
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let tokenized_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]);
input
}).
map(|input|
Tensor::of_slice(&(input))).
collect::<Vec<_>>();
})
.map(|input| Tensor::of_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
let (output, _, _) = no_grad(|| {
bert_model
.forward_t(Some(input_tensor),
bert_model.forward_t(
Some(input_tensor),
None,
None,
None,
None,
&None,
&None,
false)
false,
)
});
// Print masked tokens

View File

@ -11,20 +11,28 @@
// limitations under the License.
extern crate failure;
use tch::{Device, Tensor, nn, no_grad};
use rust_tokenizers::preprocessing::tokenizer::base_tokenizer::{Tokenizer, TruncationStrategy};
use rust_tokenizers::bert_tokenizer::BertTokenizer;
use rust_tokenizers::preprocessing::vocab::base_vocab::Vocab;
use rust_bert::distilbert::{
DistilBertConfig, DistilBertConfigResources, DistilBertModelMaskedLM, DistilBertModelResources,
DistilBertVocabResources,
};
use rust_bert::resources::{download_resource, RemoteResource, Resource};
use rust_bert::Config;
use rust_bert::distilbert::{DistilBertConfig, DistilBertModelMaskedLM, DistilBertConfigResources, DistilBertVocabResources, DistilBertModelResources};
use rust_bert::resources::{Resource, download_resource, RemoteResource};
use rust_tokenizers::bert_tokenizer::BertTokenizer;
use rust_tokenizers::preprocessing::tokenizer::base_tokenizer::{Tokenizer, TruncationStrategy};
use rust_tokenizers::preprocessing::vocab::base_vocab::Vocab;
use tch::{nn, no_grad, Device, Tensor};
fn main() -> failure::Fallible<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(DistilBertConfigResources::DISTIL_BERT));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(DistilBertVocabResources::DISTIL_BERT));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(DistilBertModelResources::DISTIL_BERT));
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
DistilBertConfigResources::DISTIL_BERT,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
DistilBertVocabResources::DISTIL_BERT,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
DistilBertModelResources::DISTIL_BERT,
));
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
let weights_path = download_resource(&weights_resource)?;
@ -38,29 +46,35 @@ fn main() -> failure::Fallible<()> {
vs.load(weights_path)?;
// Define input
let input = ["Looks like one thing is missing", "It\'s like comparing oranges to apples"];
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
let mut tokenized_input = tokenized_input.
iter().
map(|input| input.token_ids.clone()).
map(|mut input| {
let input = [
"Looks like one thing is missing",
"It\'s like comparing oranges to apples",
];
let tokenized_input =
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let mut tokenized_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]);
input
}).
collect::<Vec<_>>();
})
.collect::<Vec<_>>();
// Masking the token [thing] of sentence 1 and [oranges] of sentence 2
tokenized_input[0][4] = 103;
tokenized_input[1][6] = 103;
let tokenized_input = tokenized_input.
iter().
map(|input|
Tensor::of_slice(&(input))).
collect::<Vec<_>>();
let tokenized_input = tokenized_input
.iter()
.map(|input| Tensor::of_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
let (output, _, _) = no_grad(|| {
distil_bert_model

View File

@ -1,26 +1,44 @@
extern crate failure;
use rust_bert::gpt2::{Gpt2ConfigResources, Gpt2VocabResources, Gpt2MergesResources, Gpt2ModelResources};
use rust_bert::distilbert::{DistilBertModelResources, DistilBertConfigResources, DistilBertVocabResources};
use rust_bert::openai_gpt::{OpenAiGptConfigResources, OpenAiGptVocabResources, OpenAiGptMergesResources, OpenAiGptModelResources};
use rust_bert::roberta::{RobertaConfigResources, RobertaVocabResources, RobertaMergesResources, RobertaModelResources};
use rust_bert::bert::{BertConfigResources, BertVocabResources, BertModelResources};
use rust_bert::bart::{BartConfigResources, BartVocabResources, BartMergesResources, BartModelResources};
use rust_bert::resources::{Resource, download_resource, RemoteResource};
use rust_bert::electra::{ElectraConfigResources, ElectraVocabResources, ElectraModelResources};
use rust_bert::albert::{AlbertConfigResources, AlbertVocabResources, AlbertModelResources};
use rust_bert::albert::{AlbertConfigResources, AlbertModelResources, AlbertVocabResources};
use rust_bert::bart::{
BartConfigResources, BartMergesResources, BartModelResources, BartVocabResources,
};
use rust_bert::bert::{BertConfigResources, BertModelResources, BertVocabResources};
use rust_bert::distilbert::{
DistilBertConfigResources, DistilBertModelResources, DistilBertVocabResources,
};
use rust_bert::electra::{ElectraConfigResources, ElectraModelResources, ElectraVocabResources};
use rust_bert::gpt2::{
Gpt2ConfigResources, Gpt2MergesResources, Gpt2ModelResources, Gpt2VocabResources,
};
use rust_bert::openai_gpt::{
OpenAiGptConfigResources, OpenAiGptMergesResources, OpenAiGptModelResources,
OpenAiGptVocabResources,
};
use rust_bert::resources::{download_resource, RemoteResource, Resource};
use rust_bert::roberta::{
RobertaConfigResources, RobertaMergesResources, RobertaModelResources, RobertaVocabResources,
};
/// This example downloads and caches all dependencies used in model tests. This allows for safe
/// multi threaded testing (two test using the same resource would otherwise download the file to
/// the same location).
fn download_distil_gpt2() -> failure::Fallible<()> {
// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
let config_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::DISTIL_GPT2));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::DISTIL_GPT2));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::DISTIL_GPT2));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::DISTIL_GPT2));
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
Gpt2ConfigResources::DISTIL_GPT2,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
Gpt2VocabResources::DISTIL_GPT2,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
Gpt2MergesResources::DISTIL_GPT2,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
Gpt2ModelResources::DISTIL_GPT2,
));
let _ = download_resource(&config_resource)?;
let _ = download_resource(&vocab_resource)?;
let _ = download_resource(&merges_resource)?;
@ -30,9 +48,15 @@ fn download_distil_gpt2() -> failure::Fallible<()> {
fn download_distilbert_sst2() -> failure::Fallible<()> {
// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(DistilBertModelResources::DISTIL_BERT_SST2));
let config_resource = Resource::Remote(RemoteResource::from_pretrained(DistilBertConfigResources::DISTIL_BERT_SST2));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(DistilBertVocabResources::DISTIL_BERT_SST2));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
DistilBertModelResources::DISTIL_BERT_SST2,
));
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
DistilBertConfigResources::DISTIL_BERT_SST2,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
DistilBertVocabResources::DISTIL_BERT_SST2,
));
let _ = download_resource(&config_resource)?;
let _ = download_resource(&vocab_resource)?;
let _ = download_resource(&weights_resource)?;
@ -41,9 +65,15 @@ fn download_distilbert_sst2() -> failure::Fallible<()> {
fn download_distilbert_qa() -> failure::Fallible<()> {
// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(DistilBertModelResources::DISTIL_BERT_SQUAD));
let config_resource = Resource::Remote(RemoteResource::from_pretrained(DistilBertConfigResources::DISTIL_BERT_SQUAD));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(DistilBertVocabResources::DISTIL_BERT_SQUAD));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
DistilBertModelResources::DISTIL_BERT_SQUAD,
));
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
DistilBertConfigResources::DISTIL_BERT_SQUAD,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
DistilBertVocabResources::DISTIL_BERT_SQUAD,
));
let _ = download_resource(&config_resource)?;
let _ = download_resource(&vocab_resource)?;
let _ = download_resource(&weights_resource)?;
@ -52,9 +82,15 @@ fn download_distilbert_qa() -> failure::Fallible<()> {
fn download_distilbert() -> failure::Fallible<()> {
// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(DistilBertModelResources::DISTIL_BERT));
let config_resource = Resource::Remote(RemoteResource::from_pretrained(DistilBertConfigResources::DISTIL_BERT));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(DistilBertVocabResources::DISTIL_BERT));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
DistilBertModelResources::DISTIL_BERT,
));
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
DistilBertConfigResources::DISTIL_BERT,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
DistilBertVocabResources::DISTIL_BERT,
));
let _ = download_resource(&config_resource)?;
let _ = download_resource(&vocab_resource)?;
let _ = download_resource(&weights_resource)?;
@ -63,10 +99,14 @@ fn download_distilbert() -> failure::Fallible<()> {
fn download_gpt2() -> failure::Fallible<()> {
// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2. Modified with conversion to C-array format.
let config_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
let config_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2));
let vocab_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2));
let merges_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
let weights_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
let _ = download_resource(&config_resource)?;
let _ = download_resource(&vocab_resource)?;
let _ = download_resource(&merges_resource)?;
@ -76,10 +116,18 @@ fn download_gpt2() -> failure::Fallible<()> {
fn download_gpt() -> failure::Fallible<()> {
// Shared under MIT license by the OpenAI team at https://github.com/openai/finetune-transformer-lm. Modified with conversion to C-array format.
let config_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptConfigResources::GPT));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptVocabResources::GPT));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptMergesResources::GPT));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptModelResources::GPT));
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
OpenAiGptConfigResources::GPT,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
OpenAiGptVocabResources::GPT,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
OpenAiGptMergesResources::GPT,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
OpenAiGptModelResources::GPT,
));
let _ = download_resource(&config_resource)?;
let _ = download_resource(&vocab_resource)?;
let _ = download_resource(&merges_resource)?;
@ -89,10 +137,18 @@ fn download_gpt() -> failure::Fallible<()> {
fn download_roberta() -> failure::Fallible<()> {
// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
let config_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaConfigResources::ROBERTA));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaVocabResources::ROBERTA));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaMergesResources::ROBERTA));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaModelResources::ROBERTA));
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
RobertaConfigResources::ROBERTA,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
RobertaVocabResources::ROBERTA,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
RobertaMergesResources::ROBERTA,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
RobertaModelResources::ROBERTA,
));
let _ = download_resource(&config_resource)?;
let _ = download_resource(&vocab_resource)?;
let _ = download_resource(&merges_resource)?;
@ -102,9 +158,12 @@ fn download_roberta() -> failure::Fallible<()> {
fn download_bert() -> failure::Fallible<()> {
// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/bert. Modified with conversion to C-array format.
let config_resource = Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(BertModelResources::BERT));
let config_resource =
Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT));
let vocab_resource =
Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT));
let weights_resource =
Resource::Remote(RemoteResource::from_pretrained(BertModelResources::BERT));
let _ = download_resource(&config_resource)?;
let _ = download_resource(&vocab_resource)?;
let _ = download_resource(&weights_resource)?;
@ -113,9 +172,15 @@ fn download_bert() -> failure::Fallible<()> {
fn download_bert_ner() -> failure::Fallible<()> {
// Shared under MIT license by the MDZ Digital Library team at the Bavarian State Library at https://github.com/dbmdz/berts. Modified with conversion to C-array format.
let config_resource = Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT_NER));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT_NER));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(BertModelResources::BERT_NER));
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
BertConfigResources::BERT_NER,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
BertVocabResources::BERT_NER,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
BertModelResources::BERT_NER,
));
let _ = download_resource(&config_resource)?;
let _ = download_resource(&vocab_resource)?;
let _ = download_resource(&weights_resource)?;
@ -124,10 +189,14 @@ fn download_bert_ner() -> failure::Fallible<()> {
fn download_bart() -> failure::Fallible<()> {
// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
let config_resource = Resource::Remote(RemoteResource::from_pretrained(BartConfigResources::BART));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(BartVocabResources::BART));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(BartMergesResources::BART));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(BartModelResources::BART));
let config_resource =
Resource::Remote(RemoteResource::from_pretrained(BartConfigResources::BART));
let vocab_resource =
Resource::Remote(RemoteResource::from_pretrained(BartVocabResources::BART));
let merges_resource =
Resource::Remote(RemoteResource::from_pretrained(BartMergesResources::BART));
let weights_resource =
Resource::Remote(RemoteResource::from_pretrained(BartModelResources::BART));
let _ = download_resource(&config_resource)?;
let _ = download_resource(&vocab_resource)?;
let _ = download_resource(&merges_resource)?;
@ -137,10 +206,18 @@ fn download_bart() -> failure::Fallible<()> {
fn download_bart_cnn() -> failure::Fallible<()> {
// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
let config_resource = Resource::Remote(RemoteResource::from_pretrained(BartConfigResources::BART_CNN));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(BartVocabResources::BART_CNN));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(BartMergesResources::BART_CNN));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(BartModelResources::BART_CNN));
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
BartConfigResources::BART_CNN,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
BartVocabResources::BART_CNN,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
BartMergesResources::BART_CNN,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
BartModelResources::BART_CNN,
));
let _ = download_resource(&config_resource)?;
let _ = download_resource(&vocab_resource)?;
let _ = download_resource(&merges_resource)?;
@ -150,9 +227,15 @@ fn download_bart_cnn() -> failure::Fallible<()> {
fn download_electra_generator() -> failure::Fallible<()> {
// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/electra. Modified with conversion to C-array format.
let config_resource = Resource::Remote(RemoteResource::from_pretrained(ElectraConfigResources::BASE_GENERATOR));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(ElectraVocabResources::BASE_GENERATOR));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(ElectraModelResources::BASE_GENERATOR));
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
ElectraConfigResources::BASE_GENERATOR,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
ElectraVocabResources::BASE_GENERATOR,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
ElectraModelResources::BASE_GENERATOR,
));
let _ = download_resource(&config_resource)?;
let _ = download_resource(&vocab_resource)?;
let _ = download_resource(&weights_resource)?;
@ -161,9 +244,15 @@ fn download_electra_generator() -> failure::Fallible<()> {
fn download_electra_discriminator() -> failure::Fallible<()> {
// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/electra. Modified with conversion to C-array format.
let config_resource = Resource::Remote(RemoteResource::from_pretrained(ElectraConfigResources::BASE_DISCRIMINATOR));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(ElectraVocabResources::BASE_DISCRIMINATOR));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(ElectraModelResources::BASE_DISCRIMINATOR));
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
ElectraConfigResources::BASE_DISCRIMINATOR,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
ElectraVocabResources::BASE_DISCRIMINATOR,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
ElectraModelResources::BASE_DISCRIMINATOR,
));
let _ = download_resource(&config_resource)?;
let _ = download_resource(&vocab_resource)?;
let _ = download_resource(&weights_resource)?;
@ -172,9 +261,15 @@ fn download_electra_discriminator() -> failure::Fallible<()> {
fn download_albert_base_v2() -> failure::Fallible<()> {
// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/ALBERT. Modified with conversion to C-array format.
let config_resource = Resource::Remote(RemoteResource::from_pretrained(AlbertConfigResources::ALBERT_BASE_V2));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(AlbertVocabResources::ALBERT_BASE_V2));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(AlbertModelResources::ALBERT_BASE_V2));
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
AlbertConfigResources::ALBERT_BASE_V2,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
AlbertVocabResources::ALBERT_BASE_V2,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
AlbertModelResources::ALBERT_BASE_V2,
));
let _ = download_resource(&config_resource)?;
let _ = download_resource(&vocab_resource)?;
let _ = download_resource(&weights_resource)?;

View File

@ -12,18 +12,26 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use rust_bert::resources::{Resource, download_resource, RemoteResource};
use rust_bert::electra::{ElectraConfig, ElectraDiscriminator, ElectraConfigResources, ElectraVocabResources, ElectraModelResources};
use rust_bert::electra::{
ElectraConfig, ElectraConfigResources, ElectraDiscriminator, ElectraModelResources,
ElectraVocabResources,
};
use rust_bert::resources::{download_resource, RemoteResource, Resource};
use rust_bert::Config;
use rust_tokenizers::{BertTokenizer, Tokenizer, TruncationStrategy};
use tch::{Tensor, Device, nn, no_grad};
use tch::{nn, no_grad, Device, Tensor};
fn main() -> failure::Fallible<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(ElectraConfigResources::BASE_DISCRIMINATOR));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(ElectraVocabResources::BASE_DISCRIMINATOR));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(ElectraModelResources::BASE_DISCRIMINATOR));
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
ElectraConfigResources::BASE_DISCRIMINATOR,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
ElectraVocabResources::BASE_DISCRIMINATOR,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
ElectraModelResources::BASE_DISCRIMINATOR,
));
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
let weights_path = download_resource(&weights_resource)?;
@ -38,43 +46,43 @@ fn main() -> failure::Fallible<()> {
// Define input
let input = ["One Two Three Ten Five Six Seven Eight"];
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
let encoded_input = tokenized_input.
iter().
map(|input| input.token_ids.clone()).
map(|mut input| {
let tokenized_input =
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let encoded_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]);
input
}).
map(|input|
Tensor::of_slice(&(input))).
collect::<Vec<_>>();
})
.map(|input| Tensor::of_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(encoded_input.as_slice(), 0).to(device);
// Forward pass
let (output, _, _) = no_grad(|| {
electra_model
.forward_t(Some(input_tensor),
None,
None,
None,
None,
false)
});
let (output, _, _) =
no_grad(|| electra_model.forward_t(Some(input_tensor), None, None, None, None, false));
// Print model predictions
for (position, token) in tokenized_input[0].token_ids.iter().enumerate() {
let probability = output.double_value(&[position as i64]);
let generated = if probability > 0.5 { "generated" } else { "original" };
println!("{:?}: {} ({:.1}%)",
tokenizer.decode([*token].to_vec(),
false,
false),
let generated = if probability > 0.5 {
"generated"
} else {
"original"
};
println!(
"{:?}: {} ({:.1}%)",
tokenizer.decode([*token].to_vec(), false, false),
generated,
100f64 * probability)
100f64 * probability
)
}
Ok(())
}

View File

@ -12,18 +12,26 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use rust_bert::resources::{Resource, download_resource, RemoteResource};
use rust_bert::electra::{ElectraConfig, ElectraForMaskedLM, ElectraModelResources, ElectraConfigResources, ElectraVocabResources};
use rust_bert::electra::{
ElectraConfig, ElectraConfigResources, ElectraForMaskedLM, ElectraModelResources,
ElectraVocabResources,
};
use rust_bert::resources::{download_resource, RemoteResource, Resource};
use rust_bert::Config;
use rust_tokenizers::{BertTokenizer, Tokenizer, TruncationStrategy, Vocab};
use tch::{Tensor, Device, nn, no_grad};
use tch::{nn, no_grad, Device, Tensor};
fn main() -> failure::Fallible<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(ElectraConfigResources::BASE_GENERATOR));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(ElectraVocabResources::BASE_GENERATOR));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(ElectraModelResources::BASE_GENERATOR));
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
ElectraConfigResources::BASE_GENERATOR,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
ElectraVocabResources::BASE_GENERATOR,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
ElectraModelResources::BASE_GENERATOR,
));
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
let weights_path = download_resource(&weights_resource)?;
@ -37,31 +45,31 @@ fn main() -> failure::Fallible<()> {
vs.load(weights_path)?;
// Define input
let input = ["Looks like one [MASK] is missing", "It was a very nice and [MASK] day"];
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
let tokenized_input = tokenized_input.
iter().
map(|input| input.token_ids.clone()).
map(|mut input| {
let input = [
"Looks like one [MASK] is missing",
"It was a very nice and [MASK] day",
];
let tokenized_input =
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let tokenized_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]);
input
}).
map(|input|
Tensor::of_slice(&(input))).
collect::<Vec<_>>();
})
.map(|input| Tensor::of_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
let (output, _, _) = no_grad(|| {
electra_model
.forward_t(Some(input_tensor),
None,
None,
None,
None,
false)
});
let (output, _, _) =
no_grad(|| electra_model.forward_t(Some(input_tensor), None, None, None, None, false));
// Print masked tokens
let index_1 = output.get(0).get(4).argmax(0, false);

View File

@ -12,11 +12,9 @@
extern crate failure;
use rust_bert::pipelines::generation::{GPT2Generator, LanguageGenerator, GenerateConfig};
use rust_bert::pipelines::generation::{GPT2Generator, GenerateConfig, LanguageGenerator};
fn main() -> failure::Fallible<()> {
// Set-up masked LM model
let generate_config = GenerateConfig {
max_length: 30,
@ -30,7 +28,7 @@ fn main() -> failure::Fallible<()> {
let input_context = "The dog";
let second_input_context = "The cat was";
let output = model.generate(Some(vec!(input_context, second_input_context)), None);
let output = model.generate(Some(vec![input_context, second_input_context]), None);
for sentence in output {
println!("{:?}", sentence);

View File

@ -12,20 +12,26 @@
extern crate failure;
use tch::{Device, nn, Tensor};
use rust_tokenizers::{TruncationStrategy, Tokenizer, Gpt2Tokenizer};
use rust_bert::gpt2::{Gpt2Config, GPT2LMHeadModel, Gpt2ConfigResources, Gpt2VocabResources, Gpt2MergesResources, Gpt2ModelResources};
use rust_bert::pipelines::generation::{LMHeadModel, Cache};
use rust_bert::resources::{Resource, download_resource, RemoteResource};
use rust_bert::gpt2::{
GPT2LMHeadModel, Gpt2Config, Gpt2ConfigResources, Gpt2MergesResources, Gpt2ModelResources,
Gpt2VocabResources,
};
use rust_bert::pipelines::generation::{Cache, LMHeadModel};
use rust_bert::resources::{download_resource, RemoteResource, Resource};
use rust_bert::Config;
use rust_tokenizers::{Gpt2Tokenizer, Tokenizer, TruncationStrategy};
use tch::{nn, Device, Tensor};
fn main() -> failure::Fallible<()> {
// Resources set-up
let config_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
let config_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2));
let vocab_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2));
let merges_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
let weights_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
let merges_path = download_resource(&merges_resource)?;
@ -34,29 +40,38 @@ fn main() -> failure::Fallible<()> {
// Set-up masked LM model
let device = Device::Cpu;
let mut vs = nn::VarStore::new(device);
let tokenizer: Gpt2Tokenizer = Gpt2Tokenizer::from_file(vocab_path.to_str().unwrap(), merges_path.to_str().unwrap(), false);
let tokenizer: Gpt2Tokenizer = Gpt2Tokenizer::from_file(
vocab_path.to_str().unwrap(),
merges_path.to_str().unwrap(),
false,
);
let config = Gpt2Config::from_file(config_path);
let gpt2_model = GPT2LMHeadModel::new(&vs.root(), &config);
vs.load(weights_path)?;
// Define input
let input = ["One two three four five six seven eight nine ten eleven"];
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
let tokenized_input = tokenized_input.
iter().
map(|input| input.token_ids.clone()).
map(|mut input| {
let tokenized_input =
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let tokenized_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]);
input
}).
map(|input|
Tensor::of_slice(&(input))).
collect::<Vec<_>>();
})
.map(|input| Tensor::of_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
let (output, _, _, _, _) = gpt2_model.forward_t(
let (output, _, _, _, _) = gpt2_model
.forward_t(
&Some(input_tensor),
Cache::None,
&None,
@ -65,10 +80,12 @@ fn main() -> failure::Fallible<()> {
&None,
None,
&None,
false).unwrap();
false,
)
.unwrap();
let next_word_id = output.get(0).get(-1).argmax(-1, true).int64_value(&[0]);
let next_word = tokenizer.decode(vec!(next_word_id), true, true);
let next_word = tokenizer.decode(vec![next_word_id], true, true);
println!("Provided input: {}", input[0]);
println!("Next word: {}", next_word);

View File

@ -21,7 +21,7 @@ fn main() -> failure::Fallible<()> {
// Define input
let input = [
"My name is Amélie. I live in Москва.",
"Chongqing is a city in China."
"Chongqing is a city in China.",
];
// Run model

View File

@ -12,21 +12,31 @@
extern crate failure;
use tch::{Device, nn, Tensor};
use rust_tokenizers::{TruncationStrategy, Tokenizer, OpenAiGptTokenizer};
use rust_bert::gpt2::Gpt2Config;
use rust_bert::openai_gpt::{OpenAIGPTLMHeadModel, OpenAiGptConfigResources, OpenAiGptVocabResources, OpenAiGptMergesResources, OpenAiGptModelResources};
use rust_bert::pipelines::generation::{LMHeadModel, Cache};
use rust_bert::resources::{Resource, download_resource, RemoteResource};
use rust_bert::openai_gpt::{
OpenAIGPTLMHeadModel, OpenAiGptConfigResources, OpenAiGptMergesResources,
OpenAiGptModelResources, OpenAiGptVocabResources,
};
use rust_bert::pipelines::generation::{Cache, LMHeadModel};
use rust_bert::resources::{download_resource, RemoteResource, Resource};
use rust_bert::Config;
use rust_tokenizers::{OpenAiGptTokenizer, Tokenizer, TruncationStrategy};
use tch::{nn, Device, Tensor};
fn main() -> failure::Fallible<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptConfigResources::GPT));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptVocabResources::GPT));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptMergesResources::GPT));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptModelResources::GPT));
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
OpenAiGptConfigResources::GPT,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
OpenAiGptVocabResources::GPT,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
OpenAiGptMergesResources::GPT,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
OpenAiGptModelResources::GPT,
));
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
let merges_path = download_resource(&merges_resource)?;
@ -35,29 +45,38 @@ fn main() -> failure::Fallible<()> {
// Set-up masked LM model
let device = Device::Cpu;
let mut vs = nn::VarStore::new(device);
let tokenizer = OpenAiGptTokenizer::from_file(vocab_path.to_str().unwrap(), merges_path.to_str().unwrap(), true);
let tokenizer = OpenAiGptTokenizer::from_file(
vocab_path.to_str().unwrap(),
merges_path.to_str().unwrap(),
true,
);
let config = Gpt2Config::from_file(config_path);
let openai_gpt = OpenAIGPTLMHeadModel::new(&vs.root(), &config);
vs.load(weights_path)?;
// Define input
let input = ["Wondering what the next word will"];
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
let tokenized_input = tokenized_input.
iter().
map(|input| input.token_ids.clone()).
map(|mut input| {
let tokenized_input =
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let tokenized_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]);
input
}).
map(|input|
Tensor::of_slice(&(input))).
collect::<Vec<_>>();
})
.map(|input| Tensor::of_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
let (output, _, _, _, _) = openai_gpt.forward_t(
let (output, _, _, _, _) = openai_gpt
.forward_t(
&Some(input_tensor),
Cache::None,
&None,
@ -66,10 +85,12 @@ fn main() -> failure::Fallible<()> {
&None,
None,
&None,
false).unwrap();
false,
)
.unwrap();
let next_word_id = output.get(0).get(-1).argmax(-1, true).int64_value(&[0]);
let next_word = tokenizer.decode(vec!(next_word_id), true, true);
let next_word = tokenizer.decode(vec![next_word_id], true, true);
println!("Provided input: {}", input[0]);
println!("Next word: {}", next_word);

View File

@ -12,8 +12,7 @@
extern crate failure;
use rust_bert::pipelines::question_answering::{QuestionAnsweringModel, QaInput};
use rust_bert::pipelines::question_answering::{QaInput, QuestionAnsweringModel};
fn main() -> failure::Fallible<()> {
// Set-up Question Answering model
@ -24,11 +23,17 @@ fn main() -> failure::Fallible<()> {
let context_1 = String::from("Amy lives in Amsterdam");
let question_2 = String::from("Where does Eric live");
let context_2 = String::from("While Amy lives in Amsterdam, Eric is in The Hague.");
let qa_input_1 = QaInput { question: question_1, context: context_1 };
let qa_input_2 = QaInput { question: question_2, context: context_2 };
let qa_input_1 = QaInput {
question: question_1,
context: context_1,
};
let qa_input_2 = QaInput {
question: question_2,
context: context_2,
};
// Get answer
let answers = qa_model.predict(&vec!(qa_input_1, qa_input_2), 1, 32);
let answers = qa_model.predict(&vec![qa_input_1, qa_input_2], 1, 32);
println!("{:?}", answers);
Ok(())
}

View File

@ -12,20 +12,30 @@
extern crate failure;
use tch::{Device, nn, Tensor, no_grad};
use rust_tokenizers::{TruncationStrategy, Tokenizer, Vocab, RobertaTokenizer};
use rust_bert::Config;
use rust_bert::bert::BertConfig;
use rust_bert::roberta::{RobertaForMaskedLM, RobertaVocabResources, RobertaConfigResources, RobertaMergesResources, RobertaModelResources};
use rust_bert::resources::{Resource, download_resource, RemoteResource};
use rust_bert::resources::{download_resource, RemoteResource, Resource};
use rust_bert::roberta::{
RobertaConfigResources, RobertaForMaskedLM, RobertaMergesResources, RobertaModelResources,
RobertaVocabResources,
};
use rust_bert::Config;
use rust_tokenizers::{RobertaTokenizer, Tokenizer, TruncationStrategy, Vocab};
use tch::{nn, no_grad, Device, Tensor};
fn main() -> failure::Fallible<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaConfigResources::ROBERTA));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaVocabResources::ROBERTA));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaMergesResources::ROBERTA));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaModelResources::ROBERTA));
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
RobertaConfigResources::ROBERTA,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
RobertaVocabResources::ROBERTA,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
RobertaMergesResources::ROBERTA,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
RobertaModelResources::ROBERTA,
));
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
let merges_path = download_resource(&merges_resource)?;
@ -34,45 +44,57 @@ fn main() -> failure::Fallible<()> {
// Set-up masked LM model
let device = Device::Cpu;
let mut vs = nn::VarStore::new(device);
let tokenizer: RobertaTokenizer = RobertaTokenizer::from_file(vocab_path.to_str().unwrap(), merges_path.to_str().unwrap(), true);
let tokenizer: RobertaTokenizer = RobertaTokenizer::from_file(
vocab_path.to_str().unwrap(),
merges_path.to_str().unwrap(),
true,
);
let config = BertConfig::from_file(config_path);
let bert_model = RobertaForMaskedLM::new(&vs.root(), &config);
vs.load(weights_path)?;
// Define input
let input = ["<pad> Looks like one thing is missing", "It\'s like comparing oranges to apples"];
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
let mut tokenized_input = tokenized_input.
iter().
map(|input| input.token_ids.clone()).
map(|mut input| {
let input = [
"<pad> Looks like one thing is missing",
"It\'s like comparing oranges to apples",
];
let tokenized_input =
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let mut tokenized_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]);
input
}).
collect::<Vec<_>>();
})
.collect::<Vec<_>>();
// Masking the token [thing] of sentence 1 and [oranges] of sentence 2
tokenized_input[0][4] = 103;
tokenized_input[1][5] = 103;
let tokenized_input = tokenized_input.
iter().
map(|input|
Tensor::of_slice(&(input))).
collect::<Vec<_>>();
let tokenized_input = tokenized_input
.iter()
.map(|input| Tensor::of_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
let (output, _, _) = no_grad(|| {
bert_model
.forward_t(Some(input_tensor),
bert_model.forward_t(
Some(input_tensor),
None,
None,
None,
None,
&None,
&None,
false)
false,
)
});
// Print masked tokens

View File

@ -14,7 +14,6 @@ extern crate failure;
use rust_bert::pipelines::sentiment::SentimentModel;
fn main() -> failure::Fallible<()> {
// Set-up classifier
let sentiment_classifier = SentimentModel::new(Default::default())?;

View File

@ -12,10 +12,9 @@
extern crate failure;
use std::path::PathBuf;
use rust_bert::pipelines::question_answering::{squad_processor, QuestionAnsweringModel};
use std::env;
use rust_bert::pipelines::question_answering::{QuestionAnsweringModel, squad_processor};
use std::path::PathBuf;
fn main() -> failure::Fallible<()> {
// Set-up Question Answering model

View File

@ -10,13 +10,12 @@
// See the License for the specific language governing permissions and
// limitations under the License.
extern crate failure;
extern crate dirs;
extern crate failure;
use std::path::PathBuf;
use rust_bert::pipelines::sentiment::{SentimentModel, ss2_processor};
use rust_bert::pipelines::sentiment::{ss2_processor, SentimentModel};
use std::env;
use std::path::PathBuf;
fn main() -> failure::Fallible<()> {
// Set-up classifier
@ -30,11 +29,19 @@ fn main() -> failure::Fallible<()> {
// Run model
let batch_size = 64;
let mut output = vec!();
let mut output = vec![];
for batch in inputs.chunks(batch_size) {
output.push(sentiment_classifier.predict(batch.iter().map(|v| v.as_str()).collect::<Vec<&str>>().as_slice()));
output.push(
sentiment_classifier.predict(
batch
.iter()
.map(|v| v.as_str())
.collect::<Vec<&str>>()
.as_slice(),
),
);
}
let mut flat_outputs = vec!();
let mut flat_outputs = vec![];
for batch_output in output.iter_mut() {
flat_outputs.append(batch_output);
}

View File

@ -14,7 +14,6 @@ extern crate failure;
use rust_bert::pipelines::summarization::SummarizationModel;
fn main() -> failure::Fallible<()> {
let summarization_model = SummarizationModel::new(Default::default())?;
@ -44,7 +43,7 @@ about exoplanets like K2-18b."];
let _output = summarization_model.summarize(&input);
for sentence in _output {
println!("{:?}", sentence);
};
}
Ok(())
}

View File

@ -10,18 +10,26 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use rust_bert::pipelines::token_classification::{TokenClassificationModel, TokenClassificationConfig, LabelAggregationOption};
use rust_bert::resources::{Resource, RemoteResource};
use rust_bert::bert::{BertModelResources, BertVocabResources, BertConfigResources};
use rust_bert::bert::{BertConfigResources, BertModelResources, BertVocabResources};
use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::token_classification::{
LabelAggregationOption, TokenClassificationConfig, TokenClassificationModel,
};
use rust_bert::resources::{RemoteResource, Resource};
fn main() -> failure::Fallible<()> {
// Load a configuration
let config = TokenClassificationConfig::new(ModelType::Bert,
Resource::Remote(RemoteResource::from_pretrained(BertModelResources::BERT_NER)),
Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT_NER)),
Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT_NER)),
let config = TokenClassificationConfig::new(
ModelType::Bert,
Resource::Remote(RemoteResource::from_pretrained(
BertModelResources::BERT_NER,
)),
Resource::Remote(RemoteResource::from_pretrained(
BertConfigResources::BERT_NER,
)),
Resource::Remote(RemoteResource::from_pretrained(
BertVocabResources::BERT_NER,
)),
None, //merges resource only relevant with ModelType::Roberta
false, //lowercase
LabelAggregationOption::Mode,
@ -31,7 +39,7 @@ fn main() -> failure::Fallible<()> {
let token_classification_model = TokenClassificationModel::new(config)?;
let input = [
"My name is Amélie. I live in Москва.",
"Chongqing is a city in China."
"Chongqing is a city in China.",
];
let token_outputs = token_classification_model.predict(&input, true, false); //ignore_first_label = true (only returns the NER parts, ignoring first label O)

View File

@ -13,12 +13,12 @@
extern crate failure;
use rust_bert::pipelines::translation::{TranslationConfig, TranslationModel, Language};
use rust_bert::pipelines::translation::{Language, TranslationConfig, TranslationModel};
use tch::Device;
fn main() -> failure::Fallible<()> {
let translation_config = TranslationConfig::new(Language::EnglishToGerman, Device::cuda_if_available());
let translation_config =
TranslationConfig::new(Language::EnglishToGerman, Device::cuda_if_available());
let model = TranslationModel::new(translation_config)?;
let input_context_1 = "The quick brown fox jumps over the lazy dog";

1
rustfmt.toml Normal file
View File

@ -0,0 +1 @@
format_code_in_doc_comments = true

View File

@ -11,16 +11,15 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::collections::HashMap;
use crate::Config;
use serde::{Deserialize, Serialize};
use crate::albert::embeddings::AlbertEmbeddings;
use crate::albert::encoder::AlbertTransformer;
use tch::{nn, Tensor, Kind};
use crate::common::activations::{_tanh, _gelu_new, _gelu, _relu, _mish};
use tch::nn::Module;
use crate::common::activations::{_gelu, _gelu_new, _mish, _relu, _tanh};
use crate::common::dropout::Dropout;
use crate::Config;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use tch::nn::Module;
use tch::{nn, Kind, Tensor};
/// # ALBERT Pretrained model weight files
pub struct AlbertModelResources;
@ -33,20 +32,28 @@ pub struct AlbertVocabResources;
impl AlbertModelResources {
/// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/ALBERT. Modified with conversion to C-array format.
pub const ALBERT_BASE_V2: (&'static str, &'static str) = ("albert-base-v2/model.ot", "https://cdn.huggingface.co/albert-base-v2/rust_model.ot");
pub const ALBERT_BASE_V2: (&'static str, &'static str) = (
"albert-base-v2/model.ot",
"https://cdn.huggingface.co/albert-base-v2/rust_model.ot",
);
}
impl AlbertConfigResources {
/// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/ALBERT. Modified with conversion to C-array format.
pub const ALBERT_BASE_V2: (&'static str, &'static str) = ("albert-base-v2/config.json", "https://cdn.huggingface.co/albert-base-v2-config.json");
pub const ALBERT_BASE_V2: (&'static str, &'static str) = (
"albert-base-v2/config.json",
"https://cdn.huggingface.co/albert-base-v2-config.json",
);
}
impl AlbertVocabResources {
/// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/ALBERT. Modified with conversion to C-array format.
pub const ALBERT_BASE_V2: (&'static str, &'static str) = ("albert-base-v2/spiece.model", "https://cdn.huggingface.co/albert-base-v2-spiece.model");
pub const ALBERT_BASE_V2: (&'static str, &'static str) = (
"albert-base-v2/spiece.model",
"https://cdn.huggingface.co/albert-base-v2-spiece.model",
);
}
#[allow(non_camel_case_types)]
#[derive(Clone, Debug, Serialize, Deserialize)]
/// # Activation function used in the attention layer and masked language model head
@ -61,7 +68,6 @@ pub enum Activation {
mish,
}
#[derive(Debug, Serialize, Deserialize)]
/// # ALBERT model configuration
/// Defines the ALBERT model architecture (e.g. number of layers, hidden layer size, label mapping...)
@ -123,10 +129,10 @@ impl AlbertModel {
/// # Example
///
/// ```no_run
/// use tch::{nn, Device};
/// use rust_bert::albert::{AlbertConfig, AlbertModel};
/// use rust_bert::Config;
/// use std::path::Path;
/// use rust_bert::albert::{AlbertConfig, AlbertModel};
/// use tch::{nn, Device};
///
/// let config_path = Path::new("path/to/config.json");
/// let device = Device::Cpu;
@ -134,14 +140,23 @@ impl AlbertModel {
/// let config = AlbertConfig::from_file(config_path);
/// let albert: AlbertModel = AlbertModel::new(&(&p.root() / "albert"), &config);
/// ```
///
pub fn new(p: &nn::Path, config: &AlbertConfig) -> AlbertModel {
let embeddings = AlbertEmbeddings::new(&(p / "embeddings"), config);
let encoder = AlbertTransformer::new(&(p / "encoder"), config);
let pooler = nn::linear(&(p / "pooler"), config.hidden_size, config.hidden_size, Default::default());
let pooler = nn::linear(
&(p / "pooler"),
config.hidden_size,
config.hidden_size,
Default::default(),
);
let pooler_activation = Box::new(_tanh);
AlbertModel { embeddings, encoder, pooler, pooler_activation }
AlbertModel {
embeddings,
encoder,
pooler,
pooler_activation,
}
}
/// Forward pass through the model
@ -179,61 +194,89 @@ impl AlbertModel {
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
/// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
/// let position_ids = Tensor::arange(sequence_length, (Int64, device)).expand(&[batch_size, sequence_length], true);
/// let position_ids = Tensor::arange(sequence_length, (Int64, device))
/// .expand(&[batch_size, sequence_length], true);
///
/// let (output, pooled_output, all_hidden_states, all_attentions) = no_grad(|| {
/// albert_model
/// .forward_t(Some(input_tensor),
/// .forward_t(
/// Some(input_tensor),
/// Some(mask),
/// Some(token_type_ids),
/// Some(position_ids),
/// None,
/// false).unwrap()
/// false,
/// )
/// .unwrap()
/// });
///
/// ```
///
pub fn forward_t(&self,
pub fn forward_t(
&self,
input_ids: Option<Tensor>,
mask: Option<Tensor>,
token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool)
-> Result<(Tensor, Tensor, Option<Vec<Tensor>>, Option<Vec<Vec<Tensor>>>), &'static str> {
train: bool,
) -> Result<
(
Tensor,
Tensor,
Option<Vec<Tensor>>,
Option<Vec<Vec<Tensor>>>,
),
&'static str,
> {
let (input_shape, device) = match &input_ids {
Some(input_value) => match &input_embeds {
Some(_) => { return Err("Only one of input ids or input embeddings may be set"); }
None => (input_value.size(), input_value.device())
Some(_) => {
return Err("Only one of input ids or input embeddings may be set");
}
None => (input_value.size(), input_value.device()),
},
None => match &input_embeds {
Some(embeds) => (vec!(embeds.size()[0], embeds.size()[1]), embeds.device()),
None => { return Err("At least one of input ids or input embeddings must be set"); }
Some(embeds) => (vec![embeds.size()[0], embeds.size()[1]], embeds.device()),
None => {
return Err("At least one of input ids or input embeddings must be set");
}
},
};
let mask = match mask {
Some(value) => value,
None => Tensor::ones(&input_shape, (Kind::Int64, device))
None => Tensor::ones(&input_shape, (Kind::Int64, device)),
};
let extended_attention_mask = mask.unsqueeze(1).unsqueeze(2);
let extended_attention_mask: Tensor = (extended_attention_mask.ones_like() - extended_attention_mask) * -10000.0;
let extended_attention_mask: Tensor =
(extended_attention_mask.ones_like() - extended_attention_mask) * -10000.0;
let embedding_output = match self.embeddings.forward_t(input_ids, token_type_ids, position_ids, input_embeds, train) {
let embedding_output = match self.embeddings.forward_t(
input_ids,
token_type_ids,
position_ids,
input_embeds,
train,
) {
Ok(value) => value,
Err(e) => { return Err(e); }
Err(e) => {
return Err(e);
}
};
let (hidden_state, all_hidden_states, all_attentions) =
self.encoder.forward_t(&embedding_output,
Some(extended_attention_mask),
train);
self.encoder
.forward_t(&embedding_output, Some(extended_attention_mask), train);
let pooled_output = self.pooler.forward(&hidden_state.select(1, 0));
let pooled_output = (self.pooler_activation)(&pooled_output);
Ok((hidden_state, pooled_output, all_hidden_states, all_attentions))
Ok((
hidden_state,
pooled_output,
all_hidden_states,
all_attentions,
))
}
}
@ -248,21 +291,43 @@ impl AlbertMLMHead {
pub fn new(p: &nn::Path, config: &AlbertConfig) -> AlbertMLMHead {
let layer_norm_eps = match config.layer_norm_eps {
Some(value) => value,
None => 1e-12
None => 1e-12,
};
let layer_norm_config = nn::LayerNormConfig { eps: layer_norm_eps, ..Default::default() };
let layer_norm = nn::layer_norm(&(p / "LayerNorm"), vec![config.embedding_size], layer_norm_config);
let dense = nn::linear(&(p / "dense"), config.hidden_size, config.embedding_size, Default::default());
let decoder = nn::linear(&(p / "decoder"), config.embedding_size, config.vocab_size, Default::default());
let layer_norm_config = nn::LayerNormConfig {
eps: layer_norm_eps,
..Default::default()
};
let layer_norm = nn::layer_norm(
&(p / "LayerNorm"),
vec![config.embedding_size],
layer_norm_config,
);
let dense = nn::linear(
&(p / "dense"),
config.hidden_size,
config.embedding_size,
Default::default(),
);
let decoder = nn::linear(
&(p / "decoder"),
config.embedding_size,
config.vocab_size,
Default::default(),
);
let activation = Box::new(match &config.hidden_act {
Activation::gelu_new => _gelu_new,
Activation::gelu => _gelu,
Activation::relu => _relu,
Activation::mish => _mish
Activation::mish => _mish,
});
AlbertMLMHead { layer_norm, dense, decoder, activation }
AlbertMLMHead {
layer_norm,
dense,
decoder,
activation,
}
}
pub fn forward(&self, hidden_states: &Tensor) -> Tensor {
@ -292,10 +357,10 @@ impl AlbertForMaskedLM {
/// # Example
///
/// ```no_run
/// use tch::{nn, Device};
/// use rust_bert::albert::{AlbertConfig, AlbertForMaskedLM};
/// use rust_bert::Config;
/// use std::path::Path;
/// use rust_bert::albert::{AlbertConfig, AlbertForMaskedLM};
/// use tch::{nn, Device};
///
/// let config_path = Path::new("path/to/config.json");
/// let device = Device::Cpu;
@ -303,12 +368,14 @@ impl AlbertForMaskedLM {
/// let config = AlbertConfig::from_file(config_path);
/// let albert: AlbertForMaskedLM = AlbertForMaskedLM::new(&p.root(), &config);
/// ```
///
pub fn new(p: &nn::Path, config: &AlbertConfig) -> AlbertForMaskedLM {
let albert = AlbertModel::new(&(p / "albert"), config);
let predictions = AlbertMLMHead::new(&(p / "predictions"), config);
AlbertForMaskedLM { albert, predictions }
AlbertForMaskedLM {
albert,
predictions,
}
}
/// Forward pass through the model
@ -345,28 +412,40 @@ impl AlbertForMaskedLM {
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
/// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
/// let position_ids = Tensor::arange(sequence_length, (Int64, device)).expand(&[batch_size, sequence_length], true);
/// let position_ids = Tensor::arange(sequence_length, (Int64, device))
/// .expand(&[batch_size, sequence_length], true);
///
/// let (output, all_hidden_states, all_attentions) = no_grad(|| {
/// albert_model
/// .forward_t(Some(input_tensor),
/// albert_model.forward_t(
/// Some(input_tensor),
/// Some(mask),
/// Some(token_type_ids),
/// Some(position_ids),
/// None,
/// false)
/// false,
/// )
/// });
///
/// ```
///
pub fn forward_t(&self,
pub fn forward_t(
&self,
input_ids: Option<Tensor>,
mask: Option<Tensor>,
token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Vec<Tensor>>>) {
let (hidden_state, _, all_hidden_states, all_attentions) = self.albert.forward_t(input_ids, mask, token_type_ids, position_ids, input_embeds, train).unwrap();
train: bool,
) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Vec<Tensor>>>) {
let (hidden_state, _, all_hidden_states, all_attentions) = self
.albert
.forward_t(
input_ids,
mask,
token_type_ids,
position_ids,
input_embeds,
train,
)
.unwrap();
let prediction_scores = self.predictions.forward(&hidden_state);
(prediction_scores, all_hidden_states, all_attentions)
}
@ -395,29 +474,42 @@ impl AlbertForSequenceClassification {
/// # Example
///
/// ```no_run
/// use tch::{nn, Device};
/// use rust_bert::albert::{AlbertConfig, AlbertForSequenceClassification};
/// use rust_bert::Config;
/// use std::path::Path;
/// use rust_bert::albert::{AlbertConfig, AlbertForSequenceClassification};
/// use tch::{nn, Device};
///
/// let config_path = Path::new("path/to/config.json");
/// let device = Device::Cpu;
/// let p = nn::VarStore::new(device);
/// let config = AlbertConfig::from_file(config_path);
/// let albert: AlbertForSequenceClassification = AlbertForSequenceClassification::new(&p.root(), &config);
/// let albert: AlbertForSequenceClassification =
/// AlbertForSequenceClassification::new(&p.root(), &config);
/// ```
///
pub fn new(p: &nn::Path, config: &AlbertConfig) -> AlbertForSequenceClassification {
let albert = AlbertModel::new(&(p / "albert"), config);
let classifier_dropout_prob = match config.classifier_dropout_prob {
Some(value) => value,
None => 0.1
None => 0.1,
};
let dropout = Dropout::new(classifier_dropout_prob);
let num_labels = config.id2label.as_ref().expect("num_labels not provided in configuration").len() as i64;
let classifier = nn::linear(&(p / "classifier"), config.hidden_size, num_labels, Default::default());
let num_labels = config
.id2label
.as_ref()
.expect("num_labels not provided in configuration")
.len() as i64;
let classifier = nn::linear(
&(p / "classifier"),
config.hidden_size,
num_labels,
Default::default(),
);
AlbertForSequenceClassification { albert, dropout, classifier }
AlbertForSequenceClassification {
albert,
dropout,
classifier,
}
}
/// Forward pass through the model
@ -465,18 +557,30 @@ impl AlbertForSequenceClassification {
/// None,
/// false)
/// });
///
/// ```
///
pub fn forward_t(&self,
pub fn forward_t(
&self,
input_ids: Option<Tensor>,
mask: Option<Tensor>,
token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Vec<Tensor>>>) {
let (_, pooled_output, all_hidden_states, all_attentions) = self.albert.forward_t(input_ids, mask, token_type_ids, position_ids, input_embeds, train).unwrap();
let logits = pooled_output.apply_t(&self.dropout, train).apply(&self.classifier);
train: bool,
) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Vec<Tensor>>>) {
let (_, pooled_output, all_hidden_states, all_attentions) = self
.albert
.forward_t(
input_ids,
mask,
token_type_ids,
position_ids,
input_embeds,
train,
)
.unwrap();
let logits = pooled_output
.apply_t(&self.dropout, train)
.apply(&self.classifier);
(logits, all_hidden_states, all_attentions)
}
}
@ -505,25 +609,38 @@ impl AlbertForTokenClassification {
/// # Example
///
/// ```no_run
/// use tch::{nn, Device};
/// use rust_bert::albert::{AlbertConfig, AlbertForTokenClassification};
/// use rust_bert::Config;
/// use std::path::Path;
/// use rust_bert::albert::{AlbertConfig, AlbertForTokenClassification};
/// use tch::{nn, Device};
///
/// let config_path = Path::new("path/to/config.json");
/// let device = Device::Cpu;
/// let p = nn::VarStore::new(device);
/// let config = AlbertConfig::from_file(config_path);
/// let albert: AlbertForTokenClassification = AlbertForTokenClassification::new(&p.root(), &config);
/// let albert: AlbertForTokenClassification =
/// AlbertForTokenClassification::new(&p.root(), &config);
/// ```
///
pub fn new(p: &nn::Path, config: &AlbertConfig) -> AlbertForTokenClassification {
let albert = AlbertModel::new(&(p / "albert"), config);
let dropout = Dropout::new(config.hidden_dropout_prob);
let num_labels = config.id2label.as_ref().expect("num_labels not provided in configuration").len() as i64;
let classifier = nn::linear(&(p / "classifier"), config.hidden_size, num_labels, Default::default());
let num_labels = config
.id2label
.as_ref()
.expect("num_labels not provided in configuration")
.len() as i64;
let classifier = nn::linear(
&(p / "classifier"),
config.hidden_size,
num_labels,
Default::default(),
);
AlbertForTokenClassification { albert, dropout, classifier }
AlbertForTokenClassification {
albert,
dropout,
classifier,
}
}
/// Forward pass through the model
@ -571,18 +688,30 @@ impl AlbertForTokenClassification {
/// None,
/// false)
/// });
///
/// ```
///
pub fn forward_t(&self,
pub fn forward_t(
&self,
input_ids: Option<Tensor>,
mask: Option<Tensor>,
token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Vec<Tensor>>>) {
let (sequence_output, _, all_hidden_states, all_attentions) = self.albert.forward_t(input_ids, mask, token_type_ids, position_ids, input_embeds, train).unwrap();
let logits = sequence_output.apply_t(&self.dropout, train).apply(&self.classifier);
train: bool,
) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Vec<Tensor>>>) {
let (sequence_output, _, all_hidden_states, all_attentions) = self
.albert
.forward_t(
input_ids,
mask,
token_type_ids,
position_ids,
input_embeds,
train,
)
.unwrap();
let logits = sequence_output
.apply_t(&self.dropout, train)
.apply(&self.classifier);
(logits, all_hidden_states, all_attentions)
}
}
@ -610,10 +739,10 @@ impl AlbertForQuestionAnswering {
/// # Example
///
/// ```no_run
/// use tch::{nn, Device};
/// use rust_bert::albert::{AlbertConfig, AlbertForQuestionAnswering};
/// use rust_bert::Config;
/// use std::path::Path;
/// use rust_bert::albert::{AlbertConfig, AlbertForQuestionAnswering};
/// use tch::{nn, Device};
///
/// let config_path = Path::new("path/to/config.json");
/// let device = Device::Cpu;
@ -621,11 +750,15 @@ impl AlbertForQuestionAnswering {
/// let config = AlbertConfig::from_file(config_path);
/// let albert: AlbertForQuestionAnswering = AlbertForQuestionAnswering::new(&p.root(), &config);
/// ```
///
pub fn new(p: &nn::Path, config: &AlbertConfig) -> AlbertForQuestionAnswering {
let albert = AlbertModel::new(&(p / "albert"), config);
let num_labels = 2;
let qa_outputs = nn::linear(&(p / "qa_outputs"), config.hidden_size, num_labels, Default::default());
let qa_outputs = nn::linear(
&(p / "qa_outputs"),
config.hidden_size,
num_labels,
Default::default(),
);
AlbertForQuestionAnswering { albert, qa_outputs }
}
@ -676,17 +809,32 @@ impl AlbertForQuestionAnswering {
/// None,
/// false)
/// });
///
/// ```
///
pub fn forward_t(&self,
pub fn forward_t(
&self,
input_ids: Option<Tensor>,
mask: Option<Tensor>,
token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool) -> (Tensor, Tensor, Option<Vec<Tensor>>, Option<Vec<Vec<Tensor>>>) {
let (sequence_output, _, all_hidden_states, all_attentions) = self.albert.forward_t(input_ids, mask, token_type_ids, position_ids, input_embeds, train).unwrap();
train: bool,
) -> (
Tensor,
Tensor,
Option<Vec<Tensor>>,
Option<Vec<Vec<Tensor>>>,
) {
let (sequence_output, _, all_hidden_states, all_attentions) = self
.albert
.forward_t(
input_ids,
mask,
token_type_ids,
position_ids,
input_embeds,
train,
)
.unwrap();
let logits = sequence_output.apply(&self.qa_outputs).split(1, -1);
let (start_logits, end_logits) = (&logits[0], &logits[1]);
let start_logits = start_logits.squeeze1(-1);
@ -721,10 +869,10 @@ impl AlbertForMultipleChoice {
/// # Example
///
/// ```no_run
/// use tch::{nn, Device};
/// use rust_bert::albert::{AlbertConfig, AlbertForMultipleChoice};
/// use rust_bert::Config;
/// use std::path::Path;
/// use rust_bert::albert::{AlbertConfig, AlbertForMultipleChoice};
/// use tch::{nn, Device};
///
/// let config_path = Path::new("path/to/config.json");
/// let device = Device::Cpu;
@ -732,14 +880,22 @@ impl AlbertForMultipleChoice {
/// let config = AlbertConfig::from_file(config_path);
/// let albert: AlbertForMultipleChoice = AlbertForMultipleChoice::new(&p.root(), &config);
/// ```
///
pub fn new(p: &nn::Path, config: &AlbertConfig) -> AlbertForMultipleChoice {
let albert = AlbertModel::new(&(p / "albert"), config);
let dropout = Dropout::new(config.hidden_dropout_prob);
let num_labels = 1;
let classifier = nn::linear(&(p / "classifier"), config.hidden_size, num_labels, Default::default());
let classifier = nn::linear(
&(p / "classifier"),
config.hidden_size,
num_labels,
Default::default(),
);
AlbertForMultipleChoice { albert, dropout, classifier }
AlbertForMultipleChoice {
albert,
dropout,
classifier,
}
}
/// Forward pass through the model
@ -787,43 +943,67 @@ impl AlbertForMultipleChoice {
/// None,
/// false).unwrap()
/// });
///
/// ```
///
pub fn forward_t(&self,
pub fn forward_t(
&self,
input_ids: Option<Tensor>,
mask: Option<Tensor>,
token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool) -> Result<(Tensor, Option<Vec<Tensor>>, Option<Vec<Vec<Tensor>>>), &'static str> {
train: bool,
) -> Result<(Tensor, Option<Vec<Tensor>>, Option<Vec<Vec<Tensor>>>), &'static str> {
let (input_ids, input_embeds, num_choices) = match &input_ids {
Some(input_value) => match &input_embeds {
Some(_) => { return Err("Only one of input ids or input embeddings may be set"); }
None => (Some(input_value.view((-1, *input_value.size().last().unwrap()))), None, input_value.size()[1])
Some(_) => {
return Err("Only one of input ids or input embeddings may be set");
}
None => (
Some(input_value.view((-1, *input_value.size().last().unwrap()))),
None,
input_value.size()[1],
),
},
None => match &input_embeds {
Some(embeds) => (None, Some(embeds.view((-1, embeds.size()[1], embeds.size()[2]))), embeds.size()[1]),
None => { return Err("At least one of input ids or input embeddings must be set"); }
Some(embeds) => (
None,
Some(embeds.view((-1, embeds.size()[1], embeds.size()[2]))),
embeds.size()[1],
),
None => {
return Err("At least one of input ids or input embeddings must be set");
}
},
};
let mask = match mask {
Some(value) => Some(value.view((-1, *value.size().last().unwrap()))),
None => None
None => None,
};
let token_type_ids = match token_type_ids {
Some(value) => Some(value.view((-1, *value.size().last().unwrap()))),
None => None
None => None,
};
let position_ids = match position_ids {
Some(value) => Some(value.view((-1, *value.size().last().unwrap()))),
None => None
None => None,
};
let (_, pooled_output, all_hidden_states, all_attentions) = self.albert.forward_t(input_ids, mask, token_type_ids, position_ids, input_embeds, train).unwrap();
let logits = pooled_output.apply_t(&self.dropout, train).apply(&self.classifier).view((-1, num_choices));
let (_, pooled_output, all_hidden_states, all_attentions) = self
.albert
.forward_t(
input_ids,
mask,
token_type_ids,
position_ids,
input_embeds,
train,
)
.unwrap();
let logits = pooled_output
.apply_t(&self.dropout, train)
.apply(&self.classifier)
.view((-1, num_choices));
Ok((logits, all_hidden_states, all_attentions))
}

View File

@ -11,10 +11,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use crate::common::dropout::Dropout;
use tch::{nn, Tensor};
use crate::albert::AlbertConfig;
use crate::common::dropout::Dropout;
use tch::kind::Kind::Float;
use tch::{nn, Tensor};
#[derive(Debug)]
pub struct AlbertSelfAttention {
@ -32,24 +32,55 @@ pub struct AlbertSelfAttention {
impl AlbertSelfAttention {
pub fn new(p: nn::Path, config: &AlbertConfig) -> AlbertSelfAttention {
assert_eq!(config.hidden_size % config.num_attention_heads, 0, "Hidden size not a multiple of the number of attention heads");
assert_eq!(
config.hidden_size % config.num_attention_heads,
0,
"Hidden size not a multiple of the number of attention heads"
);
let query = nn::linear(&p / "query", config.hidden_size, config.hidden_size, Default::default());
let key = nn::linear(&p / "key", config.hidden_size, config.hidden_size, Default::default());
let value = nn::linear(&p / "value", config.hidden_size, config.hidden_size, Default::default());
let dense = nn::linear(&p / "dense", config.hidden_size, config.hidden_size, Default::default());
let query = nn::linear(
&p / "query",
config.hidden_size,
config.hidden_size,
Default::default(),
);
let key = nn::linear(
&p / "key",
config.hidden_size,
config.hidden_size,
Default::default(),
);
let value = nn::linear(
&p / "value",
config.hidden_size,
config.hidden_size,
Default::default(),
);
let dense = nn::linear(
&p / "dense",
config.hidden_size,
config.hidden_size,
Default::default(),
);
let dropout = Dropout::new(config.attention_probs_dropout_prob);
let attention_head_size = config.hidden_size / config.num_attention_heads;
let output_attentions = match config.output_attentions {
Some(value) => value,
None => false
None => false,
};
let layer_norm_eps = match config.layer_norm_eps {
Some(value) => value,
None => 1e-12
None => 1e-12,
};
let layer_norm_config = nn::LayerNormConfig { eps: layer_norm_eps, ..Default::default() };
let layer_norm = nn::layer_norm(&p / "LayerNorm", vec![config.hidden_size], layer_norm_config);
let layer_norm_config = nn::LayerNormConfig {
eps: layer_norm_eps,
..Default::default()
};
let layer_norm = nn::layer_norm(
&p / "LayerNorm",
vec![config.hidden_size],
layer_norm_config,
);
AlbertSelfAttention {
num_attention_heads: config.num_attention_heads,
@ -66,19 +97,23 @@ impl AlbertSelfAttention {
}
fn split_heads(&self, x: Tensor, bs: i64, dim_per_head: i64) -> Tensor {
x.view((bs, -1, self.num_attention_heads, dim_per_head)).transpose(1, 2)
x.view((bs, -1, self.num_attention_heads, dim_per_head))
.transpose(1, 2)
}
pub fn forward_t(&self,
pub fn forward_t(
&self,
input_ids: &Tensor,
mask: &Option<Tensor>,
train: bool) -> (Tensor, Option<Tensor>) {
train: bool,
) -> (Tensor, Option<Tensor>) {
let bs = *input_ids.size().first().unwrap();
let key_layer = self.split_heads(input_ids.apply(&self.key), bs, self.attention_head_size);
let value_layer = self.split_heads(input_ids.apply(&self.value), bs, self.attention_head_size);
let query_layer = self.split_heads(input_ids.apply(&self.query), bs, self.attention_head_size);
let value_layer =
self.split_heads(input_ids.apply(&self.value), bs, self.attention_head_size);
let query_layer =
self.split_heads(input_ids.apply(&self.query), bs, self.attention_head_size);
let query_layer: Tensor = query_layer / (self.attention_head_size as f64).sqrt();
@ -91,9 +126,11 @@ impl AlbertSelfAttention {
let weights = scores.softmax(-1, Float).apply_t(&self.dropout, train);
let context = weights.matmul(&value_layer).transpose(1, 2).contiguous();
let w = self.dense.ws
.transpose(0, 1)
.view((self.num_attention_heads, self.attention_head_size, self.hidden_size));
let w = self.dense.ws.transpose(0, 1).view((
self.num_attention_heads,
self.attention_head_size,
self.hidden_size,
));
let context: Tensor = Tensor::einsum("bfnd,ndh->bfh", &[context, w]) + &self.dense.bs;
let context = (input_ids + context.apply_t(&self.dropout, train)).apply(&self.layer_norm);

View File

@ -11,10 +11,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use tch::{nn, Tensor, Kind};
use crate::common::dropout::Dropout;
use crate::albert::AlbertConfig;
use tch::nn::{EmbeddingConfig, embedding};
use crate::common::dropout::Dropout;
use tch::nn::{embedding, EmbeddingConfig};
use tch::{nn, Kind, Tensor};
/// # Embeddings implementation for Albert model
#[derive(Debug)]
@ -34,49 +34,77 @@ impl AlbertEmbeddings {
..Default::default()
};
let word_embeddings: nn::Embedding = embedding(p / "word_embeddings",
let word_embeddings: nn::Embedding = embedding(
p / "word_embeddings",
config.vocab_size,
config.embedding_size,
embedding_config);
embedding_config,
);
let position_embeddings: nn::Embedding = embedding(p / "position_embeddings",
let position_embeddings: nn::Embedding = embedding(
p / "position_embeddings",
config.max_position_embeddings,
config.embedding_size,
Default::default());
Default::default(),
);
let token_type_embeddings: nn::Embedding = embedding(p / "token_type_embeddings",
let token_type_embeddings: nn::Embedding = embedding(
p / "token_type_embeddings",
config.type_vocab_size,
config.embedding_size,
Default::default());
Default::default(),
);
let layer_norm_eps = match config.layer_norm_eps {
Some(value) => value,
None => 1e-12
None => 1e-12,
};
let layer_norm_config = nn::LayerNormConfig { eps: layer_norm_eps, ..Default::default() };
let layer_norm: nn::LayerNorm = nn::layer_norm(p / "LayerNorm", vec![config.embedding_size], layer_norm_config);
let layer_norm_config = nn::LayerNormConfig {
eps: layer_norm_eps,
..Default::default()
};
let layer_norm: nn::LayerNorm = nn::layer_norm(
p / "LayerNorm",
vec![config.embedding_size],
layer_norm_config,
);
let dropout: Dropout = Dropout::new(config.hidden_dropout_prob);
AlbertEmbeddings { word_embeddings, position_embeddings, token_type_embeddings, layer_norm, dropout}
AlbertEmbeddings {
word_embeddings,
position_embeddings,
token_type_embeddings,
layer_norm,
dropout,
}
}
pub fn forward_t(&self,
pub fn forward_t(
&self,
input_ids: Option<Tensor>,
token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool) -> Result<Tensor, &'static str> {
train: bool,
) -> Result<Tensor, &'static str> {
let (input_embeddings, input_shape) = match input_ids {
Some(input_value) => match input_embeds {
Some(_) => { return Err("Only one of input ids or input embeddings may be set"); }
None => (input_value.apply_t(&self.word_embeddings, train), input_value.size())
Some(_) => {
return Err("Only one of input ids or input embeddings may be set");
}
None => (
input_value.apply_t(&self.word_embeddings, train),
input_value.size(),
),
},
None => match input_embeds {
Some(embeds) => {
let size = vec!(embeds.size()[0], embeds.size()[1]);
let size = vec![embeds.size()[0], embeds.size()[1]];
(embeds, size)
},
None => { return Err("Only one of input ids or input embeddings may be set"); }
}
None => {
return Err("Only one of input ids or input embeddings may be set");
}
},
};
let seq_length = input_embeddings.as_ref().size()[1].to_owned();
@ -84,19 +112,22 @@ impl AlbertEmbeddings {
let position_ids = match position_ids {
Some(value) => value,
None => Tensor::arange(seq_length, (Kind::Int64, input_embeddings.device()))
.unsqueeze(0).
expand(&input_shape, true)
.unsqueeze(0)
.expand(&input_shape, true),
};
let token_type_ids = match token_type_ids {
Some(value) => value,
None => Tensor::zeros(&input_shape, (Kind::Int64, input_embeddings.device()))
None => Tensor::zeros(&input_shape, (Kind::Int64, input_embeddings.device())),
};
let position_embeddings = position_ids.apply(&self.position_embeddings);
let token_type_embeddings = token_type_ids.apply(&self.token_type_embeddings);
let input_embeddings: Tensor = input_embeddings + position_embeddings + token_type_embeddings;
Ok(input_embeddings.apply(&self.layer_norm).apply_t(&self.dropout, train))
let input_embeddings: Tensor =
input_embeddings + position_embeddings + token_type_embeddings;
Ok(input_embeddings
.apply(&self.layer_norm)
.apply_t(&self.dropout, train))
}
}

View File

@ -11,12 +11,12 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use crate::albert::attention::AlbertSelfAttention;
use tch::{nn, Tensor};
use crate::albert::AlbertConfig;
use crate::albert::albert::Activation;
use crate::common::activations::{_gelu_new, _gelu, _relu, _mish};
use crate::albert::attention::AlbertSelfAttention;
use crate::albert::AlbertConfig;
use crate::common::activations::{_gelu, _gelu_new, _mish, _relu};
use std::borrow::BorrowMut;
use tch::{nn, Tensor};
pub struct AlbertLayer {
attention: AlbertSelfAttention,
@ -32,29 +32,55 @@ impl AlbertLayer {
let layer_norm_eps = match config.layer_norm_eps {
Some(value) => value,
None => 1e-12
None => 1e-12,
};
let layer_norm_config = nn::LayerNormConfig { eps: layer_norm_eps, ..Default::default() };
let full_layer_layer_norm = nn::layer_norm(&(p / "full_layer_layer_norm"), vec![config.hidden_size], layer_norm_config);
let layer_norm_config = nn::LayerNormConfig {
eps: layer_norm_eps,
..Default::default()
};
let full_layer_layer_norm = nn::layer_norm(
&(p / "full_layer_layer_norm"),
vec![config.hidden_size],
layer_norm_config,
);
let ffn = nn::linear(&(p / "ffn"), config.hidden_size, config.intermediate_size, Default::default());
let ffn_output = nn::linear(&(p / "ffn_output"), config.intermediate_size, config.hidden_size, Default::default());
let ffn = nn::linear(
&(p / "ffn"),
config.hidden_size,
config.intermediate_size,
Default::default(),
);
let ffn_output = nn::linear(
&(p / "ffn_output"),
config.intermediate_size,
config.hidden_size,
Default::default(),
);
let activation = Box::new(match &config.hidden_act {
Activation::gelu_new => _gelu_new,
Activation::gelu => _gelu,
Activation::relu => _relu,
Activation::mish => _mish
Activation::mish => _mish,
});
AlbertLayer { attention, full_layer_layer_norm, ffn, ffn_output, activation }
AlbertLayer {
attention,
full_layer_layer_norm,
ffn,
ffn_output,
activation,
}
}
pub fn forward_t(&self,
pub fn forward_t(
&self,
hidden_states: &Tensor,
mask: &Option<Tensor>,
train: bool) -> (Tensor, Option<Tensor>) {
let (attention_output, attention_weights) = self.attention.forward_t(hidden_states, mask, train);
train: bool,
) -> (Tensor, Option<Tensor>) {
let (attention_output, attention_weights) =
self.attention.forward_t(hidden_states, mask, train);
let ffn_output = attention_output.apply(&self.ffn);
let ffn_output: Tensor = (self.activation)(&ffn_output);
let ffn_output = ffn_output.apply(&self.ffn_output);
@ -76,29 +102,42 @@ impl AlbertLayerGroup {
let output_attentions = match config.output_attentions {
Some(value) => value,
None => false
None => false,
};
let output_hidden_states = match config.output_hidden_states {
Some(value) => value,
None => false
None => false,
};
let mut layers: Vec<AlbertLayer> = vec!();
let mut layers: Vec<AlbertLayer> = vec![];
for layer_index in 0..config.inner_group_num {
layers.push(AlbertLayer::new(&(p / layer_index), config));
};
AlbertLayerGroup { output_hidden_states, output_attentions, layers }
}
pub fn forward_t(&self,
AlbertLayerGroup {
output_hidden_states,
output_attentions,
layers,
}
}
pub fn forward_t(
&self,
hidden_states: &Tensor,
mask: &Option<Tensor>,
train: bool)
-> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
let mut all_hidden_states: Option<Vec<Tensor>> = if self.output_hidden_states { Some(vec!()) } else { None };
let mut all_attentions: Option<Vec<Tensor>> = if self.output_attentions { Some(vec!()) } else { None };
train: bool,
) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
let mut all_hidden_states: Option<Vec<Tensor>> = if self.output_hidden_states {
Some(vec![])
} else {
None
};
let mut all_attentions: Option<Vec<Tensor>> = if self.output_attentions {
Some(vec![])
} else {
None
};
let mut hidden_state = hidden_states.copy();
let mut attention_weights: Option<Tensor>;
@ -117,9 +156,9 @@ impl AlbertLayerGroup {
attentions.push(attention_weights.as_ref().unwrap().copy());
};
}
None => break
};
None => break,
};
}
(hidden_state, all_hidden_states, all_attentions)
}
@ -140,20 +179,25 @@ impl AlbertTransformer {
let output_attentions = match config.output_attentions {
Some(value) => value,
None => false
None => false,
};
let output_hidden_states = match config.output_hidden_states {
Some(value) => value,
None => false
None => false,
};
let embedding_hidden_mapping_in = nn::linear(&(p / "embedding_hidden_mapping_in"), config.embedding_size, config.hidden_size, Default::default());
let embedding_hidden_mapping_in = nn::linear(
&(p / "embedding_hidden_mapping_in"),
config.embedding_size,
config.hidden_size,
Default::default(),
);
let mut layers: Vec<AlbertLayerGroup> = vec!();
let mut layers: Vec<AlbertLayerGroup> = vec![];
for layer_index in 0..config.inner_group_num {
layers.push(AlbertLayerGroup::new(&(p_layers / layer_index), config));
};
}
AlbertTransformer {
output_hidden_states,
@ -165,16 +209,24 @@ impl AlbertTransformer {
}
}
pub fn forward_t(&self,
pub fn forward_t(
&self,
hidden_states: &Tensor,
mask: Option<Tensor>,
train: bool)
-> (Tensor, Option<Vec<Tensor>>, Option<Vec<Vec<Tensor>>>) {
train: bool,
) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Vec<Tensor>>>) {
let mut hidden_state = hidden_states.apply(&self.embedding_hidden_mapping_in);
let mut all_hidden_states: Option<Vec<Tensor>> = if self.output_hidden_states { Some(vec!()) } else { None };
let mut all_attentions: Option<Vec<Vec<Tensor>>> = if self.output_attentions { Some(vec!()) } else { None };
let mut all_hidden_states: Option<Vec<Tensor>> = if self.output_hidden_states {
Some(vec![])
} else {
None
};
let mut all_attentions: Option<Vec<Vec<Tensor>>> = if self.output_attentions {
Some(vec![])
} else {
None
};
for i in 0..self.num_hidden_layers {
let group_idx = i / (self.num_hidden_layers / self.num_hidden_groups);
@ -190,9 +242,8 @@ impl AlbertTransformer {
if let Some(attentions) = all_attentions.borrow_mut() {
attentions.push(attention_weights.unwrap());
};
};
}
(hidden_state, all_hidden_states, all_attentions)
}
}

View File

@ -25,19 +25,26 @@
//! use rust_tokenizers::AlbertTokenizer;
//! use tch::{nn, Device};
//! # use std::path::PathBuf;
//! use rust_bert::albert::{AlbertForMaskedLM, AlbertConfig};
//! use rust_bert::albert::{AlbertConfig, AlbertForMaskedLM};
//! use rust_bert::resources::{download_resource, LocalResource, Resource};
//! use rust_bert::Config;
//! use rust_bert::resources::{Resource, download_resource, LocalResource};
//!
//! let config_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/config.json")});
//! let vocab_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/vocab.txt")});
//! let weights_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/model.ot")});
//! let config_resource = Resource::Local(LocalResource {
//! local_path: PathBuf::from("path/to/config.json"),
//! });
//! let vocab_resource = Resource::Local(LocalResource {
//! local_path: PathBuf::from("path/to/vocab.txt"),
//! });
//! let weights_resource = Resource::Local(LocalResource {
//! local_path: PathBuf::from("path/to/model.ot"),
//! });
//! let config_path = download_resource(&config_resource)?;
//! let vocab_path = download_resource(&vocab_resource)?;
//! let weights_path = download_resource(&weights_resource)?;
//! let device = Device::cuda_if_available();
//! let mut vs = nn::VarStore::new(device);
//! let tokenizer: AlbertTokenizer = AlbertTokenizer::from_file(vocab_path.to_str().unwrap(), true, true);
//! let tokenizer: AlbertTokenizer =
//! AlbertTokenizer::from_file(vocab_path.to_str().unwrap(), true, true);
//! let config = AlbertConfig::from_file(config_path);
//! let bert_model = AlbertForMaskedLM::new(&vs.root(), &config);
//! vs.load(weights_path)?;
@ -46,11 +53,13 @@
//! # }
//! ```
mod encoder;
mod albert;
mod attention;
mod embeddings;
mod albert;
mod encoder;
pub use albert::{AlbertConfig, AlbertModelResources, AlbertConfigResources, AlbertVocabResources, AlbertModel, AlbertForMaskedLM, AlbertForSequenceClassification, AlbertForTokenClassification, AlbertForQuestionAnswering, AlbertForMultipleChoice};
pub use albert::{
AlbertConfig, AlbertConfigResources, AlbertForMaskedLM, AlbertForMultipleChoice,
AlbertForQuestionAnswering, AlbertForSequenceClassification, AlbertForTokenClassification,
AlbertModel, AlbertModelResources, AlbertVocabResources,
};

View File

@ -12,8 +12,8 @@
// limitations under the License.
use crate::common::dropout::Dropout;
use tch::{nn, Tensor};
use tch::kind::Kind::Float;
use tch::{nn, Tensor};
#[derive(Debug)]
/// # Cache for BART attention layers
@ -31,7 +31,7 @@ impl Clone for LayerState {
fn clone(&self) -> Self {
let prev_key_padding_mask = match &self.prev_key_padding_mask {
Some(key_padding_mask) => Some(key_padding_mask.copy()),
None => None
None => None,
};
LayerState {
prev_key: self.prev_key.copy(),
@ -46,12 +46,16 @@ impl LayerState {
self.prev_key = self.prev_key.index_select(0, new_indices);
self.prev_value = self.prev_value.index_select(0, new_indices);
if self.prev_key_padding_mask.is_some() {
self.prev_key_padding_mask = Some(self.prev_key_padding_mask.as_ref().unwrap().index_select(0, new_indices));
self.prev_key_padding_mask = Some(
self.prev_key_padding_mask
.as_ref()
.unwrap()
.index_select(0, new_indices),
);
}
}
}
#[derive(Debug)]
pub struct SelfAttention {
num_heads: i64,
@ -68,8 +72,15 @@ pub struct SelfAttention {
}
impl SelfAttention {
pub fn new(p: nn::Path, embed_dim: i64, num_heads: i64, dropout: f64,
encoder_decoder_attention: bool, store_cache: bool, output_attentions: bool) -> SelfAttention {
pub fn new(
p: nn::Path,
embed_dim: i64,
num_heads: i64,
dropout: f64,
encoder_decoder_attention: bool,
store_cache: bool,
output_attentions: bool,
) -> SelfAttention {
let k_proj = nn::linear(&p / "k_proj", embed_dim, embed_dim, Default::default());
let v_proj = nn::linear(&p / "v_proj", embed_dim, embed_dim, Default::default());
let q_proj = nn::linear(&p / "q_proj", embed_dim, embed_dim, Default::default());
@ -95,58 +106,90 @@ impl SelfAttention {
}
fn flatten(&self, x: Tensor, dim_0: i64, bs: i64) -> Tensor {
x.contiguous().view((dim_0, bs * self.num_heads, self.head_dim)).transpose(0, 1)
x.contiguous()
.view((dim_0, bs * self.num_heads, self.head_dim))
.transpose(0, 1)
}
pub fn forward_t(&self, query: &Tensor,
pub fn forward_t(
&self,
query: &Tensor,
key: Option<&Tensor>,
key_padding_mask: Option<&Tensor>,
attention_mask: Option<&Tensor>,
mut layer_state: Option<LayerState>,
train: bool) -> (Tensor, Option<Tensor>, Option<LayerState>) {
train: bool,
) -> (Tensor, Option<Tensor>, Option<LayerState>) {
let query_size = query.size();
let (target_sequence_length, bs) = (query_size[0], query_size[1]);
let q: Tensor = self.flatten(query.as_ref().apply(&self.q_proj) * self.scaling, target_sequence_length, bs);
let q: Tensor = self.flatten(
query.as_ref().apply(&self.q_proj) * self.scaling,
target_sequence_length,
bs,
);
let key = match &layer_state {
Some(_) => { if self.encoder_decoder_attention { None } else { key } }
None => key
Some(_) => {
if self.encoder_decoder_attention {
None
} else {
key
}
}
None => key,
};
let (k, v) = if self.encoder_decoder_attention {
match key {
Some(key) => {
(Some(self.flatten(key.apply(&self.k_proj), -1, bs)),
Some(self.flatten(key.apply(&self.v_proj), -1, bs))
)
}
None => (None, None)
Some(key) => (
Some(self.flatten(key.apply(&self.k_proj), -1, bs)),
Some(self.flatten(key.apply(&self.v_proj), -1, bs)),
),
None => (None, None),
}
} else {
(Some(self.flatten(query.apply(&self.k_proj), -1, bs)),
Some(self.flatten(query.apply(&self.v_proj), -1, bs))
(
Some(self.flatten(query.apply(&self.k_proj), -1, bs)),
Some(self.flatten(query.apply(&self.v_proj), -1, bs)),
)
};
let (k, v, key_padding_mask) = self.use_saved_state(&layer_state, k, v, key_padding_mask, bs);
let (k, v, key_padding_mask) =
self.use_saved_state(&layer_state, k, v, key_padding_mask, bs);
let source_sequence_length = k.size()[1];
let attention_weights = q.bmm(&k.transpose(1, 2));
let attention_weights = match attention_mask {
Some(mask) => {
let attention_weights = attention_weights.view((bs, self.num_heads, target_sequence_length, source_sequence_length)) + mask;
attention_weights.view((bs * self.num_heads, target_sequence_length, source_sequence_length))
let attention_weights = attention_weights.view((
bs,
self.num_heads,
target_sequence_length,
source_sequence_length,
)) + mask;
attention_weights.view((
bs * self.num_heads,
target_sequence_length,
source_sequence_length,
))
}
None => attention_weights
None => attention_weights,
};
let attention_weights = match key_padding_mask.as_ref() {
Some(mask) => {
attention_weights
.view((bs, self.num_heads, target_sequence_length, source_sequence_length))
Some(mask) => attention_weights
.view((
bs,
self.num_heads,
target_sequence_length,
source_sequence_length,
))
.masked_fill(&mask.unsqueeze(1).unsqueeze(2), std::f64::NEG_INFINITY)
.view((bs * self.num_heads, target_sequence_length, source_sequence_length))
}
None => attention_weights
.view((
bs * self.num_heads,
target_sequence_length,
source_sequence_length,
)),
None => attention_weights,
};
let attention_weights = attention_weights.softmax(-1, Float);
@ -159,16 +202,25 @@ impl SelfAttention {
.apply(&self.out_proj);
let attention_weights = if self.output_attentions {
Some(attention_weights.view((bs, self.num_heads, target_sequence_length, source_sequence_length)))
} else { None };
Some(attention_weights.view((
bs,
self.num_heads,
target_sequence_length,
source_sequence_length,
)))
} else {
None
};
if self.store_cache {
if layer_state.is_some() {
layer_state.as_mut().unwrap().prev_key = k.view((bs, self.num_heads, -1, self.head_dim));
layer_state.as_mut().unwrap().prev_value = v.view((bs, self.num_heads, -1, self.head_dim));
layer_state.as_mut().unwrap().prev_key =
k.view((bs, self.num_heads, -1, self.head_dim));
layer_state.as_mut().unwrap().prev_value =
v.view((bs, self.num_heads, -1, self.head_dim));
layer_state.as_mut().unwrap().prev_key_padding_mask = match key_padding_mask {
Some(tensor) => Some(tensor),
None => None
None => None,
};
} else {
layer_state = Some(LayerState {
@ -176,7 +228,7 @@ impl SelfAttention {
prev_value: v.view((bs, self.num_heads, -1, self.head_dim)),
prev_key_padding_mask: match key_padding_mask {
Some(tensor) => Some(tensor),
None => None
None => None,
},
})
};
@ -185,17 +237,23 @@ impl SelfAttention {
(output, attention_weights, layer_state)
}
fn use_saved_state(&self,
fn use_saved_state(
&self,
layer_state: &Option<LayerState>,
k: Option<Tensor>,
v: Option<Tensor>,
key_padding_mask: Option<&Tensor>,
bs: i64)
-> (Tensor, Tensor, Option<Tensor>) {
bs: i64,
) -> (Tensor, Tensor, Option<Tensor>) {
match &layer_state {
Some(prev_state) => {
let prev_key = prev_state.prev_key.view((bs * self.num_heads, -1, self.head_dim));
let prev_value = prev_state.prev_value.view((bs * self.num_heads, -1, self.head_dim));
let prev_key = prev_state
.prev_key
.view((bs * self.num_heads, -1, self.head_dim));
let prev_value =
prev_state
.prev_value
.view((bs * self.num_heads, -1, self.head_dim));
let k = if self.encoder_decoder_attention {
prev_key
} else {
@ -207,38 +265,53 @@ impl SelfAttention {
Tensor::cat(&[prev_value, v.unwrap()], 1)
};
let key_padding_mask = self.use_saved_key_padding_mask(key_padding_mask,
let key_padding_mask = self.use_saved_key_padding_mask(
key_padding_mask,
&prev_state.prev_key_padding_mask,
bs,
k.size()[1]);
k.size()[1],
);
(k, v, key_padding_mask)
}
None => {
let key_padding_mask = match key_padding_mask {
Some(value) => Some(value.copy()),
None => None
None => None,
};
(k.unwrap(), v.unwrap(), key_padding_mask)
}
}
}
fn use_saved_key_padding_mask(&self, key_padding_mask: Option<&Tensor>, prev_key_padding_mask: &Option<Tensor>,
bs: i64, sequence_length: i64) -> Option<Tensor> {
fn use_saved_key_padding_mask(
&self,
key_padding_mask: Option<&Tensor>,
prev_key_padding_mask: &Option<Tensor>,
bs: i64,
sequence_length: i64,
) -> Option<Tensor> {
if prev_key_padding_mask.is_some() {
if self.encoder_decoder_attention {
Some(prev_key_padding_mask.as_ref().unwrap().copy())
} else {
Some(Tensor::cat(&[prev_key_padding_mask.as_ref().unwrap(), key_padding_mask.as_ref().unwrap()], 1))
Some(Tensor::cat(
&[
prev_key_padding_mask.as_ref().unwrap(),
key_padding_mask.as_ref().unwrap(),
],
1,
))
}
} else {
match key_padding_mask {
Some(key_padding_mask) => {
let filler = Tensor::zeros(&[bs, sequence_length - key_padding_mask.size()[1]],
(key_padding_mask.kind(), key_padding_mask.device()));
let filler = Tensor::zeros(
&[bs, sequence_length - key_padding_mask.size()[1]],
(key_padding_mask.kind(), key_padding_mask.device()),
);
Some(Tensor::cat(&[filler, key_padding_mask.copy()], 1))
}
None => None
None => None,
}
}
}

View File

@ -11,18 +11,18 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use crate::Config;
use tch::{Tensor, nn};
use tch::kind::Kind::{Int64, Float};
use crate::bart::encoder::BartEncoder;
use crate::bart::decoder::BartDecoder;
use tch::nn::{embedding, EmbeddingConfig};
use crate::bart::attention::LayerState;
use std::borrow::BorrowMut;
use crate::bart::decoder::BartDecoder;
use crate::bart::encoder::BartEncoder;
use crate::common::dropout::Dropout;
use crate::pipelines::generation::{Cache, LMHeadModel};
use crate::Config;
use serde::{Deserialize, Serialize};
use std::borrow::BorrowMut;
use std::collections::HashMap;
use tch::kind::Kind::{Float, Int64};
use tch::nn::{embedding, EmbeddingConfig};
use tch::{nn, Tensor};
/// # BART Pretrained model weight files
pub struct BartModelResources;
@ -38,38 +38,74 @@ pub struct BartMergesResources;
impl BartModelResources {
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
pub const BART: (&'static str, &'static str) = ("bart/model.ot", "https://cdn.huggingface.co/facebook/bart-large/rust_model.ot");
pub const BART: (&'static str, &'static str) = (
"bart/model.ot",
"https://cdn.huggingface.co/facebook/bart-large/rust_model.ot",
);
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
pub const BART_CNN: (&'static str, &'static str) = ("bart-cnn/model.ot", "https://cdn.huggingface.co/facebook/bart-large-cnn/rust_model.ot");
pub const BART_CNN: (&'static str, &'static str) = (
"bart-cnn/model.ot",
"https://cdn.huggingface.co/facebook/bart-large-cnn/rust_model.ot",
);
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
pub const BART_XSUM: (&'static str, &'static str) = ("bart-xsum/model.ot", "https://cdn.huggingface.co/facebook/bart-large-xsum/rust_model.ot");
pub const BART_XSUM: (&'static str, &'static str) = (
"bart-xsum/model.ot",
"https://cdn.huggingface.co/facebook/bart-large-xsum/rust_model.ot",
);
}
impl BartConfigResources {
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
pub const BART: (&'static str, &'static str) = ("bart/config.json", "https://cdn.huggingface.co/facebook/bart-large/config.json");
pub const BART: (&'static str, &'static str) = (
"bart/config.json",
"https://cdn.huggingface.co/facebook/bart-large/config.json",
);
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
pub const BART_CNN: (&'static str, &'static str) = ("bart-cnn/config.json", "https://cdn.huggingface.co/facebook/bart-large-cnn/config.json");
pub const BART_CNN: (&'static str, &'static str) = (
"bart-cnn/config.json",
"https://cdn.huggingface.co/facebook/bart-large-cnn/config.json",
);
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
pub const BART_XSUM: (&'static str, &'static str) = ("bart-xsum/config.json", "https://cdn.huggingface.co/facebook/bart-large-xsum/config.json");
pub const BART_XSUM: (&'static str, &'static str) = (
"bart-xsum/config.json",
"https://cdn.huggingface.co/facebook/bart-large-xsum/config.json",
);
}
impl BartVocabResources {
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
pub const BART: (&'static str, &'static str) = ("bart/vocab.txt", "https://cdn.huggingface.co/roberta-large-vocab.json");
pub const BART: (&'static str, &'static str) = (
"bart/vocab.txt",
"https://cdn.huggingface.co/roberta-large-vocab.json",
);
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
pub const BART_CNN: (&'static str, &'static str) = ("bart-cnn/vocab.txt", "https://cdn.huggingface.co/roberta-large-vocab.json");
pub const BART_CNN: (&'static str, &'static str) = (
"bart-cnn/vocab.txt",
"https://cdn.huggingface.co/roberta-large-vocab.json",
);
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
pub const BART_XSUM: (&'static str, &'static str) = ("bart-xsum/vocab.txt", "https://cdn.huggingface.co/roberta-large-vocab.json");
pub const BART_XSUM: (&'static str, &'static str) = (
"bart-xsum/vocab.txt",
"https://cdn.huggingface.co/roberta-large-vocab.json",
);
}
impl BartMergesResources {
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
pub const BART: (&'static str, &'static str) = ("bart/merges.txt", "https://cdn.huggingface.co/roberta-large-merges.txt");
pub const BART: (&'static str, &'static str) = (
"bart/merges.txt",
"https://cdn.huggingface.co/roberta-large-merges.txt",
);
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
pub const BART_CNN: (&'static str, &'static str) = ("bart-cnn/merges.txt", "https://cdn.huggingface.co/roberta-large-merges.txt");
pub const BART_CNN: (&'static str, &'static str) = (
"bart-cnn/merges.txt",
"https://cdn.huggingface.co/roberta-large-merges.txt",
);
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
pub const BART_XSUM: (&'static str, &'static str) = ("bart-xsum/merges.txt", "https://cdn.huggingface.co/roberta-large-merges.txt");
pub const BART_XSUM: (&'static str, &'static str) = (
"bart-xsum/merges.txt",
"https://cdn.huggingface.co/roberta-large-merges.txt",
);
}
#[allow(non_camel_case_types)]
@ -130,14 +166,15 @@ pub struct BartConfig {
impl Config<BartConfig> for BartConfig {}
fn _prepare_bart_decoder_inputs(pad_token_id: i64,
fn _prepare_bart_decoder_inputs(
pad_token_id: i64,
input_ids: &Tensor,
decoder_input_ids: Option<&Tensor>,
decoder_padding_mask: Option<&Tensor>)
-> (Tensor, Option<Tensor>, Option<Tensor>) {
decoder_padding_mask: Option<&Tensor>,
) -> (Tensor, Option<Tensor>, Option<Tensor>) {
let decoder_input_ids = match decoder_input_ids {
Some(value) => value.copy(),
None => _shift_tokens_right(input_ids, pad_token_id)
None => _shift_tokens_right(input_ids, pad_token_id),
};
let decoder_padding_mask = match decoder_padding_mask {
@ -159,7 +196,6 @@ fn _prepare_bart_decoder_inputs(pad_token_id: i64,
(decoder_input_ids, decoder_padding_mask, Some(causal_mask))
}
fn _shift_tokens_right(input_ids: &Tensor, pad_token_id: i64) -> Tensor {
let index_eos: Tensor = input_ids.ne(pad_token_id).sum1(&[-1], true, Int64) - 1;
let output = input_ids.empty_like().to_kind(Int64);
@ -200,10 +236,10 @@ impl BartModel {
/// # Example
///
/// ```no_run
/// use tch::{nn, Device};
/// use rust_bert::bart::{BartConfig, BartModel};
/// use rust_bert::Config;
/// use std::path::Path;
/// use rust_bert::bart::{BartConfig, BartModel};
/// use tch::{nn, Device};
///
/// let config_path = Path::new("path/to/config.json");
/// let device = Device::Cpu;
@ -212,22 +248,32 @@ impl BartModel {
/// let generation_mode = true;
/// let bart: BartModel = BartModel::new(&(&p.root() / "bart"), &config, generation_mode);
/// ```
///
pub fn new(p: &nn::Path, config: &BartConfig, generation_mode: bool) -> BartModel {
let pad_token_id = match config.pad_token_id {
Some(value) => value,
None => 1
None => 1,
};
let embedding_config = EmbeddingConfig { padding_idx: pad_token_id, ..Default::default() };
let embeddings: nn::Embedding = embedding(p / "shared",
let embedding_config = EmbeddingConfig {
padding_idx: pad_token_id,
..Default::default()
};
let embeddings: nn::Embedding = embedding(
p / "shared",
config.vocab_size,
config.d_model,
embedding_config);
embedding_config,
);
let encoder = BartEncoder::new(p / "encoder", config);
let decoder = BartDecoder::new(p / "decoder", config, generation_mode);
BartModel { encoder, decoder, generation_mode, pad_token_id, embeddings }
BartModel {
encoder,
decoder,
generation_mode,
pad_token_id,
embeddings,
}
}
/// Forward pass through the model
@ -271,68 +317,102 @@ impl BartModel {
/// let (batch_size, source_sequence_length, target_sequence_length) = (64, 128, 56);
/// let input_tensor = Tensor::rand(&[batch_size, source_sequence_length], (Int64, device));
/// let target_tensor = Tensor::rand(&[batch_size, target_sequence_length], (Int64, device));
/// let encoder_attention_mask = Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
/// let decoder_attention_mask = Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
/// let encoder_attention_mask =
/// Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
/// let decoder_attention_mask =
/// Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
///
/// let (decoder_output, encoder_hidden_states, decoder_cache,
/// all_encoder_hidden_states, all_encoder_attentions,
/// all_decoder_hidden_states, all_decoder_attentions) = no_grad(|| {
/// bart_model
/// .forward_t(Some(&input_tensor),
/// let (
/// decoder_output,
/// encoder_hidden_states,
/// decoder_cache,
/// all_encoder_hidden_states,
/// all_encoder_attentions,
/// all_decoder_hidden_states,
/// all_decoder_attentions,
/// ) = no_grad(|| {
/// bart_model.forward_t(
/// Some(&input_tensor),
/// Some(&encoder_attention_mask),
/// Some(&target_tensor),
/// None,
/// Some(&decoder_attention_mask),
/// None,
/// false)
/// false,
/// )
/// });
///
/// ```
///
pub fn forward_t(&self,
pub fn forward_t(
&self,
input_ids: Option<&Tensor>,
attention_mask: Option<&Tensor>,
decoder_input_ids: Option<&Tensor>,
encoder_outputs: Option<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>)>,
decoder_attention_mask: Option<&Tensor>,
layer_states: Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
train: bool) ->
(Tensor, Tensor, Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
Option<Vec<Tensor>>, Option<Vec<Tensor>>,
Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
train: bool,
) -> (
Tensor,
Tensor,
Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
Option<Vec<Tensor>>,
Option<Vec<Tensor>>,
Option<Vec<Tensor>>,
Option<Vec<Tensor>>,
) {
let (decoder_input_ids, decoder_padding_mask, causal_mask) = if self.generation_mode {
(decoder_input_ids.unwrap().copy(), None, None)
} else {
assert!(input_ids.is_some(), "input_ids must be provided when not in generation mode");
_prepare_bart_decoder_inputs(self.pad_token_id, input_ids.unwrap(), decoder_input_ids, decoder_attention_mask)
assert!(
input_ids.is_some(),
"input_ids must be provided when not in generation mode"
);
_prepare_bart_decoder_inputs(
self.pad_token_id,
input_ids.unwrap(),
decoder_input_ids,
decoder_attention_mask,
)
};
let (encoder_hidden_states,
all_encoder_hidden_states,
all_encoder_attentions) = match encoder_outputs {
let (encoder_hidden_states, all_encoder_hidden_states, all_encoder_attentions) =
match encoder_outputs {
Some(value) => value,
None => {
assert!(input_ids.is_some(), "input_ids must be provided when encoder output is not pre-computed");
self.encoder.forward_t(input_ids.unwrap(), attention_mask, &self.embeddings, train)
assert!(
input_ids.is_some(),
"input_ids must be provided when encoder output is not pre-computed"
);
self.encoder.forward_t(
input_ids.unwrap(),
attention_mask,
&self.embeddings,
train,
)
}
};
let (decoder_outputs,
decoder_cache,
all_decoder_hidden_states,
all_decoder_attentions) = self.decoder.forward_t(&decoder_input_ids,
let (decoder_outputs, decoder_cache, all_decoder_hidden_states, all_decoder_attentions) =
self.decoder.forward_t(
&decoder_input_ids,
&encoder_hidden_states,
attention_mask,
decoder_padding_mask.as_ref(),
causal_mask.as_ref(),
&self.embeddings,
layer_states,
train);
(decoder_outputs, encoder_hidden_states, decoder_cache.1,
all_decoder_hidden_states, all_decoder_attentions,
all_encoder_hidden_states, all_encoder_attentions)
train,
);
(
decoder_outputs,
encoder_hidden_states,
decoder_cache.1,
all_decoder_hidden_states,
all_decoder_attentions,
all_encoder_hidden_states,
all_encoder_attentions,
)
}
}
/// # BART Model for conditional generation
@ -356,20 +436,24 @@ impl BartForConditionalGeneration {
/// # Example
///
/// ```no_run
/// use tch::{nn, Device};
/// use rust_bert::bart::{BartConfig, BartForConditionalGeneration};
/// use rust_bert::Config;
/// use std::path::Path;
/// use rust_bert::bart::{BartConfig, BartForConditionalGeneration};
/// use tch::{nn, Device};
///
/// let config_path = Path::new("path/to/config.json");
/// let device = Device::Cpu;
/// let p = nn::VarStore::new(device);
/// let config = BartConfig::from_file(config_path);
/// let generation_mode = true;
/// let bart: BartForConditionalGeneration = BartForConditionalGeneration::new(&(&p.root() / "bart"), &config, generation_mode);
/// let bart: BartForConditionalGeneration =
/// BartForConditionalGeneration::new(&(&p.root() / "bart"), &config, generation_mode);
/// ```
///
pub fn new(p: &nn::Path, config: &BartConfig, generation_mode: bool) -> BartForConditionalGeneration {
pub fn new(
p: &nn::Path,
config: &BartConfig,
generation_mode: bool,
) -> BartForConditionalGeneration {
let base_model = BartModel::new(&(p / "model"), config, generation_mode);
BartForConditionalGeneration { base_model }
}
@ -427,37 +511,64 @@ impl BartForConditionalGeneration {
/// None,
/// false)
/// });
///
/// ```
///
pub fn forward_t(&self,
pub fn forward_t(
&self,
input_ids: Option<&Tensor>,
attention_mask: Option<&Tensor>,
encoder_outputs: Option<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>)>,
decoder_input_ids: Option<&Tensor>,
decoder_attention_mask: Option<&Tensor>,
old_layer_states: Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
train: bool)
-> (Tensor, Tensor, Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
Option<Vec<Tensor>>, Option<Vec<Tensor>>,
Option<Vec<Tensor>>, Option<Vec<Tensor>>)
{
let (decoder_outputs, encoder_hidden_states, decoder_cache,
all_decoder_hidden_states, all_decoder_attentions,
all_encoder_hidden_states, all_encoder_attentions) =
self.base_model.forward_t(input_ids, attention_mask, decoder_input_ids, encoder_outputs, decoder_attention_mask, old_layer_states, train);
train: bool,
) -> (
Tensor,
Tensor,
Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
Option<Vec<Tensor>>,
Option<Vec<Tensor>>,
Option<Vec<Tensor>>,
Option<Vec<Tensor>>,
) {
let (
decoder_outputs,
encoder_hidden_states,
decoder_cache,
all_decoder_hidden_states,
all_decoder_attentions,
all_encoder_hidden_states,
all_encoder_attentions,
) = self.base_model.forward_t(
input_ids,
attention_mask,
decoder_input_ids,
encoder_outputs,
decoder_attention_mask,
old_layer_states,
train,
);
let lm_logits = decoder_outputs.linear::<Tensor>(&self.base_model.embeddings.ws, None);
(lm_logits, encoder_hidden_states, decoder_cache,
all_decoder_hidden_states, all_decoder_attentions,
all_encoder_hidden_states, all_encoder_attentions)
(
lm_logits,
encoder_hidden_states,
decoder_cache,
all_decoder_hidden_states,
all_decoder_attentions,
all_encoder_hidden_states,
all_encoder_attentions,
)
}
pub fn encode(&self, input_ids: &Tensor, attention_mask: Option<&Tensor>) -> Tensor {
let (encoder_hidden_states, _, _) = self.base_model.encoder.forward_t(input_ids, attention_mask, &self.base_model.embeddings, false);
let (encoder_hidden_states, _, _) = self.base_model.encoder.forward_t(
input_ids,
attention_mask,
&self.base_model.embeddings,
false,
);
encoder_hidden_states
}
}
pub struct BartClassificationHead {
@ -468,16 +579,29 @@ pub struct BartClassificationHead {
impl BartClassificationHead {
pub fn new(p: &nn::Path, config: &BartConfig) -> BartClassificationHead {
let dense = nn::linear(&(p / "dense"), config.d_model, config.d_model, Default::default());
let dense = nn::linear(
&(p / "dense"),
config.d_model,
config.d_model,
Default::default(),
);
let dropout = Dropout::new(config.classif_dropout);
let out_proj = nn::linear(&(p / "out_proj"), config.d_model, config.num_labels.unwrap(), Default::default());
let out_proj = nn::linear(
&(p / "out_proj"),
config.d_model,
config.num_labels.unwrap(),
Default::default(),
);
BartClassificationHead { dense, dropout, out_proj }
BartClassificationHead {
dense,
dropout,
out_proj,
}
}
pub fn forward_t(&self, x: &Tensor, train: bool) -> Tensor {
x
.apply_t(&self.dropout, train)
x.apply_t(&self.dropout, train)
.apply(&self.dense)
.tanh()
.apply_t(&self.dropout, train)
@ -497,7 +621,6 @@ pub struct BartForSequenceClassification {
eos_token_id: i64,
}
impl BartForSequenceClassification {
/// Build a new `BartForSequenceClassification`
///
@ -509,27 +632,31 @@ impl BartForSequenceClassification {
/// # Example
///
/// ```no_run
/// use tch::{nn, Device};
/// use rust_bert::bart::{BartConfig, BartForSequenceClassification};
/// use rust_bert::Config;
/// use std::path::Path;
/// use rust_bert::bart::{BartConfig, BartForSequenceClassification};
/// use tch::{nn, Device};
///
/// let config_path = Path::new("path/to/config.json");
/// let device = Device::Cpu;
/// let p = nn::VarStore::new(device);
/// let config = BartConfig::from_file(config_path);
/// let generation_mode = true;
/// let bart: BartForSequenceClassification = BartForSequenceClassification::new(&(&p.root() / "bart"), &config);
/// let bart: BartForSequenceClassification =
/// BartForSequenceClassification::new(&(&p.root() / "bart"), &config);
/// ```
///
pub fn new(p: &nn::Path, config: &BartConfig) -> BartForSequenceClassification {
let base_model = BartModel::new(&(p / "model"), config, false);
let classification_head = BartClassificationHead::new(&(p / "classification_head"), config);
let eos_token_id = match config.eos_token_id {
Some(value) => value,
None => 3
None => 3,
};
BartForSequenceClassification { base_model, classification_head, eos_token_id }
BartForSequenceClassification {
base_model,
classification_head,
eos_token_id,
}
}
/// Forward pass through the model
@ -585,36 +712,63 @@ impl BartForSequenceClassification {
/// None,
/// false)
/// });
///
/// ```
///
pub fn forward_t(&mut self,
pub fn forward_t(
&mut self,
input_ids: &Tensor,
attention_mask: Option<&Tensor>,
encoder_outputs: Option<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>)>,
decoder_input_ids: Option<&Tensor>,
decoder_attention_mask: Option<&Tensor>,
train: bool)
-> (Tensor, Tensor,
Option<Vec<Tensor>>, Option<Vec<Tensor>>,
Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
let (decoder_outputs, encoder_hidden_states, _,
all_decoder_hidden_states, all_decoder_attentions,
all_encoder_hidden_states, all_encoder_attentions) =
self.borrow_mut().base_model.forward_t(Some(input_ids), attention_mask, decoder_input_ids, encoder_outputs, decoder_attention_mask, None, train);
train: bool,
) -> (
Tensor,
Tensor,
Option<Vec<Tensor>>,
Option<Vec<Tensor>>,
Option<Vec<Tensor>>,
Option<Vec<Tensor>>,
) {
let (
decoder_outputs,
encoder_hidden_states,
_,
all_decoder_hidden_states,
all_decoder_attentions,
all_encoder_hidden_states,
all_encoder_attentions,
) = self.borrow_mut().base_model.forward_t(
Some(input_ids),
attention_mask,
decoder_input_ids,
encoder_outputs,
decoder_attention_mask,
None,
train,
);
let eos_mask = input_ids.eq(self.eos_token_id);
let sentence_representation = decoder_outputs
.index_select(0, &eos_mask)
.view((decoder_outputs.size()[0], -1, *decoder_outputs.size().last().unwrap()))
.view((
decoder_outputs.size()[0],
-1,
*decoder_outputs.size().last().unwrap(),
))
.select(1, -1);
let logits = self.classification_head.forward_t(&sentence_representation, train);
(logits, encoder_hidden_states,
all_decoder_hidden_states, all_decoder_attentions,
all_encoder_hidden_states, all_encoder_attentions)
let logits = self
.classification_head
.forward_t(&sentence_representation, train);
(
logits,
encoder_hidden_states,
all_decoder_hidden_states,
all_decoder_attentions,
all_encoder_hidden_states,
all_encoder_attentions,
)
}
}
impl LMHeadModel for BartForConditionalGeneration {
@ -675,10 +829,9 @@ impl LMHeadModel for BartForConditionalGeneration {
/// None,
/// false)
/// });
///
/// ```
///
fn forward_t(&self,
fn forward_t(
&self,
input_ids: &Option<Tensor>,
cache: Cache,
attention_mask: &Option<Tensor>,
@ -687,27 +840,47 @@ impl LMHeadModel for BartForConditionalGeneration {
_input_embeds: &Option<Tensor>,
encoder_outputs: Option<&Tensor>,
decoder_input_ids: &Option<Tensor>,
train: bool) -> Result<(Tensor, Option<Tensor>, Cache, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
train: bool,
) -> Result<
(
Tensor,
Option<Tensor>,
Cache,
Option<Vec<Tensor>>,
Option<Vec<Tensor>>,
),
&'static str,
> {
let (decoder_output, encoder_hidden_states, new_cache, _, _, _, _) = match cache {
Cache::BARTCache(cached_layer_states) => self.base_model.forward_t(input_ids.as_ref(),
Cache::BARTCache(cached_layer_states) => self.base_model.forward_t(
input_ids.as_ref(),
attention_mask.as_ref(),
decoder_input_ids.as_ref(),
Some((encoder_outputs.as_ref().unwrap().copy(), None, None)),
None,
cached_layer_states,
train),
train,
),
Cache::None => self.base_model.forward_t(input_ids.as_ref(),
Cache::None => self.base_model.forward_t(
input_ids.as_ref(),
attention_mask.as_ref(),
decoder_input_ids.as_ref(),
Some((encoder_outputs.as_ref().unwrap().copy(), None, None)),
None,
None,
train),
_ => Err("Cache not compatible with BART Model")?
train,
),
_ => Err("Cache not compatible with BART Model")?,
};
let lm_logits = decoder_output.linear::<Tensor>(&self.base_model.embeddings.ws, None);
Ok((lm_logits, Some(encoder_hidden_states), Cache::BARTCache(new_cache), None, None))
Ok((
lm_logits,
Some(encoder_hidden_states),
Cache::BARTCache(new_cache),
None,
None,
))
}
}

View File

@ -11,15 +11,17 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use crate::bart::attention::{SelfAttention, LayerState};
use tch::{nn, Tensor};
use crate::common::dropout::Dropout;
use crate::bart::BartConfig;
use crate::bart::attention::{LayerState, SelfAttention};
use crate::bart::bart::Activation;
use crate::common::activations::{_gelu, _relu, _swish, _gelu_new, _tanh};
use tch::kind::Kind::Int64;
use crate::bart::embeddings::{
EmbeddingOption, LearnedPositionalEmbedding, SinusoidalPositionalEmbedding,
};
use crate::bart::BartConfig;
use crate::common::activations::{_gelu, _gelu_new, _relu, _swish, _tanh};
use crate::common::dropout::Dropout;
use std::borrow::BorrowMut;
use crate::bart::embeddings::{EmbeddingOption, LearnedPositionalEmbedding, SinusoidalPositionalEmbedding};
use tch::kind::Kind::Int64;
use tch::{nn, Tensor};
pub struct DecoderLayer {
self_attention: SelfAttention,
@ -36,51 +38,74 @@ pub struct DecoderLayer {
impl DecoderLayer {
pub fn new(p: nn::Path, config: &BartConfig) -> DecoderLayer {
let layer_norm_config = nn::LayerNormConfig { eps: 1e-5, ..Default::default() };
let layer_norm_config = nn::LayerNormConfig {
eps: 1e-5,
..Default::default()
};
let output_attention = match config.output_attentions {
Some(value) => value,
None => false
None => false,
};
let self_attention = SelfAttention::new(&p / "self_attn",
let self_attention = SelfAttention::new(
&p / "self_attn",
config.d_model,
config.decoder_attention_heads,
config.attention_dropout,
false,
true,
output_attention);
let encoder_attention = SelfAttention::new(&p / "encoder_attn",
output_attention,
);
let encoder_attention = SelfAttention::new(
&p / "encoder_attn",
config.d_model,
config.decoder_attention_heads,
config.attention_dropout,
true,
true,
output_attention);
let self_attention_layer_norm = nn::layer_norm(&p / "self_attn_layer_norm",
output_attention,
);
let self_attention_layer_norm = nn::layer_norm(
&p / "self_attn_layer_norm",
vec![config.d_model],
layer_norm_config);
let encoder_attention_layer_norm = nn::layer_norm(&p / "encoder_attn_layer_norm",
layer_norm_config,
);
let encoder_attention_layer_norm = nn::layer_norm(
&p / "encoder_attn_layer_norm",
vec![config.d_model],
layer_norm_config);
layer_norm_config,
);
let dropout = Dropout::new(config.dropout);
let activation_dropout = Dropout::new(config.activation_dropout);
let activation_function = match &config.activation_function {
Some(act_function) => act_function,
None => &Activation::gelu
None => &Activation::gelu,
};
let activation = Box::new(match activation_function {
Activation::gelu => _gelu,
Activation::relu => _relu,
Activation::swish => _swish,
Activation::gelu_new => _gelu_new,
Activation::tanh => _tanh
Activation::tanh => _tanh,
});
let fc1 = nn::linear(&p / "fc1", config.d_model, config.decoder_ffn_dim, Default::default());
let fc2 = nn::linear(&p / "fc2", config.decoder_ffn_dim, config.d_model, Default::default());
let fc1 = nn::linear(
&p / "fc1",
config.d_model,
config.decoder_ffn_dim,
Default::default(),
);
let fc2 = nn::linear(
&p / "fc2",
config.decoder_ffn_dim,
config.d_model,
Default::default(),
);
let final_layer_norm = nn::layer_norm(&p / "final_layer_norm",
let final_layer_norm = nn::layer_norm(
&p / "final_layer_norm",
vec![config.d_model],
layer_norm_config);
layer_norm_config,
);
DecoderLayer {
self_attention,
@ -96,18 +121,38 @@ impl DecoderLayer {
}
}
pub fn forward_t(&self,
pub fn forward_t(
&self,
x: &Tensor,
encoder_hidden_states: &Tensor,
encoder_attn_mask: Option<&Tensor>,
causal_mask: Option<&Tensor>,
decoder_padding_mask: Option<&Tensor>,
layer_states: (Option<LayerState>, Option<LayerState>),
train: bool) -> (Tensor, Option<Tensor>, (Option<LayerState>, Option<LayerState>)) {
let (output, attention_weights, new_self_layer_states) = self.self_attention.forward_t(x, Some(x), decoder_padding_mask, causal_mask, layer_states.0, train);
train: bool,
) -> (
Tensor,
Option<Tensor>,
(Option<LayerState>, Option<LayerState>),
) {
let (output, attention_weights, new_self_layer_states) = self.self_attention.forward_t(
x,
Some(x),
decoder_padding_mask,
causal_mask,
layer_states.0,
train,
);
let output: Tensor = output.apply_t(&self.dropout, train) + x;
let output = output.apply(&self.self_attention_layer_norm);
let (output1, _, new_encoder_layer_states) = self.encoder_attention.forward_t(&output, Some(encoder_hidden_states), encoder_attn_mask, None, layer_states.1, train);
let (output1, _, new_encoder_layer_states) = self.encoder_attention.forward_t(
&output,
Some(encoder_hidden_states),
encoder_attn_mask,
None,
layer_states.1,
train,
);
let output1: Tensor = output1.apply_t(&self.dropout, train) + output;
let output1 = output1.apply(&self.encoder_attention_layer_norm);
let output2 = (self.activation)(&output1.apply(&self.fc1));
@ -116,7 +161,11 @@ impl DecoderLayer {
.apply(&self.fc2)
.apply_t(&self.dropout, train);
let output2: Tensor = output2 + output1;
(output2.apply(&self.final_layer_norm), attention_weights, (new_self_layer_states, new_encoder_layer_states))
(
output2.apply(&self.final_layer_norm),
attention_weights,
(new_self_layer_states, new_encoder_layer_states),
)
}
}
@ -136,61 +185,76 @@ impl BartDecoder {
pub fn new(p: nn::Path, config: &BartConfig, generation_mode: bool) -> BartDecoder {
let output_past = match config.output_past {
Some(value) => value,
None => true
None => true,
};
let output_attentions = match config.output_attentions {
Some(value) => value,
None => false
None => false,
};
let output_hidden_states = match config.output_hidden_states {
Some(value) => value,
None => false
None => false,
};
let normalize_embedding = match config.normalize_embedding {
Some(value) => value,
None => true
None => true,
};
let static_position_embeddings = match config.static_position_embeddings {
Some(value) => value,
None => false
None => false,
};
let scale_embedding = match config.scale_embedding {
Some(value) => if value { (config.d_model as f64).sqrt() } else { 1.0 },
None => 1.0
Some(value) => {
if value {
(config.d_model as f64).sqrt()
} else {
1.0
}
}
None => 1.0,
};
let dropout = Dropout::new(config.dropout);
let layer_norm_embedding = if normalize_embedding {
let layer_norm_config = nn::LayerNormConfig { eps: 1e-5, ..Default::default() };
Some(nn::layer_norm(&p / "layernorm_embedding",
let layer_norm_config = nn::LayerNormConfig {
eps: 1e-5,
..Default::default()
};
Some(nn::layer_norm(
&p / "layernorm_embedding",
vec![config.d_model],
layer_norm_config))
layer_norm_config,
))
} else {
None
};
let pad_token_id = match config.pad_token_id {
Some(value) => value,
None => 1
None => 1,
};
let embed_positions = if static_position_embeddings {
EmbeddingOption::SinusoidalPositionalEmbedding(SinusoidalPositionalEmbedding::new(&p / "embed_positions",
config.max_position_embeddings,
config.d_model))
} else {
EmbeddingOption::LearnedPositionalEmbedding(LearnedPositionalEmbedding::new(&p / "embed_positions",
EmbeddingOption::SinusoidalPositionalEmbedding(SinusoidalPositionalEmbedding::new(
&p / "embed_positions",
config.max_position_embeddings,
config.d_model,
pad_token_id))
))
} else {
EmbeddingOption::LearnedPositionalEmbedding(LearnedPositionalEmbedding::new(
&p / "embed_positions",
config.max_position_embeddings,
config.d_model,
pad_token_id,
))
};
let mut layers: Vec<DecoderLayer> = vec!();
let mut layers: Vec<DecoderLayer> = vec![];
let p_layers = &p / "layers";
for layer_index in 0..config.decoder_layers {
layers.push(DecoderLayer::new(&p_layers / layer_index, config));
};
}
BartDecoder {
dropout,
@ -205,7 +269,8 @@ impl BartDecoder {
}
}
pub fn forward_t(&self,
pub fn forward_t(
&self,
input_ids: &Tensor,
encoder_hidden_states: &Tensor,
encoder_padding_mask: Option<&Tensor>,
@ -213,33 +278,56 @@ impl BartDecoder {
decoder_causal_mask: Option<&Tensor>,
embeddings: &nn::Embedding,
old_layer_states: Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
train: bool)
-> (Tensor,
(Option<Tensor>, Option<Vec<(Option<LayerState>, Option<LayerState>)>>),
train: bool,
) -> (
Tensor,
(
Option<Tensor>,
Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
),
Option<Vec<Tensor>>,
Option<Vec<Tensor>>) {
Option<Vec<Tensor>>,
) {
let encoder_padding_mask = match encoder_padding_mask {
Some(mask) => Some(mask.eq(0).to_kind(Int64)),
None => None
None => None,
};
let positions = self.embed_positions.forward(input_ids, self.generation_mode);
let positions = self
.embed_positions
.forward(input_ids, self.generation_mode);
let x: Tensor = if self.generation_mode {
let end_inputs = input_ids.size()[1];
let end_positions = positions.size()[1];
input_ids.narrow(1, end_inputs - 1, 1).apply(embeddings) * self.scale_embedding + positions.narrow(1, end_positions - 1, 1)
input_ids.narrow(1, end_inputs - 1, 1).apply(embeddings) * self.scale_embedding
+ positions.narrow(1, end_positions - 1, 1)
} else {
input_ids.apply(embeddings) * self.scale_embedding + positions
};
let x = if let Some(layer_norm_embedding) = &self.layer_norm_embedding { x.apply(layer_norm_embedding) } else { x };
let mut hidden_state = x
.apply_t(&self.dropout, train)
.transpose(0, 1);
let mut all_hidden_states: Option<Vec<Tensor>> = if self.output_hidden_states { Some(Vec::with_capacity(self.layers.len())) } else { None };
let mut all_attentions: Option<Vec<Tensor>> = if self.output_attentions { Some(Vec::with_capacity(self.layers.len())) } else { None };
let mut next_decoder_cache: Option<Vec<(Option<LayerState>, Option<LayerState>)>> = if self.output_past {
if old_layer_states.is_some() { old_layer_states } else { Some(vec!((None, None); self.layers.len())) }
let x = if let Some(layer_norm_embedding) = &self.layer_norm_embedding {
x.apply(layer_norm_embedding)
} else {
x
};
let mut hidden_state = x.apply_t(&self.dropout, train).transpose(0, 1);
let mut all_hidden_states: Option<Vec<Tensor>> = if self.output_hidden_states {
Some(Vec::with_capacity(self.layers.len()))
} else {
None
};
let mut all_attentions: Option<Vec<Tensor>> = if self.output_attentions {
Some(Vec::with_capacity(self.layers.len()))
} else {
None
};
let mut next_decoder_cache: Option<Vec<(Option<LayerState>, Option<LayerState>)>> =
if self.output_past {
if old_layer_states.is_some() {
old_layer_states
} else {
Some(vec![(None, None); self.layers.len()])
}
} else {
None
};
@ -252,15 +340,17 @@ impl BartDecoder {
Some((layer_idx, layer)) => {
let layer_state = match &next_decoder_cache {
Some(values) => values[layer_idx].to_owned(),
None => (None, None)
None => (None, None),
};
let temp = layer.forward_t(&hidden_state,
let temp = layer.forward_t(
&hidden_state,
&encoder_hidden_states,
encoder_padding_mask.as_ref(),
decoder_causal_mask,
decoder_padding_mask,
layer_state,
train);
train,
);
hidden_state = temp.0;
attention_weights = temp.1;
if let Some(hidden_states) = all_hidden_states.borrow_mut() {
@ -269,15 +359,19 @@ impl BartDecoder {
if let Some(attentions) = all_attentions.borrow_mut() {
attentions.push(attention_weights.as_ref().unwrap().copy());
};
if let Some(value) = &mut next_decoder_cache { value[layer_idx] = temp.2 };
if let Some(value) = &mut next_decoder_cache {
value[layer_idx] = temp.2
};
}
None => break
};
None => break,
};
}
(hidden_state.transpose(0, 1),
(
hidden_state.transpose(0, 1),
(encoder_padding_mask, next_decoder_cache),
all_hidden_states,
all_attentions)
all_attentions,
)
}
}

View File

@ -11,10 +11,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use tch::{nn, Tensor};
use tch::nn::{EmbeddingConfig, embedding};
use tch::kind::Kind::Int64;
use tch::nn::{embedding, EmbeddingConfig};
use tch::{nn, Tensor};
/// # Abstraction that holds a embeddings configuration
pub enum EmbeddingOption {
@ -27,8 +26,12 @@ impl EmbeddingOption {
/// Interface method to forward_t() of the particular models.
pub fn forward(&self, input: &Tensor, generation_mode: bool) -> Tensor {
match *self {
Self::LearnedPositionalEmbedding(ref embeddings) => embeddings.forward(input, generation_mode),
Self::SinusoidalPositionalEmbedding(ref embeddings) => embeddings.forward(input, generation_mode)
Self::LearnedPositionalEmbedding(ref embeddings) => {
embeddings.forward(input, generation_mode)
}
Self::SinusoidalPositionalEmbedding(ref embeddings) => {
embeddings.forward(input, generation_mode)
}
}
}
}
@ -40,15 +43,24 @@ pub struct LearnedPositionalEmbedding {
}
impl LearnedPositionalEmbedding {
pub fn new(p: nn::Path, num_embeddings: i64, embedding_dim: i64, padding_index: i64) -> LearnedPositionalEmbedding {
let embedding_config = EmbeddingConfig { padding_idx: padding_index, ..Default::default() };
pub fn new(
p: nn::Path,
num_embeddings: i64,
embedding_dim: i64,
padding_index: i64,
) -> LearnedPositionalEmbedding {
let embedding_config = EmbeddingConfig {
padding_idx: padding_index,
..Default::default()
};
let num_embeddings = num_embeddings + padding_index + 1;
let embedding: nn::Embedding = embedding(p,
num_embeddings,
embedding_dim,
embedding_config);
LearnedPositionalEmbedding { embedding, padding_index }
let embedding: nn::Embedding =
embedding(p, num_embeddings, embedding_dim, embedding_config);
LearnedPositionalEmbedding {
embedding,
padding_index,
}
}
pub fn forward(&self, input: &Tensor, generation_mode: bool) -> Tensor {
@ -74,11 +86,13 @@ pub struct SinusoidalPositionalEmbedding {
}
impl SinusoidalPositionalEmbedding {
pub fn new(p: nn::Path, num_embeddings: i64, embedding_dim: i64) -> SinusoidalPositionalEmbedding {
let embedding: nn::Embedding = embedding(p,
num_embeddings,
embedding_dim,
Default::default());
pub fn new(
p: nn::Path,
num_embeddings: i64,
embedding_dim: i64,
) -> SinusoidalPositionalEmbedding {
let embedding: nn::Embedding =
embedding(p, num_embeddings, embedding_dim, Default::default());
SinusoidalPositionalEmbedding { embedding }
}

View File

@ -12,14 +12,16 @@
// limitations under the License.
use crate::bart::attention::SelfAttention;
use tch::{nn, Tensor};
use crate::common::dropout::Dropout;
use crate::bart::BartConfig;
use crate::bart::bart::Activation;
use crate::common::activations::{_gelu, _relu, _swish, _gelu_new, _tanh};
use crate::bart::embeddings::{EmbeddingOption, LearnedPositionalEmbedding, SinusoidalPositionalEmbedding};
use tch::kind::Kind::Bool;
use crate::bart::embeddings::{
EmbeddingOption, LearnedPositionalEmbedding, SinusoidalPositionalEmbedding,
};
use crate::bart::BartConfig;
use crate::common::activations::{_gelu, _gelu_new, _relu, _swish, _tanh};
use crate::common::dropout::Dropout;
use std::borrow::BorrowMut;
use tch::kind::Kind::Bool;
use tch::{nn, Tensor};
pub struct EncoderLayer {
self_attention: SelfAttention,
@ -34,46 +36,81 @@ pub struct EncoderLayer {
impl EncoderLayer {
pub fn new(p: nn::Path, config: &BartConfig) -> EncoderLayer {
let layer_norm_config = nn::LayerNormConfig { eps: 1e-5, ..Default::default() };
let layer_norm_config = nn::LayerNormConfig {
eps: 1e-5,
..Default::default()
};
let output_attention = match config.output_attentions {
Some(value) => value,
None => false
None => false,
};
let self_attention = SelfAttention::new(&p / "self_attn",
let self_attention = SelfAttention::new(
&p / "self_attn",
config.d_model,
config.encoder_attention_heads,
config.attention_dropout,
false,
false,
output_attention);
let self_attention_layer_norm = nn::layer_norm(&p / "self_attn_layer_norm",
output_attention,
);
let self_attention_layer_norm = nn::layer_norm(
&p / "self_attn_layer_norm",
vec![config.d_model],
layer_norm_config);
layer_norm_config,
);
let dropout = Dropout::new(config.dropout);
let activation_dropout = Dropout::new(config.activation_dropout);
let activation_function = match &config.activation_function {
Some(act_function) => act_function,
None => &Activation::gelu
None => &Activation::gelu,
};
let activation = Box::new(match activation_function {
Activation::gelu => _gelu,
Activation::relu => _relu,
Activation::swish => _swish,
Activation::gelu_new => _gelu_new,
Activation::tanh => _tanh
Activation::tanh => _tanh,
});
let fc1 = nn::linear(&p / "fc1", config.d_model, config.encoder_ffn_dim, Default::default());
let fc2 = nn::linear(&p / "fc2", config.encoder_ffn_dim, config.d_model, Default::default());
let fc1 = nn::linear(
&p / "fc1",
config.d_model,
config.encoder_ffn_dim,
Default::default(),
);
let fc2 = nn::linear(
&p / "fc2",
config.encoder_ffn_dim,
config.d_model,
Default::default(),
);
let final_layer_norm = nn::layer_norm(&p / "final_layer_norm",
let final_layer_norm = nn::layer_norm(
&p / "final_layer_norm",
vec![config.d_model],
layer_norm_config);
layer_norm_config,
);
EncoderLayer { self_attention, self_attention_layer_norm, dropout, activation_dropout, activation, fc1, fc2, final_layer_norm }
EncoderLayer {
self_attention,
self_attention_layer_norm,
dropout,
activation_dropout,
activation,
fc1,
fc2,
final_layer_norm,
}
}
pub fn forward_t(&self, x: &Tensor, encoder_padding_mask: Option<&Tensor>, train: bool) -> (Tensor, Option<Tensor>) {
let (output, attention_weights, _) = self.self_attention.forward_t(x, None, encoder_padding_mask, None, None, train);
pub fn forward_t(
&self,
x: &Tensor,
encoder_padding_mask: Option<&Tensor>,
train: bool,
) -> (Tensor, Option<Tensor>) {
let (output, attention_weights, _) =
self.self_attention
.forward_t(x, None, encoder_padding_mask, None, None, train);
let output: Tensor = output.apply_t(&self.dropout, train) + x;
let output = output.apply(&self.self_attention_layer_norm);
@ -102,57 +139,72 @@ impl BartEncoder {
pub fn new(p: nn::Path, config: &BartConfig) -> BartEncoder {
let output_attentions = match config.output_attentions {
Some(value) => value,
None => false
None => false,
};
let output_hidden_states = match config.output_hidden_states {
Some(value) => value,
None => false
None => false,
};
let normalize_embedding = match config.normalize_embedding {
Some(value) => value,
None => true
None => true,
};
let static_position_embeddings = match config.static_position_embeddings {
Some(value) => value,
None => false
None => false,
};
let scale_embedding = match config.scale_embedding {
Some(value) => if value { (config.d_model as f64).sqrt() } else { 1.0 },
None => 1.0
Some(value) => {
if value {
(config.d_model as f64).sqrt()
} else {
1.0
}
}
None => 1.0,
};
let dropout = Dropout::new(config.dropout);
let layer_norm_embedding = if normalize_embedding {
let layer_norm_config = nn::LayerNormConfig { eps: 1e-5, ..Default::default() };
Some(nn::layer_norm(&p / "layernorm_embedding",
let layer_norm_config = nn::LayerNormConfig {
eps: 1e-5,
..Default::default()
};
Some(nn::layer_norm(
&p / "layernorm_embedding",
vec![config.d_model],
layer_norm_config))
layer_norm_config,
))
} else {
None
};
let pad_token_id = match config.pad_token_id {
Some(value) => value,
None => 1
None => 1,
};
let embed_positions = if static_position_embeddings {
EmbeddingOption::SinusoidalPositionalEmbedding(SinusoidalPositionalEmbedding::new(&p / "embed_positions",
config.max_position_embeddings,
config.d_model))
} else {
EmbeddingOption::LearnedPositionalEmbedding(LearnedPositionalEmbedding::new(&p / "embed_positions",
EmbeddingOption::SinusoidalPositionalEmbedding(SinusoidalPositionalEmbedding::new(
&p / "embed_positions",
config.max_position_embeddings,
config.d_model,
pad_token_id))
))
} else {
EmbeddingOption::LearnedPositionalEmbedding(LearnedPositionalEmbedding::new(
&p / "embed_positions",
config.max_position_embeddings,
config.d_model,
pad_token_id,
))
};
let mut layers: Vec<EncoderLayer> = vec!();
let mut layers: Vec<EncoderLayer> = vec![];
let p_layers = &p / "layers";
for layer_index in 0..config.encoder_layers {
layers.push(EncoderLayer::new(&p_layers / layer_index, config));
};
}
BartEncoder {
dropout,
@ -165,26 +217,37 @@ impl BartEncoder {
}
}
pub fn forward_t(&self,
pub fn forward_t(
&self,
input_ids: &Tensor,
attention_mask: Option<&Tensor>,
embeddings: &nn::Embedding,
train: bool)
-> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
train: bool,
) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
let attention_mask = match attention_mask {
Some(mask) => Some(mask.eq(0).to_kind(Bool)),
None => None
None => None,
};
let x = input_ids.apply(embeddings) * self.scale_embedding;
let x: Tensor = x + &self.embed_positions.forward(input_ids, false);
let x = if let Some(layer_norm_embedding) = &self.layer_norm_embedding { x.apply(layer_norm_embedding) } else { x };
let x = x
.apply_t(&self.dropout, train)
.transpose(0, 1);
let x = if let Some(layer_norm_embedding) = &self.layer_norm_embedding {
x.apply(layer_norm_embedding)
} else {
x
};
let x = x.apply_t(&self.dropout, train).transpose(0, 1);
let mut all_hidden_states: Option<Vec<Tensor>> = if self.output_hidden_states { Some(vec!()) } else { None };
let mut all_attentions: Option<Vec<Tensor>> = if self.output_attentions { Some(vec!()) } else { None };
let mut all_hidden_states: Option<Vec<Tensor>> = if self.output_hidden_states {
Some(vec![])
} else {
None
};
let mut all_attentions: Option<Vec<Tensor>> = if self.output_attentions {
Some(vec![])
} else {
None
};
let mut hidden_state = x.copy();
let mut attention_weights: Option<Tensor>;
@ -204,13 +267,17 @@ impl BartEncoder {
attentions.push(attention_weights.as_ref().unwrap().copy());
};
}
None => break
};
None => break,
};
}
if let Some(hidden_states) = all_hidden_states.borrow_mut() {
hidden_states.push(hidden_state.as_ref().copy().transpose(0, 1));
};
(hidden_state.transpose(0, 1), all_hidden_states, all_attentions)
(
hidden_state.transpose(0, 1),
all_hidden_states,
all_attentions,
)
}
}

View File

@ -20,14 +20,22 @@
//! use rust_tokenizers::RobertaTokenizer;
//! use tch::{nn, Device};
//! # use std::path::PathBuf;
//! use rust_bert::Config;
//! use rust_bert::bart::{BartConfig, BartModel};
//! use rust_bert::resources::{Resource, download_resource, LocalResource};
//! use rust_bert::resources::{download_resource, LocalResource, Resource};
//! use rust_bert::Config;
//!
//! let config_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/config.json")});
//! let vocab_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/vocab.txt")});
//! let merges_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/vocab.txt")});
//! let weights_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/model.ot")});
//! let config_resource = Resource::Local(LocalResource {
//! local_path: PathBuf::from("path/to/config.json"),
//! });
//! let vocab_resource = Resource::Local(LocalResource {
//! local_path: PathBuf::from("path/to/vocab.txt"),
//! });
//! let merges_resource = Resource::Local(LocalResource {
//! local_path: PathBuf::from("path/to/vocab.txt"),
//! });
//! let weights_resource = Resource::Local(LocalResource {
//! local_path: PathBuf::from("path/to/model.ot"),
//! });
//! let config_path = download_resource(&config_resource)?;
//! let vocab_path = download_resource(&vocab_resource)?;
//! let merges_path = download_resource(&merges_resource)?;
@ -35,7 +43,11 @@
//!
//! let device = Device::cuda_if_available();
//! let mut vs = nn::VarStore::new(device);
//! let tokenizer: RobertaTokenizer = RobertaTokenizer::from_file(vocab_path.to_str().unwrap(), merges_path.to_str().unwrap(), true);
//! let tokenizer: RobertaTokenizer = RobertaTokenizer::from_file(
//! vocab_path.to_str().unwrap(),
//! merges_path.to_str().unwrap(),
//! true,
//! );
//! let config = BartConfig::from_file(config_path);
//! let bart_model = BartModel::new(&vs.root(), &config, false);
//! vs.load(weights_path)?;
@ -44,12 +56,15 @@
//! # }
//! ```
mod bart;
mod attention;
mod encoder;
mod bart;
mod decoder;
mod embeddings;
mod encoder;
pub use bart::{BartModelResources, BartConfigResources, BartVocabResources, BartMergesResources,
BartConfig, Activation, BartModel, BartForSequenceClassification, BartForConditionalGeneration};
pub use attention::LayerState;
pub use bart::{
Activation, BartConfig, BartConfigResources, BartForConditionalGeneration,
BartForSequenceClassification, BartMergesResources, BartModel, BartModelResources,
BartVocabResources,
};

View File

@ -11,11 +11,11 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use crate::common::dropout::Dropout;
use tch::{nn, Tensor};
use tch::kind::Kind::Float;
use crate::bert::bert::{Activation, BertConfig};
use crate::common::activations::{_gelu, _relu, _mish};
use crate::common::activations::{_gelu, _mish, _relu};
use crate::common::dropout::Dropout;
use tch::kind::Kind::Float;
use tch::{nn, Tensor};
#[derive(Debug)]
pub struct BertSelfAttention {
@ -30,17 +30,36 @@ pub struct BertSelfAttention {
impl BertSelfAttention {
pub fn new(p: nn::Path, config: &BertConfig) -> BertSelfAttention {
assert_eq!(config.hidden_size % config.num_attention_heads, 0, "Hidden size not a multiple of the number of attention heads");
assert_eq!(
config.hidden_size % config.num_attention_heads,
0,
"Hidden size not a multiple of the number of attention heads"
);
let query = nn::linear(&p / "query", config.hidden_size, config.hidden_size, Default::default());
let key = nn::linear(&p / "key", config.hidden_size, config.hidden_size, Default::default());
let value = nn::linear(&p / "value", config.hidden_size, config.hidden_size, Default::default());
let query = nn::linear(
&p / "query",
config.hidden_size,
config.hidden_size,
Default::default(),
);
let key = nn::linear(
&p / "key",
config.hidden_size,
config.hidden_size,
Default::default(),
);
let value = nn::linear(
&p / "value",
config.hidden_size,
config.hidden_size,
Default::default(),
);
let dropout = Dropout::new(config.attention_probs_dropout_prob);
let attention_head_size = config.hidden_size / config.num_attention_heads;
let output_attentions = match config.output_attentions {
Some(value) => value,
None => false
None => false,
};
BertSelfAttention {
@ -55,35 +74,44 @@ impl BertSelfAttention {
}
fn split_heads(&self, x: Tensor, bs: i64, dim_per_head: i64) -> Tensor {
x.view((bs, -1, self.num_attention_heads, dim_per_head)).transpose(1, 2)
x.view((bs, -1, self.num_attention_heads, dim_per_head))
.transpose(1, 2)
}
fn flatten(&self, x: Tensor, bs: i64, dim_per_head: i64) -> Tensor {
x.transpose(1, 2).contiguous().view((bs, -1, &self.num_attention_heads * dim_per_head))
x.transpose(1, 2)
.contiguous()
.view((bs, -1, &self.num_attention_heads * dim_per_head))
}
pub fn forward_t(&self,
pub fn forward_t(
&self,
hidden_states: &Tensor,
mask: &Option<Tensor>,
encoder_hidden_states: &Option<Tensor>,
encoder_mask: &Option<Tensor>,
train: bool) -> (Tensor, Option<Tensor>) {
train: bool,
) -> (Tensor, Option<Tensor>) {
let (key_layer, value_layer, mask) = match encoder_hidden_states {
Some(encoder_hidden_state_values) => {
(encoder_hidden_state_values.apply(&self.key),
Some(encoder_hidden_state_values) => (
encoder_hidden_state_values.apply(&self.key),
encoder_hidden_state_values.apply(&self.value),
encoder_mask)
}
None => {
(hidden_states.apply(&self.key),
encoder_mask,
),
None => (
hidden_states.apply(&self.key),
hidden_states.apply(&self.value),
mask)
}
mask,
),
};
let bs = hidden_states.size()[0];
let query_layer = self.split_heads(hidden_states.apply(&self.query), bs, self.attention_head_size);
let query_layer = self.split_heads(
hidden_states.apply(&self.query),
bs,
self.attention_head_size,
);
let key_layer = self.split_heads(key_layer, bs, self.attention_head_size);
let value_layer = self.split_heads(value_layer, bs, self.attention_head_size);
let query_layer: Tensor = query_layer / (self.attention_head_size as f64).sqrt();
@ -114,16 +142,32 @@ pub struct BertSelfOutput {
impl BertSelfOutput {
pub fn new(p: &nn::Path, config: &BertConfig) -> BertSelfOutput {
let linear = nn::linear(p / "dense", config.hidden_size, config.hidden_size, Default::default());
let layer_norm_config = nn::LayerNormConfig { eps: 1e-12, ..Default::default() };
let layer_norm = nn::layer_norm(p / "LayerNorm", vec![config.hidden_size], layer_norm_config);
let linear = nn::linear(
p / "dense",
config.hidden_size,
config.hidden_size,
Default::default(),
);
let layer_norm_config = nn::LayerNormConfig {
eps: 1e-12,
..Default::default()
};
let layer_norm =
nn::layer_norm(p / "LayerNorm", vec![config.hidden_size], layer_norm_config);
let dropout = Dropout::new(config.hidden_dropout_prob);
BertSelfOutput { linear, layer_norm, dropout }
BertSelfOutput {
linear,
layer_norm,
dropout,
}
}
pub fn forward_t(&self, hidden_states: &Tensor, input_tensor: &Tensor, train: bool) -> Tensor {
let hidden_states: Tensor = input_tensor + hidden_states.apply(&self.linear).apply_t(&self.dropout, train);
let hidden_states: Tensor = input_tensor
+ hidden_states
.apply(&self.linear)
.apply_t(&self.dropout, train);
hidden_states.apply(&self.layer_norm)
}
}
@ -141,14 +185,21 @@ impl BertAttention {
BertAttention { _self, output }
}
pub fn forward_t(&self,
pub fn forward_t(
&self,
hidden_states: &Tensor,
mask: &Option<Tensor>,
encoder_hidden_states: &Option<Tensor>,
encoder_mask: &Option<Tensor>,
train: bool) -> (Tensor, Option<Tensor>) {
let (self_output, attention_weights) = self._self.
forward_t(hidden_states, mask, encoder_hidden_states, encoder_mask, train);
train: bool,
) -> (Tensor, Option<Tensor>) {
let (self_output, attention_weights) = self._self.forward_t(
hidden_states,
mask,
encoder_hidden_states,
encoder_mask,
train,
);
let self_output = self.output.forward_t(&self_output, hidden_states, train);
(self_output, attention_weights)
@ -162,11 +213,16 @@ pub struct BertIntermediate {
impl BertIntermediate {
pub fn new(p: &nn::Path, config: &BertConfig) -> BertIntermediate {
let lin = nn::linear(p / "dense", config.hidden_size, config.intermediate_size, Default::default());
let lin = nn::linear(
p / "dense",
config.hidden_size,
config.intermediate_size,
Default::default(),
);
let activation = Box::new(match &config.hidden_act {
Activation::gelu => _gelu,
Activation::relu => _relu,
Activation::mish => _mish
Activation::mish => _mish,
});
BertIntermediate { lin, activation }
}
@ -184,17 +240,30 @@ pub struct BertOutput {
impl BertOutput {
pub fn new(p: &nn::Path, config: &BertConfig) -> BertOutput {
let lin = nn::linear(p / "dense", config.intermediate_size, config.hidden_size, Default::default());
let layer_norm_config = nn::LayerNormConfig { eps: 1e-12, ..Default::default() };
let layer_norm = nn::layer_norm(p / "LayerNorm", vec![config.hidden_size], layer_norm_config);
let lin = nn::linear(
p / "dense",
config.intermediate_size,
config.hidden_size,
Default::default(),
);
let layer_norm_config = nn::LayerNormConfig {
eps: 1e-12,
..Default::default()
};
let layer_norm =
nn::layer_norm(p / "LayerNorm", vec![config.hidden_size], layer_norm_config);
let dropout = Dropout::new(config.hidden_dropout_prob);
BertOutput { lin, layer_norm, dropout }
BertOutput {
lin,
layer_norm,
dropout,
}
}
pub fn forward_t(&self, hidden_states: &Tensor, input_tensor: &Tensor, train: bool) -> Tensor {
let hidden_states: Tensor = input_tensor + hidden_states.apply(&self.lin).apply_t(&self.dropout, train);
let hidden_states: Tensor =
input_tensor + hidden_states.apply(&self.lin).apply_t(&self.dropout, train);
hidden_states.apply(&self.layer_norm)
}
}

View File

@ -11,17 +11,17 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use serde::{Deserialize, Serialize};
use crate::bert::embeddings::{BertEmbeddings, BertEmbedding};
use crate::bert::embeddings::{BertEmbedding, BertEmbeddings};
use crate::bert::encoder::{BertEncoder, BertPooler};
use tch::{nn, Tensor, Kind};
use tch::kind::Kind::Float;
use crate::common::activations::{_gelu, _relu, _mish};
use crate::common::linear::{LinearNoBias, linear_no_bias};
use tch::nn::Init;
use crate::common::activations::{_gelu, _mish, _relu};
use crate::common::dropout::Dropout;
use std::collections::HashMap;
use crate::common::linear::{linear_no_bias, LinearNoBias};
use crate::Config;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use tch::kind::Kind::Float;
use tch::nn::Init;
use tch::{nn, Kind, Tensor};
/// # BERT Pretrained model weight files
pub struct BertModelResources;
@ -34,23 +34,41 @@ pub struct BertVocabResources;
impl BertModelResources {
/// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/bert. Modified with conversion to C-array format.
pub const BERT: (&'static str, &'static str) = ("bert/model.ot", "https://cdn.huggingface.co/bert-base-uncased-rust_model.ot");
pub const BERT: (&'static str, &'static str) = (
"bert/model.ot",
"https://cdn.huggingface.co/bert-base-uncased-rust_model.ot",
);
/// Shared under MIT license by the MDZ Digital Library team at the Bavarian State Library at https://github.com/dbmdz/berts. Modified with conversion to C-array format.
pub const BERT_NER: (&'static str, &'static str) = ("bert-ner/model.ot", "https://cdn.huggingface.co/dbmdz/bert-large-cased-finetuned-conll03-english/rust_model.ot");
pub const BERT_NER: (&'static str, &'static str) = (
"bert-ner/model.ot",
"https://cdn.huggingface.co/dbmdz/bert-large-cased-finetuned-conll03-english/rust_model.ot",
);
}
impl BertConfigResources {
/// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/bert. Modified with conversion to C-array format.
pub const BERT: (&'static str, &'static str) = ("bert/config.json", "https://cdn.huggingface.co/bert-base-uncased-config.json");
pub const BERT: (&'static str, &'static str) = (
"bert/config.json",
"https://cdn.huggingface.co/bert-base-uncased-config.json",
);
/// Shared under MIT license by the MDZ Digital Library team at the Bavarian State Library at https://github.com/dbmdz/berts. Modified with conversion to C-array format.
pub const BERT_NER: (&'static str, &'static str) = ("bert-ner/config.json", "https://cdn.huggingface.co/dbmdz/bert-large-cased-finetuned-conll03-english/config.json");
pub const BERT_NER: (&'static str, &'static str) = (
"bert-ner/config.json",
"https://cdn.huggingface.co/dbmdz/bert-large-cased-finetuned-conll03-english/config.json",
);
}
impl BertVocabResources {
/// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/bert. Modified with conversion to C-array format.
pub const BERT: (&'static str, &'static str) = ("bert/vocab.txt", "https://cdn.huggingface.co/bert-base-uncased-vocab.txt");
pub const BERT: (&'static str, &'static str) = (
"bert/vocab.txt",
"https://cdn.huggingface.co/bert-base-uncased-vocab.txt",
);
/// Shared under MIT license by the MDZ Digital Library team at the Bavarian State Library at https://github.com/dbmdz/berts. Modified with conversion to C-array format.
pub const BERT_NER: (&'static str, &'static str) = ("bert-ner/vocab.txt", "https://cdn.huggingface.co/dbmdz/bert-large-cased-finetuned-conll03-english/vocab.txt");
pub const BERT_NER: (&'static str, &'static str) = (
"bert-ner/vocab.txt",
"https://cdn.huggingface.co/dbmdz/bert-large-cased-finetuned-conll03-english/vocab.txt",
);
}
#[allow(non_camel_case_types)]
@ -117,10 +135,10 @@ impl<T: BertEmbedding> BertModel<T> {
/// # Example
///
/// ```no_run
/// use rust_bert::bert::{BertModel, BertConfig, BertEmbeddings};
/// use tch::{nn, Device};
/// use rust_bert::bert::{BertConfig, BertEmbeddings, BertModel};
/// use rust_bert::Config;
/// use std::path::Path;
/// use tch::{nn, Device};
///
/// let config_path = Path::new("path/to/config.json");
/// let device = Device::Cpu;
@ -128,17 +146,21 @@ impl<T: BertEmbedding> BertModel<T> {
/// let config = BertConfig::from_file(config_path);
/// let bert: BertModel<BertEmbeddings> = BertModel::new(&(&p.root() / "bert"), &config);
/// ```
///
pub fn new(p: &nn::Path, config: &BertConfig) -> BertModel<T> {
let is_decoder = match config.is_decoder {
Some(value) => value,
None => false
None => false,
};
let embeddings = T::new(&(p / "embeddings"), config);
let encoder = BertEncoder::new(&(p / "encoder"), config);
let pooler = BertPooler::new(&(p / "pooler"), config);
BertModel { embeddings, encoder, pooler, is_decoder }
BertModel {
embeddings,
encoder,
pooler,
is_decoder,
}
}
/// Forward pass through the model
@ -178,23 +200,26 @@ impl<T: BertEmbedding> BertModel<T> {
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
/// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
/// let position_ids = Tensor::arange(sequence_length, (Int64, device)).expand(&[batch_size, sequence_length], true);
/// let position_ids = Tensor::arange(sequence_length, (Int64, device))
/// .expand(&[batch_size, sequence_length], true);
///
/// let (output, pooled_output, all_hidden_states, all_attentions) = no_grad(|| {
/// bert_model
/// .forward_t(Some(input_tensor),
/// .forward_t(
/// Some(input_tensor),
/// Some(mask),
/// Some(token_type_ids),
/// Some(position_ids),
/// None,
/// &None,
/// &None,
/// false).unwrap()
/// false,
/// )
/// .unwrap()
/// });
///
/// ```
///
pub fn forward_t(&self,
pub fn forward_t(
&self,
input_ids: Option<Tensor>,
mask: Option<Tensor>,
token_type_ids: Option<Tensor>,
@ -202,74 +227,109 @@ impl<T: BertEmbedding> BertModel<T> {
input_embeds: Option<Tensor>,
encoder_hidden_states: &Option<Tensor>,
encoder_mask: &Option<Tensor>,
train: bool)
-> Result<(Tensor, Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
train: bool,
) -> Result<(Tensor, Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
let (input_shape, device) = match &input_ids {
Some(input_value) => match &input_embeds {
Some(_) => { return Err("Only one of input ids or input embeddings may be set"); }
None => (input_value.size(), input_value.device())
Some(_) => {
return Err("Only one of input ids or input embeddings may be set");
}
None => (input_value.size(), input_value.device()),
},
None => match &input_embeds {
Some(embeds) => (vec!(embeds.size()[0], embeds.size()[1]), embeds.device()),
None => { return Err("At least one of input ids or input embeddings must be set"); }
Some(embeds) => (vec![embeds.size()[0], embeds.size()[1]], embeds.device()),
None => {
return Err("At least one of input ids or input embeddings must be set");
}
},
};
let mask = match mask {
Some(value) => value,
None => Tensor::ones(&input_shape, (Kind::Int64, device))
None => Tensor::ones(&input_shape, (Kind::Int64, device)),
};
let extended_attention_mask = match mask.dim() {
3 => mask.unsqueeze(1),
2 => if self.is_decoder {
2 => {
if self.is_decoder {
let seq_ids = Tensor::arange(input_shape[1], (Float, device));
let causal_mask = seq_ids.unsqueeze(0).unsqueeze(0).repeat(&vec!(input_shape[0], input_shape[1], 1));
let causal_mask = seq_ids.unsqueeze(0).unsqueeze(0).repeat(&vec![
input_shape[0],
input_shape[1],
1,
]);
let causal_mask = causal_mask.le1(&seq_ids.unsqueeze(0).unsqueeze(-1));
causal_mask * mask.unsqueeze(1).unsqueeze(1)
} else {
mask.unsqueeze(1).unsqueeze(1)
},
_ => { return Err("Invalid attention mask dimension, must be 2 or 3"); }
}
}
_ => {
return Err("Invalid attention mask dimension, must be 2 or 3");
}
};
let extended_attention_mask: Tensor = (extended_attention_mask.ones_like() - extended_attention_mask) * -10000.0;
let extended_attention_mask: Tensor =
(extended_attention_mask.ones_like() - extended_attention_mask) * -10000.0;
let encoder_extended_attention_mask: Option<Tensor> = if self.is_decoder & encoder_hidden_states.is_some() {
let encoder_extended_attention_mask: Option<Tensor> =
if self.is_decoder & encoder_hidden_states.is_some() {
let encoder_hidden_states = encoder_hidden_states.as_ref().unwrap();
let encoder_hidden_states_shape = encoder_hidden_states.size();
let encoder_mask = match encoder_mask {
Some(value) => value.copy(),
None => Tensor::ones(&[encoder_hidden_states_shape[0], encoder_hidden_states_shape[1]], (Kind::Int64, device))
None => Tensor::ones(
&[
encoder_hidden_states_shape[0],
encoder_hidden_states_shape[1],
],
(Kind::Int64, device),
),
};
match encoder_mask.dim() {
2 => Some(encoder_mask.unsqueeze(1).unsqueeze(1)),
3 => Some(encoder_mask.unsqueeze(1)),
_ => { return Err("Invalid encoder attention mask dimension, must be 2 or 3"); }
_ => {
return Err("Invalid encoder attention mask dimension, must be 2 or 3");
}
}
} else {
None
};
let embedding_output = match self.embeddings.forward_t(input_ids, token_type_ids, position_ids, input_embeds, train) {
let embedding_output = match self.embeddings.forward_t(
input_ids,
token_type_ids,
position_ids,
input_embeds,
train,
) {
Ok(value) => value,
Err(e) => { return Err(e); }
Err(e) => {
return Err(e);
}
};
let (hidden_state, all_hidden_states, all_attentions) =
self.encoder.forward_t(&embedding_output,
let (hidden_state, all_hidden_states, all_attentions) = self.encoder.forward_t(
&embedding_output,
&Some(extended_attention_mask),
encoder_hidden_states,
&encoder_extended_attention_mask,
train);
train,
);
let pooled_output = self.pooler.forward(&hidden_state);
Ok((hidden_state, pooled_output, all_hidden_states, all_attentions))
Ok((
hidden_state,
pooled_output,
all_hidden_states,
all_attentions,
))
}
}
pub struct BertPredictionHeadTransform {
dense: nn::Linear,
activation: Box<dyn Fn(&Tensor) -> Tensor>,
@ -278,16 +338,29 @@ pub struct BertPredictionHeadTransform {
impl BertPredictionHeadTransform {
pub fn new(p: &nn::Path, config: &BertConfig) -> BertPredictionHeadTransform {
let dense = nn::linear(p / "dense", config.hidden_size, config.hidden_size, Default::default());
let dense = nn::linear(
p / "dense",
config.hidden_size,
config.hidden_size,
Default::default(),
);
let activation = Box::new(match &config.hidden_act {
Activation::gelu => _gelu,
Activation::relu => _relu,
Activation::mish => _mish
Activation::mish => _mish,
});
let layer_norm_config = nn::LayerNormConfig { eps: 1e-12, ..Default::default() };
let layer_norm = nn::layer_norm(p / "LayerNorm", vec![config.hidden_size], layer_norm_config);
let layer_norm_config = nn::LayerNormConfig {
eps: 1e-12,
..Default::default()
};
let layer_norm =
nn::layer_norm(p / "LayerNorm", vec![config.hidden_size], layer_norm_config);
BertPredictionHeadTransform { dense, activation, layer_norm }
BertPredictionHeadTransform {
dense,
activation,
layer_norm,
}
}
pub fn forward(&self, hidden_states: &Tensor) -> Tensor {
@ -305,10 +378,19 @@ impl BertLMPredictionHead {
pub fn new(p: &nn::Path, config: &BertConfig) -> BertLMPredictionHead {
let p = &(p / "predictions");
let transform = BertPredictionHeadTransform::new(&(p / "transform"), config);
let decoder = linear_no_bias(&(p / "decoder"), config.hidden_size, config.vocab_size, Default::default());
let decoder = linear_no_bias(
&(p / "decoder"),
config.hidden_size,
config.vocab_size,
Default::default(),
);
let bias = p.var("bias", &[config.vocab_size], Init::KaimingUniform);
BertLMPredictionHead { transform, decoder, bias }
BertLMPredictionHead {
transform,
decoder,
bias,
}
}
pub fn forward(&self, hidden_states: &Tensor) -> Tensor {
@ -338,9 +420,9 @@ impl BertForMaskedLM {
///
/// ```no_run
/// use rust_bert::bert::{BertConfig, BertForMaskedLM};
/// use tch::{nn, Device};
/// use rust_bert::Config;
/// use std::path::Path;
/// use tch::{nn, Device};
///
/// let config_path = Path::new("path/to/config.json");
/// let device = Device::Cpu;
@ -348,7 +430,6 @@ impl BertForMaskedLM {
/// let config = BertConfig::from_file(config_path);
/// let bert = BertForMaskedLM::new(&(&p.root() / "bert"), &config);
/// ```
///
pub fn new(p: &nn::Path, config: &BertConfig) -> BertForMaskedLM {
let bert = BertModel::new(&(p / "bert"), config);
let cls = BertLMPredictionHead::new(&(p / "cls"), config);
@ -392,23 +473,24 @@ impl BertForMaskedLM {
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
/// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
/// let position_ids = Tensor::arange(sequence_length, (Int64, device)).expand(&[batch_size, sequence_length], true);
/// let position_ids = Tensor::arange(sequence_length, (Int64, device))
/// .expand(&[batch_size, sequence_length], true);
///
/// let (output, all_hidden_states, all_attentions) = no_grad(|| {
/// bert_model
/// .forward_t(Some(input_tensor),
/// bert_model.forward_t(
/// Some(input_tensor),
/// Some(mask),
/// Some(token_type_ids),
/// Some(position_ids),
/// None,
/// &None,
/// &None,
/// false)
/// false,
/// )
/// });
///
/// ```
///
pub fn forward_t(&self,
pub fn forward_t(
&self,
input_ids: Option<Tensor>,
mask: Option<Tensor>,
token_type_ids: Option<Tensor>,
@ -416,9 +498,21 @@ impl BertForMaskedLM {
input_embeds: Option<Tensor>,
encoder_hidden_states: &Option<Tensor>,
encoder_mask: &Option<Tensor>,
train: bool) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
let (hidden_state, _, all_hidden_states, all_attentions) = self.bert.forward_t(input_ids, mask, token_type_ids, position_ids,
input_embeds, encoder_hidden_states, encoder_mask, train).unwrap();
train: bool,
) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
let (hidden_state, _, all_hidden_states, all_attentions) = self
.bert
.forward_t(
input_ids,
mask,
token_type_ids,
position_ids,
input_embeds,
encoder_hidden_states,
encoder_mask,
train,
)
.unwrap();
let prediction_scores = self.cls.forward(&hidden_state);
(prediction_scores, all_hidden_states, all_attentions)
@ -448,9 +542,9 @@ impl BertForSequenceClassification {
///
/// ```no_run
/// use rust_bert::bert::{BertConfig, BertForSequenceClassification};
/// use tch::{nn, Device};
/// use rust_bert::Config;
/// use std::path::Path;
/// use tch::{nn, Device};
///
/// let config_path = Path::new("path/to/config.json");
/// let device = Device::Cpu;
@ -461,10 +555,23 @@ impl BertForSequenceClassification {
pub fn new(p: &nn::Path, config: &BertConfig) -> BertForSequenceClassification {
let bert = BertModel::new(&(p / "bert"), config);
let dropout = Dropout::new(config.hidden_dropout_prob);
let num_labels = config.id2label.as_ref().expect("num_labels not provided in configuration").len() as i64;
let classifier = nn::linear(p / "classifier", config.hidden_size, num_labels, Default::default());
let num_labels = config
.id2label
.as_ref()
.expect("num_labels not provided in configuration")
.len() as i64;
let classifier = nn::linear(
p / "classifier",
config.hidden_size,
num_labels,
Default::default(),
);
BertForSequenceClassification { bert, dropout, classifier }
BertForSequenceClassification {
bert,
dropout,
classifier,
}
}
/// Forward pass through the model
@ -501,31 +608,46 @@ impl BertForSequenceClassification {
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
/// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
/// let position_ids = Tensor::arange(sequence_length, (Int64, device)).expand(&[batch_size, sequence_length], true);
/// let position_ids = Tensor::arange(sequence_length, (Int64, device))
/// .expand(&[batch_size, sequence_length], true);
///
/// let (labels, all_hidden_states, all_attentions) = no_grad(|| {
/// bert_model
/// .forward_t(Some(input_tensor),
/// bert_model.forward_t(
/// Some(input_tensor),
/// Some(mask),
/// Some(token_type_ids),
/// Some(position_ids),
/// None,
/// false)
/// false,
/// )
/// });
///
/// ```
///
pub fn forward_t(&self,
pub fn forward_t(
&self,
input_ids: Option<Tensor>,
mask: Option<Tensor>,
token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
let (_, pooled_output, all_hidden_states, all_attentions) = self.bert.forward_t(input_ids, mask, token_type_ids, position_ids,
input_embeds, &None, &None, train).unwrap();
train: bool,
) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
let (_, pooled_output, all_hidden_states, all_attentions) = self
.bert
.forward_t(
input_ids,
mask,
token_type_ids,
position_ids,
input_embeds,
&None,
&None,
train,
)
.unwrap();
let output = pooled_output.apply_t(&self.dropout, train).apply(&self.classifier);
let output = pooled_output
.apply_t(&self.dropout, train)
.apply(&self.classifier);
(output, all_hidden_states, all_attentions)
}
}
@ -555,9 +677,9 @@ impl BertForMultipleChoice {
///
/// ```no_run
/// use rust_bert::bert::{BertConfig, BertForMultipleChoice};
/// use tch::{nn, Device};
/// use rust_bert::Config;
/// use std::path::Path;
/// use tch::{nn, Device};
///
/// let config_path = Path::new("path/to/config.json");
/// let device = Device::Cpu;
@ -570,7 +692,11 @@ impl BertForMultipleChoice {
let dropout = Dropout::new(config.hidden_dropout_prob);
let classifier = nn::linear(p / "classifier", config.hidden_size, 1, Default::default());
BertForMultipleChoice { bert, dropout, classifier }
BertForMultipleChoice {
bert,
dropout,
classifier,
}
}
/// Forward pass through the model
@ -606,46 +732,61 @@ impl BertForMultipleChoice {
/// let input_tensor = Tensor::rand(&[num_choices, sequence_length], (Int64, device));
/// let mask = Tensor::zeros(&[num_choices, sequence_length], (Int64, device));
/// let token_type_ids = Tensor::zeros(&[num_choices, sequence_length], (Int64, device));
/// let position_ids = Tensor::arange(sequence_length, (Int64, device)).expand(&[num_choices, sequence_length], true);
/// let position_ids = Tensor::arange(sequence_length, (Int64, device))
/// .expand(&[num_choices, sequence_length], true);
///
/// let (choices, all_hidden_states, all_attentions) = no_grad(|| {
/// bert_model
/// .forward_t(input_tensor,
/// bert_model.forward_t(
/// input_tensor,
/// Some(mask),
/// Some(token_type_ids),
/// Some(position_ids),
/// false)
/// false,
/// )
/// });
///
/// ```
///
pub fn forward_t(&self,
pub fn forward_t(
&self,
input_ids: Tensor,
mask: Option<Tensor>,
token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>,
train: bool) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
train: bool,
) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
let num_choices = input_ids.size()[1];
let input_ids = input_ids.view((-1, *input_ids.size().last().unwrap()));
let mask = match mask {
Some(value) => Some(value.view((-1, *value.size().last().unwrap()))),
None => None
None => None,
};
let token_type_ids = match token_type_ids {
Some(value) => Some(value.view((-1, *value.size().last().unwrap()))),
None => None
None => None,
};
let position_ids = match position_ids {
Some(value) => Some(value.view((-1, *value.size().last().unwrap()))),
None => None
None => None,
};
let (_, pooled_output, all_hidden_states, all_attentions) = self
.bert
.forward_t(
Some(input_ids),
mask,
token_type_ids,
position_ids,
None,
&None,
&None,
train,
)
.unwrap();
let (_, pooled_output, all_hidden_states, all_attentions) = self.bert.forward_t(Some(input_ids), mask, token_type_ids, position_ids,
None, &None, &None, train).unwrap();
let output = pooled_output.apply_t(&self.dropout, train).apply(&self.classifier).view((-1, num_choices));
let output = pooled_output
.apply_t(&self.dropout, train)
.apply(&self.classifier)
.view((-1, num_choices));
(output, all_hidden_states, all_attentions)
}
}
@ -674,9 +815,9 @@ impl BertForTokenClassification {
///
/// ```no_run
/// use rust_bert::bert::{BertConfig, BertForTokenClassification};
/// use tch::{nn, Device};
/// use rust_bert::Config;
/// use std::path::Path;
/// use tch::{nn, Device};
///
/// let config_path = Path::new("path/to/config.json");
/// let device = Device::Cpu;
@ -687,10 +828,23 @@ impl BertForTokenClassification {
pub fn new(p: &nn::Path, config: &BertConfig) -> BertForTokenClassification {
let bert = BertModel::new(&(p / "bert"), config);
let dropout = Dropout::new(config.hidden_dropout_prob);
let num_labels = config.id2label.as_ref().expect("num_labels not provided in configuration").len() as i64;
let classifier = nn::linear(p / "classifier", config.hidden_size, num_labels, Default::default());
let num_labels = config
.id2label
.as_ref()
.expect("num_labels not provided in configuration")
.len() as i64;
let classifier = nn::linear(
p / "classifier",
config.hidden_size,
num_labels,
Default::default(),
);
BertForTokenClassification { bert, dropout, classifier }
BertForTokenClassification {
bert,
dropout,
classifier,
}
}
/// Forward pass through the model
@ -727,31 +881,46 @@ impl BertForTokenClassification {
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
/// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
/// let position_ids = Tensor::arange(sequence_length, (Int64, device)).expand(&[batch_size, sequence_length], true);
/// let position_ids = Tensor::arange(sequence_length, (Int64, device))
/// .expand(&[batch_size, sequence_length], true);
///
/// let (token_labels, all_hidden_states, all_attentions) = no_grad(|| {
/// bert_model
/// .forward_t(Some(input_tensor),
/// bert_model.forward_t(
/// Some(input_tensor),
/// Some(mask),
/// Some(token_type_ids),
/// Some(position_ids),
/// None,
/// false)
/// false,
/// )
/// });
///
/// ```
///
pub fn forward_t(&self,
pub fn forward_t(
&self,
input_ids: Option<Tensor>,
mask: Option<Tensor>,
token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
let (hidden_state, _, all_hidden_states, all_attentions) = self.bert.forward_t(input_ids, mask, token_type_ids, position_ids,
input_embeds, &None, &None, train).unwrap();
train: bool,
) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
let (hidden_state, _, all_hidden_states, all_attentions) = self
.bert
.forward_t(
input_ids,
mask,
token_type_ids,
position_ids,
input_embeds,
&None,
&None,
train,
)
.unwrap();
let sequence_output = hidden_state.apply_t(&self.dropout, train).apply(&self.classifier);
let sequence_output = hidden_state
.apply_t(&self.dropout, train)
.apply(&self.classifier);
(sequence_output, all_hidden_states, all_attentions)
}
}
@ -780,9 +949,9 @@ impl BertForQuestionAnswering {
///
/// ```no_run
/// use rust_bert::bert::{BertConfig, BertForQuestionAnswering};
/// use tch::{nn, Device};
/// use rust_bert::Config;
/// use std::path::Path;
/// use tch::{nn, Device};
///
/// let config_path = Path::new("path/to/config.json");
/// let device = Device::Cpu;
@ -793,7 +962,12 @@ impl BertForQuestionAnswering {
pub fn new(p: &nn::Path, config: &BertConfig) -> BertForQuestionAnswering {
let bert = BertModel::new(&(p / "bert"), config);
let num_labels = 2;
let qa_outputs = nn::linear(p / "qa_outputs", config.hidden_size, num_labels, Default::default());
let qa_outputs = nn::linear(
p / "qa_outputs",
config.hidden_size,
num_labels,
Default::default(),
);
BertForQuestionAnswering { bert, qa_outputs }
}
@ -833,29 +1007,42 @@ impl BertForQuestionAnswering {
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
/// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
/// let position_ids = Tensor::arange(sequence_length, (Int64, device)).expand(&[batch_size, sequence_length], true);
/// let position_ids = Tensor::arange(sequence_length, (Int64, device))
/// .expand(&[batch_size, sequence_length], true);
///
/// let (start_scores, end_scores, all_hidden_states, all_attentions) = no_grad(|| {
/// bert_model
/// .forward_t(Some(input_tensor),
/// bert_model.forward_t(
/// Some(input_tensor),
/// Some(mask),
/// Some(token_type_ids),
/// Some(position_ids),
/// None,
/// false)
/// false,
/// )
/// });
///
/// ```
///
pub fn forward_t(&self,
pub fn forward_t(
&self,
input_ids: Option<Tensor>,
mask: Option<Tensor>,
token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool) -> (Tensor, Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
let (hidden_state, _, all_hidden_states, all_attentions) = self.bert.forward_t(input_ids, mask, token_type_ids, position_ids,
input_embeds, &None, &None, train).unwrap();
train: bool,
) -> (Tensor, Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
let (hidden_state, _, all_hidden_states, all_attentions) = self
.bert
.forward_t(
input_ids,
mask,
token_type_ids,
position_ids,
input_embeds,
&None,
&None,
train,
)
.unwrap();
let sequence_output = hidden_state.apply(&self.qa_outputs);
let logits = sequence_output.split(1, -1);

View File

@ -11,22 +11,24 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use tch::{nn, Tensor, Kind};
use tch::nn::{EmbeddingConfig, embedding};
use crate::common::dropout::Dropout;
use crate::bert::bert::BertConfig;
use crate::common::dropout::Dropout;
use tch::nn::{embedding, EmbeddingConfig};
use tch::{nn, Kind, Tensor};
/// # BertEmbedding trait (for use in BertModel or RoBERTaModel)
/// Defines an interface for the embedding layers in BERT-based models
pub trait BertEmbedding {
fn new(p: &nn::Path, config: &BertConfig) -> Self;
fn forward_t(&self,
fn forward_t(
&self,
input_ids: Option<Tensor>,
token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool) -> Result<Tensor, &'static str>;
train: bool,
) -> Result<Tensor, &'static str>;
}
#[derive(Debug)]
@ -51,10 +53,10 @@ impl BertEmbedding for BertEmbeddings {
/// # Example
///
/// ```no_run
/// use rust_bert::bert::{BertConfig, BertEmbeddings, BertEmbedding};
/// use tch::{nn, Device};
/// use rust_bert::bert::{BertConfig, BertEmbedding, BertEmbeddings};
/// use rust_bert::Config;
/// use std::path::Path;
/// use tch::{nn, Device};
///
/// let config_path = Path::new("path/to/config.json");
/// let device = Device::Cpu;
@ -62,29 +64,47 @@ impl BertEmbedding for BertEmbeddings {
/// let config = BertConfig::from_file(config_path);
/// let bert_embeddings = BertEmbeddings::new(&(&p.root() / "bert_embeddings"), &config);
/// ```
///
fn new(p: &nn::Path, config: &BertConfig) -> BertEmbeddings {
let embedding_config = EmbeddingConfig { padding_idx: 0, ..Default::default() };
let embedding_config = EmbeddingConfig {
padding_idx: 0,
..Default::default()
};
let word_embeddings: nn::Embedding = embedding(p / "word_embeddings",
let word_embeddings: nn::Embedding = embedding(
p / "word_embeddings",
config.vocab_size,
config.hidden_size,
embedding_config);
embedding_config,
);
let position_embeddings: nn::Embedding = embedding(p / "position_embeddings",
let position_embeddings: nn::Embedding = embedding(
p / "position_embeddings",
config.max_position_embeddings,
config.hidden_size,
Default::default());
Default::default(),
);
let token_type_embeddings: nn::Embedding = embedding(p / "token_type_embeddings",
let token_type_embeddings: nn::Embedding = embedding(
p / "token_type_embeddings",
config.type_vocab_size,
config.hidden_size,
Default::default());
Default::default(),
);
let layer_norm_config = nn::LayerNormConfig { eps: 1e-12, ..Default::default() };
let layer_norm: nn::LayerNorm = nn::layer_norm(p / "LayerNorm", vec![config.hidden_size], layer_norm_config);
let layer_norm_config = nn::LayerNormConfig {
eps: 1e-12,
..Default::default()
};
let layer_norm: nn::LayerNorm =
nn::layer_norm(p / "LayerNorm", vec![config.hidden_size], layer_norm_config);
let dropout: Dropout = Dropout::new(config.hidden_dropout_prob);
BertEmbeddings { word_embeddings, position_embeddings, token_type_embeddings, layer_norm, dropout }
BertEmbeddings {
word_embeddings,
position_embeddings,
token_type_embeddings,
layer_norm,
dropout,
}
}
/// Forward pass through the embedding layer
@ -118,36 +138,48 @@ impl BertEmbedding for BertEmbeddings {
/// let (batch_size, sequence_length) = (64, 128);
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
/// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
/// let position_ids = Tensor::arange(sequence_length, (Int64, device)).expand(&[batch_size, sequence_length], true);
/// let position_ids = Tensor::arange(sequence_length, (Int64, device))
/// .expand(&[batch_size, sequence_length], true);
///
/// let embedded_output = no_grad(|| {
/// bert_embeddings
/// .forward_t(Some(input_tensor),
/// .forward_t(
/// Some(input_tensor),
/// Some(token_type_ids),
/// Some(position_ids),
/// None,
/// false).unwrap()
/// false,
/// )
/// .unwrap()
/// });
/// ```
///
fn forward_t(&self,
fn forward_t(
&self,
input_ids: Option<Tensor>,
token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool) -> Result<Tensor, &'static str> {
train: bool,
) -> Result<Tensor, &'static str> {
let (input_embeddings, input_shape) = match input_ids {
Some(input_value) => match input_embeds {
Some(_) => { return Err("Only one of input ids or input embeddings may be set"); }
None => (input_value.apply_t(&self.word_embeddings, train), input_value.size())
Some(_) => {
return Err("Only one of input ids or input embeddings may be set");
}
None => (
input_value.apply_t(&self.word_embeddings, train),
input_value.size(),
),
},
None => match input_embeds {
Some(embeds) => {
let size = vec!(embeds.size()[0], embeds.size()[1]);
let size = vec![embeds.size()[0], embeds.size()[1]];
(embeds, size)
},
None => { return Err("Only one of input ids or input embeddings may be set"); }
}
None => {
return Err("Only one of input ids or input embeddings may be set");
}
},
};
let seq_length = input_embeddings.as_ref().size()[1].to_owned();
@ -155,19 +187,22 @@ impl BertEmbedding for BertEmbeddings {
let position_ids = match position_ids {
Some(value) => value,
None => Tensor::arange(seq_length, (Kind::Int64, input_embeddings.device()))
.unsqueeze(0).
expand(&input_shape, true)
.unsqueeze(0)
.expand(&input_shape, true),
};
let token_type_ids = match token_type_ids {
Some(value) => value,
None => Tensor::zeros(&input_shape, (Kind::Int64, input_embeddings.device()))
None => Tensor::zeros(&input_shape, (Kind::Int64, input_embeddings.device())),
};
let position_embeddings = position_ids.apply(&self.position_embeddings);
let token_type_embeddings = token_type_ids.apply(&self.token_type_embeddings);
let input_embeddings: Tensor = input_embeddings + position_embeddings + token_type_embeddings;
Ok(input_embeddings.apply(&self.layer_norm).apply_t(&self.dropout, train))
let input_embeddings: Tensor =
input_embeddings + position_embeddings + token_type_embeddings;
Ok(input_embeddings
.apply(&self.layer_norm)
.apply_t(&self.dropout, train))
}
}

View File

@ -11,10 +11,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use tch::{Tensor, nn};
use crate::bert::attention::{BertAttention, BertIntermediate, BertOutput};
use std::borrow::BorrowMut;
use crate::bert::bert::BertConfig;
use std::borrow::BorrowMut;
use tch::{nn, Tensor};
pub struct BertLayer {
attention: BertAttention,
@ -30,35 +30,55 @@ impl BertLayer {
let (is_decoder, cross_attention) = match config.is_decoder {
Some(value) => {
if value == true {
(value, Some(BertAttention::new(&(p / "cross_attention"), &config)))
(
value,
Some(BertAttention::new(&(p / "cross_attention"), &config)),
)
} else {
(value, None)
}
}
None => (false, None)
None => (false, None),
};
let intermediate = BertIntermediate::new(&(p / "intermediate"), &config);
let output = BertOutput::new(&(p / "output"), &config);
BertLayer { attention, is_decoder, cross_attention, intermediate, output }
BertLayer {
attention,
is_decoder,
cross_attention,
intermediate,
output,
}
}
pub fn forward_t(&self,
pub fn forward_t(
&self,
hidden_states: &Tensor,
mask: &Option<Tensor>,
encoder_hidden_states: &Option<Tensor>,
encoder_mask: &Option<Tensor>,
train: bool) -> (Tensor, Option<Tensor>, Option<Tensor>) {
let (attention_output, attention_weights, cross_attention_weights) = if self.is_decoder & encoder_hidden_states.is_some() {
train: bool,
) -> (Tensor, Option<Tensor>, Option<Tensor>) {
let (attention_output, attention_weights, cross_attention_weights) =
if self.is_decoder & encoder_hidden_states.is_some() {
let (attention_output, attention_weights) =
self.attention.forward_t(hidden_states, mask, &None, &None, train);
self.attention
.forward_t(hidden_states, mask, &None, &None, train);
let (attention_output, cross_attention_weights) =
self.cross_attention.as_ref().unwrap().forward_t(&attention_output, mask, encoder_hidden_states, encoder_mask, train);
self.cross_attention.as_ref().unwrap().forward_t(
&attention_output,
mask,
encoder_hidden_states,
encoder_mask,
train,
);
(attention_output, attention_weights, cross_attention_weights)
} else {
let (attention_output, attention_weights) =
self.attention.forward_t(hidden_states, mask, &None, &None, train);
self.attention
.forward_t(hidden_states, mask, &None, &None, train);
(attention_output, attention_weights, None)
};
@ -78,26 +98,47 @@ pub struct BertEncoder {
impl BertEncoder {
pub fn new(p: &nn::Path, config: &BertConfig) -> BertEncoder {
let p = &(p / "layer");
let output_attentions = if let Some(value) = config.output_attentions { value } else { false };
let output_hidden_states = if let Some(value) = config.output_hidden_states { value } else { false };
let mut layers: Vec<BertLayer> = vec!();
for layer_index in 0..config.num_hidden_layers {
layers.push(BertLayer::new(&(p / layer_index), config));
let output_attentions = if let Some(value) = config.output_attentions {
value
} else {
false
};
let output_hidden_states = if let Some(value) = config.output_hidden_states {
value
} else {
false
};
BertEncoder { output_attentions, output_hidden_states, layers }
let mut layers: Vec<BertLayer> = vec![];
for layer_index in 0..config.num_hidden_layers {
layers.push(BertLayer::new(&(p / layer_index), config));
}
pub fn forward_t(&self,
BertEncoder {
output_attentions,
output_hidden_states,
layers,
}
}
pub fn forward_t(
&self,
hidden_states: &Tensor,
mask: &Option<Tensor>,
encoder_hidden_states: &Option<Tensor>,
encoder_mask: &Option<Tensor>,
train: bool)
-> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
let mut all_hidden_states: Option<Vec<Tensor>> = if self.output_hidden_states { Some(vec!()) } else { None };
let mut all_attentions: Option<Vec<Tensor>> = if self.output_attentions { Some(vec!()) } else { None };
train: bool,
) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
let mut all_hidden_states: Option<Vec<Tensor>> = if self.output_hidden_states {
Some(vec![])
} else {
None
};
let mut all_attentions: Option<Vec<Tensor>> = if self.output_attentions {
Some(vec![])
} else {
None
};
let mut hidden_state = hidden_states.copy();
let mut attention_weights: Option<Tensor>;
@ -109,16 +150,22 @@ impl BertEncoder {
hidden_states.push(hidden_state.as_ref().copy());
};
let temp = layer.forward_t(&hidden_state, &mask, encoder_hidden_states, encoder_mask, train);
let temp = layer.forward_t(
&hidden_state,
&mask,
encoder_hidden_states,
encoder_mask,
train,
);
hidden_state = temp.0;
attention_weights = temp.1;
if let Some(attentions) = all_attentions.borrow_mut() {
attentions.push(attention_weights.as_ref().unwrap().copy());
};
}
None => break
};
None => break,
};
}
(hidden_state, all_hidden_states, all_attentions)
}
@ -130,14 +177,16 @@ pub struct BertPooler {
impl BertPooler {
pub fn new(p: &nn::Path, config: &BertConfig) -> BertPooler {
let lin = nn::linear(&(p / "dense"), config.hidden_size, config.hidden_size, Default::default());
let lin = nn::linear(
&(p / "dense"),
config.hidden_size,
config.hidden_size,
Default::default(),
);
BertPooler { lin }
}
pub fn forward(&self, hidden_states: &Tensor) -> Tensor {
hidden_states
.select(1, 0)
.apply(&self.lin)
.tanh()
hidden_states.select(1, 0).apply(&self.lin).tanh()
}
}

View File

@ -24,13 +24,19 @@
//! use rust_tokenizers::BertTokenizer;
//! use tch::{nn, Device};
//! # use std::path::PathBuf;
//! use rust_bert::bert::{BertForMaskedLM, BertConfig};
//! use rust_bert::bert::{BertConfig, BertForMaskedLM};
//! use rust_bert::resources::{download_resource, LocalResource, Resource};
//! use rust_bert::Config;
//! use rust_bert::resources::{Resource, download_resource, LocalResource};
//!
//! let config_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/config.json")});
//! let vocab_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/vocab.txt")});
//! let weights_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/model.ot")});
//! let config_resource = Resource::Local(LocalResource {
//! local_path: PathBuf::from("path/to/config.json"),
//! });
//! let vocab_resource = Resource::Local(LocalResource {
//! local_path: PathBuf::from("path/to/vocab.txt"),
//! });
//! let weights_resource = Resource::Local(LocalResource {
//! local_path: PathBuf::from("path/to/model.ot"),
//! });
//! let config_path = download_resource(&config_resource)?;
//! let vocab_path = download_resource(&vocab_resource)?;
//! let weights_path = download_resource(&weights_resource)?;
@ -45,13 +51,14 @@
//! # }
//! ```
mod attention;
mod bert;
mod embeddings;
mod attention;
pub(crate) mod encoder;
pub use bert::{BertModelResources, BertConfigResources, BertVocabResources,
BertConfig, Activation, BertModel, BertForTokenClassification, BertForMultipleChoice,
BertForMaskedLM, BertForSequenceClassification, BertForQuestionAnswering};
pub use bert::{
Activation, BertConfig, BertConfigResources, BertForMaskedLM, BertForMultipleChoice,
BertForQuestionAnswering, BertForSequenceClassification, BertForTokenClassification, BertModel,
BertModelResources, BertVocabResources,
};
pub use embeddings::{BertEmbedding, BertEmbeddings};

View File

@ -1,14 +1,26 @@
use tch::Tensor;
use std::f64::consts::PI;
use tch::Tensor;
pub fn _gelu(x: &Tensor) -> Tensor { x * 0.5 * (1.0 + (x / ((2.0 as f64).sqrt())).erf()) }
pub fn _gelu(x: &Tensor) -> Tensor {
x * 0.5 * (1.0 + (x / ((2.0 as f64).sqrt())).erf())
}
pub fn _relu(x: &Tensor) -> Tensor { x.relu() }
pub fn _relu(x: &Tensor) -> Tensor {
x.relu()
}
pub fn _swish(x: &Tensor) -> Tensor { x * x.sigmoid() }
pub fn _swish(x: &Tensor) -> Tensor {
x * x.sigmoid()
}
pub fn _mish(x: &Tensor) -> Tensor { x * (x.softplus().tanh()) }
pub fn _mish(x: &Tensor) -> Tensor {
x * (x.softplus().tanh())
}
pub fn _gelu_new(x: &Tensor) -> Tensor { x * 0.5 * (((x.pow(3.0f64) * 0.044715 + x) * ((2f64 / PI).sqrt())).tanh() + 1) }
pub fn _gelu_new(x: &Tensor) -> Tensor {
x * 0.5 * (((x.pow(3.0f64) * 0.044715 + x) * ((2f64 / PI).sqrt())).tanh() + 1)
}
pub fn _tanh(x: &Tensor) -> Tensor { x.tanh() }
pub fn _tanh(x: &Tensor) -> Tensor {
x.tanh()
}

View File

@ -9,15 +9,16 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::path::Path;
use serde::Deserialize;
use std::fs::File;
use std::io::BufReader;
use serde::Deserialize;
use std::path::Path;
/// # Utility to deserialize JSON config files
pub trait Config<T>
where for<'de> T: Deserialize<'de> {
where
for<'de> T: Deserialize<'de>,
{
/// Loads a `Config` object from a JSON file. The format is expected to be aligned with the [Transformers library](https://github.com/huggingface/transformers) configuration files for each model.
/// The parsing will fail if non-optional keys expected by the model are missing.
///
@ -28,14 +29,13 @@ pub trait Config<T>
/// # Example
///
/// ```no_run
/// use rust_bert::gpt2::Gpt2Config;
/// use rust_bert::Config;
/// use std::path::Path;
/// use rust_bert::gpt2::Gpt2Config;
///
/// let config_path = Path::new("path/to/config.json");
/// let config = Gpt2Config::from_file(config_path);
/// ```
///
fn from_file(path: &Path) -> T {
let f = File::open(path).expect("Could not open configuration file.");
let br = BufReader::new(f);

View File

@ -10,9 +10,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use tch::nn::{Init, Path, Module};
use tch::Tensor;
use std::borrow::Borrow;
use tch::nn::{Init, Module, Path};
use tch::Tensor;
#[derive(Debug, Clone, Copy)]
pub struct LinearNoBiasConfig {
@ -32,7 +32,6 @@ pub struct LinearNoBias {
pub ws: Tensor,
}
pub fn linear_no_bias<'a, T: Borrow<Path<'a>>>(
vs: T,
in_dim: i64,

View File

@ -1,7 +1,7 @@
pub mod config;
pub mod resources;
pub(crate) mod dropout;
pub(crate) mod activations;
pub mod config;
pub(crate) mod dropout;
pub(crate) mod linear;
pub mod resources;
pub use config::Config;

View File

@ -18,9 +18,9 @@
//! pre-trained models in each model module.
use lazy_static::lazy_static;
use std::path::PathBuf;
use reqwest::Client;
use std::{fs, env};
use std::path::PathBuf;
use std::{env, fs};
use tokio::prelude::*;
use tokio::runtime::Runtime;
use tokio::task;
@ -47,12 +47,13 @@ impl Resource {
/// # Example
///
/// ```no_run
/// use rust_bert::resources::{Resource, LocalResource};
/// use rust_bert::resources::{LocalResource, Resource};
/// use std::path::PathBuf;
/// let config_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/config.json")});
/// let config_resource = Resource::Local(LocalResource {
/// local_path: PathBuf::from("path/to/config.json"),
/// });
/// let config_path = config_resource.get_local_path();
/// ```
///
pub fn get_local_path(&self) -> &PathBuf {
match self {
Resource::Local(resource) => &resource.local_path,
@ -65,7 +66,7 @@ impl Resource {
#[derive(PartialEq, Clone)]
pub struct LocalResource {
/// Local path for the resource
pub local_path: PathBuf
pub local_path: PathBuf,
}
/// # Remote resource
@ -93,13 +94,18 @@ impl RemoteResource {
/// # Example
///
/// ```no_run
/// use rust_bert::resources::{Resource, RemoteResource};
/// use rust_bert::resources::{RemoteResource, Resource};
/// use std::path::PathBuf;
/// let config_resource = Resource::Remote(RemoteResource::new("http://config_json_location", PathBuf::from("path/to/config.json")));
/// let config_resource = Resource::Remote(RemoteResource::new(
/// "http://config_json_location",
/// PathBuf::from("path/to/config.json"),
/// ));
/// ```
///
pub fn new(url: &str, target: PathBuf) -> RemoteResource {
RemoteResource { url: url.to_string(), local_path: target }
RemoteResource {
url: url.to_string(),
local_path: target,
}
}
/// Creates a new RemoteResource from an URL and local name. Will define a local path pointing to
@ -117,14 +123,12 @@ impl RemoteResource {
/// # Example
///
/// ```no_run
/// use rust_bert::resources::{Resource, RemoteResource};
/// let model_resource = Resource::Remote(RemoteResource::from_pretrained(
/// ("distilbert-sst2/model.ot",
/// "https://cdn.huggingface.co/distilbert-base-uncased-finetuned-sst-2-english-rust_model.ot"
/// )
/// ));
/// use rust_bert::resources::{RemoteResource, Resource};
/// let model_resource = Resource::Remote(RemoteResource::from_pretrained((
/// "distilbert-sst2/model.ot",
/// "https://cdn.huggingface.co/distilbert-base-uncased-finetuned-sst-2-english-rust_model.ot",
/// )));
/// ```
///
pub fn from_pretrained(name_url_tuple: (&str, &str)) -> RemoteResource {
let name = name_url_tuple.0;
let url = name_url_tuple.1.to_string();
@ -171,15 +175,13 @@ fn _get_cache_directory() -> PathBuf {
/// # Example
///
/// ```no_run
/// use rust_bert::resources::{Resource, RemoteResource, download_resource};
/// let model_resource = Resource::Remote(RemoteResource::from_pretrained(
/// ("distilbert-sst2/model.ot",
/// "https://cdn.huggingface.co/distilbert-base-uncased-finetuned-sst-2-english-rust_model.ot"
/// )
/// ));
/// use rust_bert::resources::{download_resource, RemoteResource, Resource};
/// let model_resource = Resource::Remote(RemoteResource::from_pretrained((
/// "distilbert-sst2/model.ot",
/// "https://cdn.huggingface.co/distilbert-base-uncased-finetuned-sst-2-english-rust_model.ot",
/// )));
/// let local_path = download_resource(&model_resource);
/// ```
///
pub fn download_resource(resource: &Resource) -> failure::Fallible<&PathBuf> {
match resource {
Resource::Remote(remote_resource) => {
@ -202,8 +204,6 @@ pub fn download_resource(resource: &Resource) -> failure::Fallible<&PathBuf> {
Ok(resource.get_local_path())
}
Resource::Local(_) => {
Ok(resource.get_local_path())
}
Resource::Local(_) => Ok(resource.get_local_path()),
}
}

View File

@ -16,7 +16,11 @@ extern crate tch;
pub fn main() -> failure::Fallible<()> {
let args: Vec<_> = std::env::args().collect();
ensure!(args.len() == 3, "usage: {} source.npz destination.ot", args[0]);
ensure!(
args.len() == 3,
"usage: {} source.npz destination.ot",
args[0]
);
let source_file = &args[1];
let destination_file = &args[2];

View File

@ -10,11 +10,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use tch::{nn, Tensor};
use crate::common::dropout::Dropout;
use crate::distilbert::distilbert::DistilBertConfig;
use tch::kind::Kind::Float;
use crate::common::dropout::Dropout;
use tch::{nn, Tensor};
#[derive(Debug)]
pub struct MultiHeadSelfAttention {
@ -39,7 +38,7 @@ impl MultiHeadSelfAttention {
let output_attentions = match config.output_attentions {
Some(value) => value,
None => false
None => false,
};
MultiHeadSelfAttention {
@ -59,10 +58,19 @@ impl MultiHeadSelfAttention {
}
fn flatten(&self, x: Tensor, bs: i64, dim_per_head: i64) -> Tensor {
x.transpose(1, 2).contiguous().view((bs, -1, &self.n_heads * dim_per_head))
x.transpose(1, 2)
.contiguous()
.view((bs, -1, &self.n_heads * dim_per_head))
}
pub fn forward_t(&self, query: &Tensor, key: &Tensor, value: &Tensor, mask: &Option<Tensor>, train: bool) -> (Tensor, Option<Tensor>) {
pub fn forward_t(
&self,
query: &Tensor,
key: &Tensor,
value: &Tensor,
mask: &Option<Tensor>,
train: bool,
) -> (Tensor, Option<Tensor>) {
let bs = query.size()[0];
let k_length = key.size()[1];
@ -73,14 +81,19 @@ impl MultiHeadSelfAttention {
let scores = if let Some(mask) = mask {
let unmasked_scores = q.matmul(&k.transpose(2, 3));
let mask = mask.le1(&(mask.zeros_like() + 0.1)).view((bs, 1i64, 1i64, k_length)).expand_as(&unmasked_scores);
let mask = mask
.le1(&(mask.zeros_like() + 0.1))
.view((bs, 1i64, 1i64, k_length))
.expand_as(&unmasked_scores);
unmasked_scores.masked_fill(&mask, std::f64::NEG_INFINITY)
} else {
q.matmul(&k.transpose(2, 3))
};
let weights = scores.softmax(-1, Float).apply_t(&self.dropout, train);
let context = self.flatten(weights.matmul(&v), bs, self.dim_per_head).apply(&self.out_lin);
let context = self
.flatten(weights.matmul(&v), bs, self.dim_per_head)
.apply(&self.out_lin);
if !self.output_attentions {
(context, None)

View File

@ -12,13 +12,13 @@
extern crate tch;
use std::collections::HashMap;
use serde::{Deserialize, Serialize};
use crate::distilbert::embeddings::DistilBertEmbedding;
use crate::distilbert::transformer::Transformer;
use self::tch::{nn, Tensor};
use crate::common::dropout::Dropout;
use crate::distilbert::embeddings::DistilBertEmbedding;
use crate::distilbert::transformer::Transformer;
use crate::Config;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
/// # DistilBERT Pretrained model weight files
pub struct DistilBertModelResources;
@ -31,29 +31,56 @@ pub struct DistilBertVocabResources;
impl DistilBertModelResources {
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
pub const DISTIL_BERT_SST2: (&'static str, &'static str) = ("distilbert-sst2/model.ot", "https://cdn.huggingface.co/distilbert-base-uncased-finetuned-sst-2-english-rust_model.ot");
pub const DISTIL_BERT_SST2: (&'static str, &'static str) = (
"distilbert-sst2/model.ot",
"https://cdn.huggingface.co/distilbert-base-uncased-finetuned-sst-2-english-rust_model.ot",
);
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
pub const DISTIL_BERT: (&'static str, &'static str) = ("distilbert/model.ot", "https://cdn.huggingface.co/distilbert-base-uncased-rust_model.ot");
pub const DISTIL_BERT: (&'static str, &'static str) = (
"distilbert/model.ot",
"https://cdn.huggingface.co/distilbert-base-uncased-rust_model.ot",
);
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
pub const DISTIL_BERT_SQUAD: (&'static str, &'static str) = ("distilbert-qa/model.ot", "https://cdn.huggingface.co/distilbert-base-cased-distilled-squad-rust_model.ot");
pub const DISTIL_BERT_SQUAD: (&'static str, &'static str) = (
"distilbert-qa/model.ot",
"https://cdn.huggingface.co/distilbert-base-cased-distilled-squad-rust_model.ot",
);
}
impl DistilBertConfigResources {
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
pub const DISTIL_BERT_SST2: (&'static str, &'static str) = ("distilbert-sst2/config.json", "https://cdn.huggingface.co/distilbert-base-uncased-finetuned-sst-2-english-config.json");
pub const DISTIL_BERT_SST2: (&'static str, &'static str) = (
"distilbert-sst2/config.json",
"https://cdn.huggingface.co/distilbert-base-uncased-finetuned-sst-2-english-config.json",
);
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
pub const DISTIL_BERT: (&'static str, &'static str) = ("distilbert/config.json", "https://cdn.huggingface.co/distilbert-base-uncased-config.json");
pub const DISTIL_BERT: (&'static str, &'static str) = (
"distilbert/config.json",
"https://cdn.huggingface.co/distilbert-base-uncased-config.json",
);
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
pub const DISTIL_BERT_SQUAD: (&'static str, &'static str) = ("distilbert-qa/config.json", "https://cdn.huggingface.co/distilbert-base-cased-distilled-squad-config.json");
pub const DISTIL_BERT_SQUAD: (&'static str, &'static str) = (
"distilbert-qa/config.json",
"https://cdn.huggingface.co/distilbert-base-cased-distilled-squad-config.json",
);
}
impl DistilBertVocabResources {
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
pub const DISTIL_BERT_SST2: (&'static str, &'static str) = ("distilbert-sst2/vocab.txt", "https://cdn.huggingface.co/distilbert-base-uncased-finetuned-sst-2-english-vocab.txt");
pub const DISTIL_BERT_SST2: (&'static str, &'static str) = (
"distilbert-sst2/vocab.txt",
"https://cdn.huggingface.co/distilbert-base-uncased-finetuned-sst-2-english-vocab.txt",
);
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
pub const DISTIL_BERT: (&'static str, &'static str) = ("distilbert/vocab.txt", "https://cdn.huggingface.co/bert-base-uncased-vocab.txt");
pub const DISTIL_BERT: (&'static str, &'static str) = (
"distilbert/vocab.txt",
"https://cdn.huggingface.co/bert-base-uncased-vocab.txt",
);
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
pub const DISTIL_BERT_SQUAD: (&'static str, &'static str) = ("distilbert-qa/vocab.txt", "https://cdn.huggingface.co/bert-large-cased-vocab.txt");
pub const DISTIL_BERT_SQUAD: (&'static str, &'static str) = (
"distilbert-qa/vocab.txt",
"https://cdn.huggingface.co/bert-large-cased-vocab.txt",
);
}
#[allow(non_camel_case_types)]
@ -118,10 +145,10 @@ impl DistilBertModel {
/// # Example
///
/// ```no_run
/// use tch::{nn, Device};
/// use rust_bert::distilbert::{DistilBertConfig, DistilBertModel};
/// use rust_bert::Config;
/// use std::path::Path;
/// use rust_bert::distilbert::{DistilBertConfig, DistilBertModel};
/// use tch::{nn, Device};
///
/// let config_path = Path::new("path/to/config.json");
/// let device = Device::Cpu;
@ -129,12 +156,14 @@ impl DistilBertModel {
/// let config = DistilBertConfig::from_file(config_path);
/// let distil_bert: DistilBertModel = DistilBertModel::new(&(&p.root() / "distilbert"), &config);
/// ```
///
pub fn new(p: &nn::Path, config: &DistilBertConfig) -> DistilBertModel {
let p = &(p / "distilbert");
let embeddings = DistilBertEmbedding::new(&(p / "embeddings"), config);
let transformer = Transformer::new(&(p / "transformer"), config);
DistilBertModel { embeddings, transformer }
DistilBertModel {
embeddings,
transformer,
}
}
/// Forward pass through the model
@ -172,28 +201,32 @@ impl DistilBertModel {
///
/// let (output, all_hidden_states, all_attentions) = no_grad(|| {
/// distilbert_model
/// .forward_t(Some(input_tensor),
/// Some(mask),
/// None,
/// false).unwrap()
/// .forward_t(Some(input_tensor), Some(mask), None, false)
/// .unwrap()
/// });
///
/// ```
///
pub fn forward_t(&self, input: Option<Tensor>, mask: Option<Tensor>, input_embeds: Option<Tensor>, train: bool)
-> Result<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
pub fn forward_t(
&self,
input: Option<Tensor>,
mask: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool,
) -> Result<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
let input_embeddings = match input {
Some(input_value) => match input_embeds {
Some(_) => { return Err("Only one of input ids or input embeddings may be set"); }
None => input_value.apply_t(&self.embeddings, train)
Some(_) => {
return Err("Only one of input ids or input embeddings may be set");
}
None => input_value.apply_t(&self.embeddings, train),
},
None => match input_embeds {
Some(embeds) => embeds,
None => { return Err("At least one of input ids or input embeddings must be set"); }
None => {
return Err("At least one of input ids or input embeddings must be set");
}
},
};
let transformer_output = (&self.transformer).forward_t(&input_embeddings, mask, train);
Ok(transformer_output)
}
@ -223,28 +256,47 @@ impl DistilBertModelClassifier {
/// # Example
///
/// ```no_run
/// use tch::{nn, Device};
/// use rust_bert::distilbert::{DistilBertConfig, DistilBertModelClassifier};
/// use rust_bert::Config;
/// use std::path::Path;
/// use rust_bert::distilbert::{DistilBertConfig, DistilBertModelClassifier};
/// use tch::{nn, Device};
///
/// let config_path = Path::new("path/to/config.json");
/// let device = Device::Cpu;
/// let p = nn::VarStore::new(device);
/// let config = DistilBertConfig::from_file(config_path);
/// let distil_bert: DistilBertModelClassifier = DistilBertModelClassifier::new(&(&p.root() / "distilbert"), &config);
/// let distil_bert: DistilBertModelClassifier =
/// DistilBertModelClassifier::new(&(&p.root() / "distilbert"), &config);
/// ```
///
pub fn new(p: &nn::Path, config: &DistilBertConfig) -> DistilBertModelClassifier {
let distil_bert_model = DistilBertModel::new(&p, config);
let num_labels = config.id2label.as_ref().expect("id2label must be provided for classifiers").len() as i64;
let num_labels = config
.id2label
.as_ref()
.expect("id2label must be provided for classifiers")
.len() as i64;
let pre_classifier = nn::linear(&(p / "pre_classifier"), config.dim, config.dim, Default::default());
let classifier = nn::linear(&(p / "classifier"), config.dim, num_labels, Default::default());
let pre_classifier = nn::linear(
&(p / "pre_classifier"),
config.dim,
config.dim,
Default::default(),
);
let classifier = nn::linear(
&(p / "classifier"),
config.dim,
num_labels,
Default::default(),
);
let dropout = Dropout::new(config.seq_classif_dropout);
DistilBertModelClassifier { distil_bert_model, pre_classifier, classifier, dropout }
DistilBertModelClassifier {
distil_bert_model,
pre_classifier,
classifier,
dropout,
}
}
/// Forward pass through the model
@ -287,14 +339,21 @@ impl DistilBertModelClassifier {
/// None,
/// false).unwrap()
/// });
///
/// ```
///
pub fn forward_t(&self, input: Option<Tensor>, mask: Option<Tensor>, input_embeds: Option<Tensor>, train: bool)
-> Result<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
let (output, all_hidden_states, all_attentions) = match self.distil_bert_model.forward_t(input, mask, input_embeds, train) {
pub fn forward_t(
&self,
input: Option<Tensor>,
mask: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool,
) -> Result<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
let (output, all_hidden_states, all_attentions) =
match self
.distil_bert_model
.forward_t(input, mask, input_embeds, train)
{
Ok(value) => value,
Err(err) => return Err(err)
Err(err) => return Err(err),
};
let output = output
@ -322,7 +381,6 @@ pub struct DistilBertModelMaskedLM {
vocab_projector: nn::Linear,
}
impl DistilBertModelMaskedLM {
/// Build a new `DistilBertModelMaskedLM` for sequence classification
///
@ -334,10 +392,10 @@ impl DistilBertModelMaskedLM {
/// # Example
///
/// ```no_run
/// use tch::{nn, Device};
/// use rust_bert::distilbert::{DistilBertConfig, DistilBertModelMaskedLM};
/// use rust_bert::Config;
/// use std::path::Path;
/// use rust_bert::distilbert::{DistilBertConfig, DistilBertModelMaskedLM};
/// use tch::{nn, Device};
///
/// let config_path = Path::new("path/to/config.json");
/// let device = Device::Cpu;
@ -345,15 +403,33 @@ impl DistilBertModelMaskedLM {
/// let config = DistilBertConfig::from_file(config_path);
/// let distil_bert = DistilBertModelMaskedLM::new(&(&p.root() / "distilbert"), &config);
/// ```
///
pub fn new(p: &nn::Path, config: &DistilBertConfig) -> DistilBertModelMaskedLM {
let distil_bert_model = DistilBertModel::new(&p, config);
let vocab_transform = nn::linear(&(p / "vocab_transform"), config.dim, config.dim, Default::default());
let layer_norm_config = nn::LayerNormConfig { eps: 1e-12, ..Default::default() };
let vocab_layer_norm = nn::layer_norm(p / "vocab_layer_norm", vec![config.dim], layer_norm_config);
let vocab_projector = nn::linear(&(p / "vocab_projector"), config.dim, config.vocab_size, Default::default());
let vocab_transform = nn::linear(
&(p / "vocab_transform"),
config.dim,
config.dim,
Default::default(),
);
let layer_norm_config = nn::LayerNormConfig {
eps: 1e-12,
..Default::default()
};
let vocab_layer_norm =
nn::layer_norm(p / "vocab_layer_norm", vec![config.dim], layer_norm_config);
let vocab_projector = nn::linear(
&(p / "vocab_projector"),
config.dim,
config.vocab_size,
Default::default(),
);
DistilBertModelMaskedLM { distil_bert_model, vocab_transform, vocab_layer_norm, vocab_projector }
DistilBertModelMaskedLM {
distil_bert_model,
vocab_transform,
vocab_layer_norm,
vocab_projector,
}
}
/// Forward pass through the model
@ -391,19 +467,24 @@ impl DistilBertModelMaskedLM {
///
/// let (output, all_hidden_states, all_attentions) = no_grad(|| {
/// distilbert_model
/// .forward_t(Some(input_tensor),
/// Some(mask),
/// None,
/// false).unwrap()
/// .forward_t(Some(input_tensor), Some(mask), None, false)
/// .unwrap()
/// });
///
/// ```
///
pub fn forward_t(&self, input: Option<Tensor>, mask: Option<Tensor>, input_embeds: Option<Tensor>, train: bool)
-> Result<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
let (output, all_hidden_states, all_attentions) = match self.distil_bert_model.forward_t(input, mask, input_embeds, train) {
pub fn forward_t(
&self,
input: Option<Tensor>,
mask: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool,
) -> Result<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
let (output, all_hidden_states, all_attentions) =
match self
.distil_bert_model
.forward_t(input, mask, input_embeds, train)
{
Ok(value) => value,
Err(err) => return Err(err)
Err(err) => return Err(err),
};
let output = output
@ -440,10 +521,10 @@ impl DistilBertForQuestionAnswering {
/// # Example
///
/// ```no_run
/// use tch::{nn, Device};
/// use rust_bert::distilbert::{DistilBertConfig, DistilBertForQuestionAnswering};
/// use rust_bert::Config;
/// use std::path::Path;
/// use rust_bert::distilbert::{DistilBertConfig, DistilBertForQuestionAnswering};
/// use tch::{nn, Device};
///
/// let config_path = Path::new("path/to/config.json");
/// let device = Device::Cpu;
@ -451,13 +532,16 @@ impl DistilBertForQuestionAnswering {
/// let config = DistilBertConfig::from_file(config_path);
/// let distil_bert = DistilBertForQuestionAnswering::new(&(&p.root() / "distilbert"), &config);
/// ```
///
pub fn new(p: &nn::Path, config: &DistilBertConfig) -> DistilBertForQuestionAnswering {
let distil_bert_model = DistilBertModel::new(&p, config);
let qa_outputs = nn::linear(&(p / "qa_outputs"), config.dim, 2, Default::default());
let dropout = Dropout::new(config.qa_dropout);
DistilBertForQuestionAnswering { distil_bert_model, qa_outputs, dropout }
DistilBertForQuestionAnswering {
distil_bert_model,
qa_outputs,
dropout,
}
}
/// Forward pass through the model
@ -496,35 +580,33 @@ impl DistilBertForQuestionAnswering {
///
/// let (start_scores, end_score, all_hidden_states, all_attentions) = no_grad(|| {
/// distilbert_model
/// .forward_t(Some(input_tensor),
/// Some(mask),
/// None,
/// false).unwrap()
/// .forward_t(Some(input_tensor), Some(mask), None, false)
/// .unwrap()
/// });
///
/// ```
///
pub fn forward_t(&self,
pub fn forward_t(
&self,
input: Option<Tensor>,
mask: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool)
-> Result<(Tensor, Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
let (output, all_hidden_states, all_attentions) = match self.distil_bert_model.forward_t(input, mask, input_embeds, train) {
train: bool,
) -> Result<(Tensor, Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
let (output, all_hidden_states, all_attentions) =
match self
.distil_bert_model
.forward_t(input, mask, input_embeds, train)
{
Ok(value) => value,
Err(err) => return Err(err)
Err(err) => return Err(err),
};
let output = output
.apply_t(&self.dropout, train)
.apply(&self.qa_outputs);
let output = output.apply_t(&self.dropout, train).apply(&self.qa_outputs);
let logits = output.split(1, -1);
let (start_logits, end_logits) = (&logits[0], &logits[1]);
let start_logits = start_logits.squeeze1(-1);
let end_logits = end_logits.squeeze1(-1);
Ok((start_logits, end_logits, all_hidden_states, all_attentions))
}
}
@ -552,10 +634,10 @@ impl DistilBertForTokenClassification {
/// # Example
///
/// ```no_run
/// use tch::{nn, Device};
/// use rust_bert::distilbert::{DistilBertConfig, DistilBertForTokenClassification};
/// use rust_bert::Config;
/// use std::path::Path;
/// use rust_bert::distilbert::{DistilBertConfig, DistilBertForTokenClassification};
/// use tch::{nn, Device};
///
/// let config_path = Path::new("path/to/config.json");
/// let device = Device::Cpu;
@ -563,16 +645,28 @@ impl DistilBertForTokenClassification {
/// let config = DistilBertConfig::from_file(config_path);
/// let distil_bert = DistilBertForTokenClassification::new(&(&p.root() / "distilbert"), &config);
/// ```
///
pub fn new(p: &nn::Path, config: &DistilBertConfig) -> DistilBertForTokenClassification {
let distil_bert_model = DistilBertModel::new(&p, config);
let num_labels = config.id2label.as_ref().expect("id2label must be provided for classifiers").len() as i64;
let num_labels = config
.id2label
.as_ref()
.expect("id2label must be provided for classifiers")
.len() as i64;
let classifier = nn::linear(&(p / "classifier"), config.dim, num_labels, Default::default());
let classifier = nn::linear(
&(p / "classifier"),
config.dim,
num_labels,
Default::default(),
);
let dropout = Dropout::new(config.seq_classif_dropout);
DistilBertForTokenClassification { distil_bert_model, classifier, dropout }
DistilBertForTokenClassification {
distil_bert_model,
classifier,
dropout,
}
}
/// Forward pass through the model
@ -610,24 +704,27 @@ impl DistilBertForTokenClassification {
///
/// let (output, all_hidden_states, all_attentions) = no_grad(|| {
/// distilbert_model
/// .forward_t(Some(input_tensor),
/// Some(mask),
/// None,
/// false).unwrap()
/// .forward_t(Some(input_tensor), Some(mask), None, false)
/// .unwrap()
/// });
///
/// ```
///
pub fn forward_t(&self, input: Option<Tensor>, mask: Option<Tensor>, input_embeds: Option<Tensor>, train: bool)
-> Result<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
let (output, all_hidden_states, all_attentions) = match self.distil_bert_model.forward_t(input, mask, input_embeds, train) {
pub fn forward_t(
&self,
input: Option<Tensor>,
mask: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool,
) -> Result<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
let (output, all_hidden_states, all_attentions) =
match self
.distil_bert_model
.forward_t(input, mask, input_embeds, train)
{
Ok(value) => value,
Err(err) => return Err(err)
Err(err) => return Err(err),
};
let output = output
.apply_t(&self.dropout, train)
.apply(&self.classifier);
let output = output.apply_t(&self.dropout, train).apply(&self.classifier);
Ok((output, all_hidden_states, all_attentions))
}

View File

@ -10,22 +10,26 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use tch::{nn, Tensor, Kind, Device};
use tch::nn::{ModuleT, embedding, EmbeddingConfig};
use crate::distilbert::distilbert::DistilBertConfig;
use crate::common::dropout::Dropout;
use crate::distilbert::distilbert::DistilBertConfig;
use tch::kind::Kind::Float;
use tch::nn::{embedding, EmbeddingConfig, ModuleT};
use tch::{nn, Device, Kind, Tensor};
fn create_sinusoidal_embeddings(config: &DistilBertConfig, device: Device) -> nn::Embedding {
let mut sinusoidal_embedding: Vec<Tensor> = Vec::with_capacity(config.max_position_embeddings as usize);
let mut sinusoidal_embedding: Vec<Tensor> =
Vec::with_capacity(config.max_position_embeddings as usize);
for pos in 0..config.max_position_embeddings {
let mut temp_vec: Vec<f64> = Vec::with_capacity(config.dim as usize);
for j in 0..config.dim {
if j % 2 == 0 {
temp_vec.push((pos as f64 / 10000f64.powf((2 * (j / 2)) as f64 / config.dim as f64)).sin());
temp_vec.push(
(pos as f64 / 10000f64.powf((2 * (j / 2)) as f64 / config.dim as f64)).sin(),
);
} else {
temp_vec.push((pos as f64 / 10000f64.powf((2 * (j / 2)) as f64 / config.dim as f64)).cos());
temp_vec.push(
(pos as f64 / 10000f64.powf((2 * (j / 2)) as f64 / config.dim as f64)).cos(),
);
}
}
let temp_vec = Tensor::of_slice(&temp_vec);
@ -35,17 +39,21 @@ fn create_sinusoidal_embeddings(config: &DistilBertConfig, device: Device) -> nn
.to_kind(Float)
.to_device(device);
let embedding_config = EmbeddingConfig { padding_idx: 0, ..Default::default() };
let mut embeddings = embedding(&nn::VarStore::new(device).root(),
let embedding_config = EmbeddingConfig {
padding_idx: 0,
..Default::default()
};
let mut embeddings = embedding(
&nn::VarStore::new(device).root(),
config.max_position_embeddings,
config.dim,
embedding_config);
embedding_config,
);
embeddings.ws = sinusoidal_embedding;
embeddings
}
#[derive(Debug)]
pub struct DistilBertEmbedding {
word_embeddings: nn::Embedding,
@ -56,24 +64,40 @@ pub struct DistilBertEmbedding {
impl DistilBertEmbedding {
pub fn new(p: &nn::Path, config: &DistilBertConfig) -> DistilBertEmbedding {
let embedding_config = EmbeddingConfig { padding_idx: 0, ..Default::default() };
let embedding_config = EmbeddingConfig {
padding_idx: 0,
..Default::default()
};
let word_embeddings: nn::Embedding = embedding(p / "word_embeddings",
let word_embeddings: nn::Embedding = embedding(
p / "word_embeddings",
config.vocab_size,
config.dim,
embedding_config);
embedding_config,
);
let position_embeddings: nn::Embedding = match config.sinusoidal_pos_embds {
false => embedding(p / "position_embeddings",
false => embedding(
p / "position_embeddings",
config.max_position_embeddings,
config.dim,
embedding_config),
embedding_config,
),
true => create_sinusoidal_embeddings(&config, p.device())
true => create_sinusoidal_embeddings(&config, p.device()),
};
let layer_norm_config = nn::LayerNormConfig { eps: 1e-12, ..Default::default() };
let layer_norm: nn::LayerNorm = nn::layer_norm(p / "LayerNorm", vec![config.dim], layer_norm_config);
let layer_norm_config = nn::LayerNormConfig {
eps: 1e-12,
..Default::default()
};
let layer_norm: nn::LayerNorm =
nn::layer_norm(p / "LayerNorm", vec![config.dim], layer_norm_config);
let dropout: Dropout = Dropout::new(config.dropout);
DistilBertEmbedding { word_embeddings, position_embeddings, layer_norm, dropout }
DistilBertEmbedding {
word_embeddings,
position_embeddings,
layer_norm,
dropout,
}
}
pub fn _get_word_embeddings(&self) -> &nn::Embedding {
@ -96,7 +120,9 @@ impl ModuleT for DistilBertEmbedding {
// position_embed.get(0).get(0).print();
let embeddings = word_embed + position_embed;
let embeddings = embeddings.apply(&self.layer_norm).apply_t(&self.dropout, train);
let embeddings = embeddings
.apply(&self.layer_norm)
.apply_t(&self.dropout, train);
embeddings
}

View File

@ -23,13 +23,22 @@
//! use rust_tokenizers::BertTokenizer;
//! use tch::{nn, Device};
//! # use std::path::PathBuf;
//! use rust_bert::distilbert::{
//! DistilBertConfig, DistilBertConfigResources, DistilBertModelMaskedLM,
//! DistilBertModelResources, DistilBertVocabResources,
//! };
//! use rust_bert::resources::{download_resource, LocalResource, RemoteResource, Resource};
//! use rust_bert::Config;
//! use rust_bert::distilbert::{DistilBertModelMaskedLM, DistilBertConfig, DistilBertConfigResources, DistilBertVocabResources, DistilBertModelResources};
//! use rust_bert::resources::{Resource, download_resource, RemoteResource, LocalResource};
//!
//! let config_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/config.json")});
//! let vocab_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/vocab.txt")});
//! let weights_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/model.ot")});
//! let config_resource = Resource::Local(LocalResource {
//! local_path: PathBuf::from("path/to/config.json"),
//! });
//! let vocab_resource = Resource::Local(LocalResource {
//! local_path: PathBuf::from("path/to/vocab.txt"),
//! });
//! let weights_resource = Resource::Local(LocalResource {
//! local_path: PathBuf::from("path/to/model.ot"),
//! });
//! let config_path = download_resource(&config_resource)?;
//! let vocab_path = download_resource(&vocab_resource)?;
//! let weights_path = download_resource(&weights_resource)?;
@ -44,13 +53,13 @@
//! # }
//! ```
mod attention;
mod distilbert;
mod embeddings;
mod attention;
mod transformer;
pub use distilbert::{DistilBertModelResources, DistilBertConfigResources, DistilBertVocabResources,
DistilBertConfig, Activation, DistilBertModel, DistilBertForQuestionAnswering, DistilBertForTokenClassification,
DistilBertModelMaskedLM, DistilBertModelClassifier};
pub use distilbert::{
Activation, DistilBertConfig, DistilBertConfigResources, DistilBertForQuestionAnswering,
DistilBertForTokenClassification, DistilBertModel, DistilBertModelClassifier,
DistilBertModelMaskedLM, DistilBertModelResources, DistilBertVocabResources,
};

View File

@ -10,13 +10,13 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use tch::{Tensor, nn};
use crate::distilbert::distilbert::{DistilBertConfig, Activation};
use crate::distilbert::attention::MultiHeadSelfAttention;
use tch::nn::LayerNorm;
use std::borrow::BorrowMut;
use crate::common::dropout::Dropout;
use crate::common::activations::{_gelu, _relu};
use crate::common::dropout::Dropout;
use crate::distilbert::attention::MultiHeadSelfAttention;
use crate::distilbert::distilbert::{Activation, DistilBertConfig};
use std::borrow::BorrowMut;
use tch::nn::LayerNorm;
use tch::{nn, Tensor};
pub struct FeedForwardNetwork {
lin1: nn::Linear,
@ -27,18 +27,35 @@ pub struct FeedForwardNetwork {
impl FeedForwardNetwork {
pub fn new(p: nn::Path, config: &DistilBertConfig) -> FeedForwardNetwork {
let lin1 = nn::linear(&p / "lin1", config.dim, config.hidden_dim, Default::default());
let lin2 = nn::linear(&p / "lin2", config.hidden_dim, config.dim, Default::default());
let lin1 = nn::linear(
&p / "lin1",
config.dim,
config.hidden_dim,
Default::default(),
);
let lin2 = nn::linear(
&p / "lin2",
config.hidden_dim,
config.dim,
Default::default(),
);
let dropout = Dropout::new(config.dropout);
let activation = Box::new(match &config.activation {
Activation::gelu => _gelu,
Activation::relu => _relu
Activation::relu => _relu,
});
FeedForwardNetwork { lin1, lin2, dropout, activation }
FeedForwardNetwork {
lin1,
lin2,
dropout,
activation,
}
}
pub fn forward_t(&self, input: &Tensor, train: bool) -> Tensor {
(self.activation)(&input.apply(&self.lin1)).apply(&self.lin2).apply_t(&self.dropout, train)
(self.activation)(&input.apply(&self.lin1))
.apply(&self.lin2)
.apply_t(&self.dropout, train)
}
}
@ -52,10 +69,15 @@ pub struct TransformerBlock {
impl TransformerBlock {
pub fn new(p: &nn::Path, config: &DistilBertConfig) -> TransformerBlock {
let attention = MultiHeadSelfAttention::new(p / "attention", &config);
let layer_norm_config = nn::LayerNormConfig { eps: 1e-12, ..Default::default() };
let sa_layer_norm = nn::layer_norm(p / "sa_layer_norm", vec![config.dim], layer_norm_config);
let layer_norm_config = nn::LayerNormConfig {
eps: 1e-12,
..Default::default()
};
let sa_layer_norm =
nn::layer_norm(p / "sa_layer_norm", vec![config.dim], layer_norm_config);
let ffn = FeedForwardNetwork::new(p / "ffn", &config);
let output_layer_norm = nn::layer_norm(p / "output_layer_norm", vec![config.dim], layer_norm_config);
let output_layer_norm =
nn::layer_norm(p / "output_layer_norm", vec![config.dim], layer_norm_config);
TransformerBlock {
attention,
@ -65,8 +87,15 @@ impl TransformerBlock {
}
}
pub fn forward_t(&self, input: &Tensor, mask: &Option<Tensor>, train: bool) -> (Tensor, Option<Tensor>) {
let (output, sa_weights) = self.attention.forward_t(&input, &input, &input, mask, train);
pub fn forward_t(
&self,
input: &Tensor,
mask: &Option<Tensor>,
train: bool,
) -> (Tensor, Option<Tensor>) {
let (output, sa_weights) = self
.attention
.forward_t(&input, &input, &input, mask, train);
let output = (input + &output).apply(&self.sa_layer_norm);
let output = (&output + self.ffn.forward_t(&output, train)).apply(&self.output_layer_norm);
(output, sa_weights)
@ -84,25 +113,41 @@ impl Transformer {
let p = &(p / "layer");
let output_attentions = match config.output_attentions {
Some(value) => value,
None => false
None => false,
};
let output_hidden_states = match config.output_hidden_states {
Some(value) => value,
None => false
None => false,
};
let mut layers: Vec<TransformerBlock> = vec!();
let mut layers: Vec<TransformerBlock> = vec![];
for layer_index in 0..config.n_layers {
layers.push(TransformerBlock::new(&(p / layer_index), config));
};
Transformer { output_attentions, output_hidden_states, layers }
}
pub fn forward_t(&self, input: &Tensor, mask: Option<Tensor>, train: bool)
-> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
let mut all_hidden_states: Option<Vec<Tensor>> = if self.output_hidden_states { Some(vec!()) } else { None };
let mut all_attentions: Option<Vec<Tensor>> = if self.output_attentions { Some(vec!()) } else { None };
Transformer {
output_attentions,
output_hidden_states,
layers,
}
}
pub fn forward_t(
&self,
input: &Tensor,
mask: Option<Tensor>,
train: bool,
) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
let mut all_hidden_states: Option<Vec<Tensor>> = if self.output_hidden_states {
Some(vec![])
} else {
None
};
let mut all_attentions: Option<Vec<Tensor>> = if self.output_attentions {
Some(vec![])
} else {
None
};
let mut hidden_state = input.copy();
let mut attention_weights: Option<Tensor>;
@ -121,9 +166,9 @@ impl Transformer {
attentions.push(attention_weights.as_ref().unwrap().copy());
};
}
None => break
};
None => break,
};
}
(hidden_state, all_hidden_states, all_attentions)
}

View File

@ -12,15 +12,15 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use crate::bert::encoder::BertEncoder;
use crate::bert::{Activation, BertConfig};
use crate::common::activations::{_gelu, _mish, _relu};
use crate::common::dropout::Dropout;
use crate::electra::embeddings::ElectraEmbeddings;
use crate::Config;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use crate::bert::{Activation, BertConfig};
use crate::Config;
use crate::electra::embeddings::ElectraEmbeddings;
use tch::{nn, Tensor, Kind};
use crate::bert::encoder::BertEncoder;
use crate::common::activations::{_gelu, _relu, _mish};
use crate::common::dropout::Dropout;
use tch::{nn, Kind, Tensor};
/// # Electra Pretrained model weight files
pub struct ElectraModelResources;
@ -33,23 +33,41 @@ pub struct ElectraVocabResources;
impl ElectraModelResources {
/// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/electra. Modified with conversion to C-array format.
pub const BASE_GENERATOR: (&'static str, &'static str) = ("electra-base-generator/model.ot", "https://cdn.huggingface.co/google/electra-base-generator/rust_model.ot");
pub const BASE_GENERATOR: (&'static str, &'static str) = (
"electra-base-generator/model.ot",
"https://cdn.huggingface.co/google/electra-base-generator/rust_model.ot",
);
/// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/electra. Modified with conversion to C-array format.
pub const BASE_DISCRIMINATOR: (&'static str, &'static str) = ("electra-base-discriminator/model.ot", "https://cdn.huggingface.co/google/electra-base-discriminator/rust_model.ot");
pub const BASE_DISCRIMINATOR: (&'static str, &'static str) = (
"electra-base-discriminator/model.ot",
"https://cdn.huggingface.co/google/electra-base-discriminator/rust_model.ot",
);
}
impl ElectraConfigResources {
/// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/electra. Modified with conversion to C-array format.
pub const BASE_GENERATOR: (&'static str, &'static str) = ("electra-base-generator/config.json", "https://cdn.huggingface.co/google/electra-base-generator/config.json");
pub const BASE_GENERATOR: (&'static str, &'static str) = (
"electra-base-generator/config.json",
"https://cdn.huggingface.co/google/electra-base-generator/config.json",
);
/// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/electra. Modified with conversion to C-array format.
pub const BASE_DISCRIMINATOR: (&'static str, &'static str) = ("electra-base-discriminator/config.json", "https://cdn.huggingface.co/google/electra-base-discriminator/config.json");
pub const BASE_DISCRIMINATOR: (&'static str, &'static str) = (
"electra-base-discriminator/config.json",
"https://cdn.huggingface.co/google/electra-base-discriminator/config.json",
);
}
impl ElectraVocabResources {
/// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/electra. Modified with conversion to C-array format.
pub const BASE_GENERATOR: (&'static str, &'static str) = ("electra-base-generator/vocab.txt", "https://cdn.huggingface.co/google/electra-base-generator/vocab.txt");
pub const BASE_GENERATOR: (&'static str, &'static str) = (
"electra-base-generator/vocab.txt",
"https://cdn.huggingface.co/google/electra-base-generator/vocab.txt",
);
/// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/electra. Modified with conversion to C-array format.
pub const BASE_DISCRIMINATOR: (&'static str, &'static str) = ("electra-base-discriminator/vocab.txt", "https://cdn.huggingface.co/google/electra-base-discriminator/vocab.txt");
pub const BASE_DISCRIMINATOR: (&'static str, &'static str) = (
"electra-base-discriminator/vocab.txt",
"https://cdn.huggingface.co/google/electra-base-discriminator/vocab.txt",
);
}
#[derive(Debug, Serialize, Deserialize)]
@ -103,10 +121,10 @@ impl ElectraModel {
/// # Example
///
/// ```no_run
/// use rust_bert::electra::{ElectraModel, ElectraConfig};
/// use tch::{nn, Device};
/// use rust_bert::electra::{ElectraConfig, ElectraModel};
/// use rust_bert::Config;
/// use std::path::Path;
/// use tch::{nn, Device};
///
/// let config_path = Path::new("path/to/config.json");
/// let device = Device::Cpu;
@ -114,11 +132,15 @@ impl ElectraModel {
/// let config = ElectraConfig::from_file(config_path);
/// let electra_model: ElectraModel = ElectraModel::new(&(&p.root() / "electra"), &config);
/// ```
///
pub fn new(p: &nn::Path, config: &ElectraConfig) -> ElectraModel {
let embeddings = ElectraEmbeddings::new(&(p / "embeddings"), config);
let embeddings_project = if config.embedding_size != config.hidden_size {
Some(nn::linear(&(p / "embeddings_project"), config.embedding_size, config.hidden_size, Default::default()))
Some(nn::linear(
&(p / "embeddings_project"),
config.embedding_size,
config.hidden_size,
Default::default(),
))
} else {
None
};
@ -141,7 +163,11 @@ impl ElectraModel {
label2id: config.label2id.clone(),
};
let encoder = BertEncoder::new(&(p / "encoder"), &bert_config);
ElectraModel { embeddings, embeddings_project, encoder }
ElectraModel {
embeddings,
embeddings_project,
encoder,
}
}
/// Forward pass through the model
@ -178,66 +204,84 @@ impl ElectraModel {
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
/// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
/// let position_ids = Tensor::arange(sequence_length, (Int64, device)).expand(&[batch_size, sequence_length], true);
/// let position_ids = Tensor::arange(sequence_length, (Int64, device))
/// .expand(&[batch_size, sequence_length], true);
///
/// let (output, all_hidden_states, all_attentions) = no_grad(|| {
/// electra_model
/// .forward_t(Some(input_tensor),
/// .forward_t(
/// Some(input_tensor),
/// Some(mask),
/// Some(token_type_ids),
/// Some(position_ids),
/// None,
/// false).unwrap()
/// false,
/// )
/// .unwrap()
/// });
///
/// ```
///
pub fn forward_t(&self,
pub fn forward_t(
&self,
input_ids: Option<Tensor>,
mask: Option<Tensor>,
token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool)
-> Result<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
train: bool,
) -> Result<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
let (input_shape, device) = match &input_ids {
Some(input_value) => match &input_embeds {
Some(_) => { return Err("Only one of input ids or input embeddings may be set"); }
None => (input_value.size(), input_value.device())
Some(_) => {
return Err("Only one of input ids or input embeddings may be set");
}
None => (input_value.size(), input_value.device()),
},
None => match &input_embeds {
Some(embeds) => (vec!(embeds.size()[0], embeds.size()[1]), embeds.device()),
None => { return Err("At least one of input ids or input embeddings must be set"); }
Some(embeds) => (vec![embeds.size()[0], embeds.size()[1]], embeds.device()),
None => {
return Err("At least one of input ids or input embeddings must be set");
}
},
};
let mask = match mask {
Some(value) => value,
None => Tensor::ones(&input_shape, (Kind::Int64, device))
None => Tensor::ones(&input_shape, (Kind::Int64, device)),
};
let extended_attention_mask = match mask.dim() {
3 => mask.unsqueeze(1),
2 => mask.unsqueeze(1).unsqueeze(1),
_ => { return Err("Invalid attention mask dimension, must be 2 or 3"); }
_ => {
return Err("Invalid attention mask dimension, must be 2 or 3");
}
};
let hidden_states = match self.embeddings.forward_t(input_ids, token_type_ids, position_ids, input_embeds, train) {
let hidden_states = match self.embeddings.forward_t(
input_ids,
token_type_ids,
position_ids,
input_embeds,
train,
) {
Ok(value) => value,
Err(e) => { return Err(e); }
Err(e) => {
return Err(e);
}
};
let hidden_states = match &self.embeddings_project {
Some(layer) => hidden_states.apply(layer),
None => hidden_states
None => hidden_states,
};
let (hidden_state, all_hidden_states, all_attentions) =
self.encoder.forward_t(&hidden_states,
let (hidden_state, all_hidden_states, all_attentions) = self.encoder.forward_t(
&hidden_states,
&Some(extended_attention_mask),
&None,
&None,
train);
train,
);
Ok((hidden_state, all_hidden_states, all_attentions))
}
@ -268,9 +312,9 @@ impl ElectraDiscriminatorHead {
///
/// ```no_run
/// use rust_bert::electra::{ElectraConfig, ElectraDiscriminatorHead};
/// use tch::{nn, Device};
/// use rust_bert::Config;
/// use std::path::Path;
/// use tch::{nn, Device};
///
/// let config_path = Path::new("path/to/config.json");
/// let device = Device::Cpu;
@ -278,16 +322,29 @@ impl ElectraDiscriminatorHead {
/// let config = ElectraConfig::from_file(config_path);
/// let discriminator_head = ElectraDiscriminatorHead::new(&(&p.root() / "electra"), &config);
/// ```
///
pub fn new(p: &nn::Path, config: &ElectraConfig) -> ElectraDiscriminatorHead {
let dense = nn::linear(&(p / "dense"), config.hidden_size, config.hidden_size, Default::default());
let dense_prediction = nn::linear(&(p / "dense_prediction"), config.hidden_size, 1, Default::default());
let dense = nn::linear(
&(p / "dense"),
config.hidden_size,
config.hidden_size,
Default::default(),
);
let dense_prediction = nn::linear(
&(p / "dense_prediction"),
config.hidden_size,
1,
Default::default(),
);
let activation = Box::new(match &config.hidden_act {
Activation::gelu => _gelu,
Activation::relu => _relu,
Activation::mish => _mish
Activation::mish => _mish,
});
ElectraDiscriminatorHead { dense, dense_prediction, activation }
ElectraDiscriminatorHead {
dense,
dense_prediction,
activation,
}
}
/// Forward pass through the discriminator head
@ -316,14 +373,13 @@ impl ElectraDiscriminatorHead {
/// # let config = ElectraConfig::from_file(config_path);
/// # let discriminator_head = ElectraDiscriminatorHead::new(&vs.root(), &config);
/// let (batch_size, sequence_length) = (64, 128);
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length, config.hidden_size], (Float, device));
///
/// let output = no_grad(|| {
/// discriminator_head.forward(&input_tensor)
/// });
/// let input_tensor = Tensor::rand(
/// &[batch_size, sequence_length, config.hidden_size],
/// (Float, device),
/// );
///
/// let output = no_grad(|| discriminator_head.forward(&input_tensor));
/// ```
///
pub fn forward(&self, encoder_hidden_states: &Tensor) -> Tensor {
let output = encoder_hidden_states.apply(&self.dense);
let output = (self.activation)(&output);
@ -356,9 +412,9 @@ impl ElectraGeneratorHead {
///
/// ```no_run
/// use rust_bert::electra::{ElectraConfig, ElectraGeneratorHead};
/// use tch::{nn, Device};
/// use rust_bert::Config;
/// use std::path::Path;
/// use tch::{nn, Device};
///
/// let config_path = Path::new("path/to/config.json");
/// let device = Device::Cpu;
@ -366,13 +422,25 @@ impl ElectraGeneratorHead {
/// let config = ElectraConfig::from_file(config_path);
/// let generator_head = ElectraGeneratorHead::new(&(&p.root() / "electra"), &config);
/// ```
///
pub fn new(p: &nn::Path, config: &ElectraConfig) -> ElectraGeneratorHead {
let layer_norm = nn::layer_norm(p / "LayerNorm", vec![config.embedding_size], Default::default());
let dense = nn::linear(&(p / "dense"), config.hidden_size, config.embedding_size, Default::default());
let layer_norm = nn::layer_norm(
p / "LayerNorm",
vec![config.embedding_size],
Default::default(),
);
let dense = nn::linear(
&(p / "dense"),
config.hidden_size,
config.embedding_size,
Default::default(),
);
let activation = Box::new(_gelu);
ElectraGeneratorHead { layer_norm, dense, activation }
ElectraGeneratorHead {
layer_norm,
dense,
activation,
}
}
/// Forward pass through the generator head
@ -399,14 +467,13 @@ impl ElectraGeneratorHead {
/// # let config = ElectraConfig::from_file(config_path);
/// # let generator_head = ElectraGeneratorHead::new(&vs.root(), &config);
/// let (batch_size, sequence_length) = (64, 128);
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length, config.hidden_size], (Float, device));
///
/// let output = no_grad(|| {
/// generator_head.forward(&input_tensor)
/// });
/// let input_tensor = Tensor::rand(
/// &[batch_size, sequence_length, config.hidden_size],
/// (Float, device),
/// );
///
/// let output = no_grad(|| generator_head.forward(&input_tensor));
/// ```
///
pub fn forward(&self, encoder_hidden_states: &Tensor) -> Tensor {
let output = encoder_hidden_states.apply(&self.dense);
let output = (self.activation)(&output);
@ -438,10 +505,10 @@ impl ElectraForMaskedLM {
/// # Example
///
/// ```no_run
/// use rust_bert::electra::{ElectraForMaskedLM, ElectraConfig};
/// use tch::{nn, Device};
/// use rust_bert::electra::{ElectraConfig, ElectraForMaskedLM};
/// use rust_bert::Config;
/// use std::path::Path;
/// use tch::{nn, Device};
///
/// let config_path = Path::new("path/to/config.json");
/// let device = Device::Cpu;
@ -449,13 +516,21 @@ impl ElectraForMaskedLM {
/// let config = ElectraConfig::from_file(config_path);
/// let electra_model: ElectraForMaskedLM = ElectraForMaskedLM::new(&p.root(), &config);
/// ```
///
pub fn new(p: &nn::Path, config: &ElectraConfig) -> ElectraForMaskedLM {
let electra = ElectraModel::new(&(p / "electra"), config);
let generator_head = ElectraGeneratorHead::new(&(p / "generator_predictions"), config);
let lm_head = nn::linear(&(p / "generator_lm_head"), config.embedding_size, config.vocab_size, Default::default());
let lm_head = nn::linear(
&(p / "generator_lm_head"),
config.embedding_size,
config.vocab_size,
Default::default(),
);
ElectraForMaskedLM { electra, generator_head, lm_head }
ElectraForMaskedLM {
electra,
generator_head,
lm_head,
}
}
/// Forward pass through the model
@ -492,32 +567,39 @@ impl ElectraForMaskedLM {
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
/// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
/// let position_ids = Tensor::arange(sequence_length, (Int64, device)).expand(&[batch_size, sequence_length], true);
/// let position_ids = Tensor::arange(sequence_length, (Int64, device))
/// .expand(&[batch_size, sequence_length], true);
///
/// let (output, all_hidden_states, all_attentions) = no_grad(|| {
/// electra_model
/// .forward_t(Some(input_tensor),
/// electra_model.forward_t(
/// Some(input_tensor),
/// Some(mask),
/// Some(token_type_ids),
/// Some(position_ids),
/// None,
/// false)
/// false,
/// )
/// });
///
/// ```
///
pub fn forward_t(&self,
pub fn forward_t(
&self,
input_ids: Option<Tensor>,
mask: Option<Tensor>,
token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool)
-> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
let (hidden_states,
all_hidden_states,
all_attentions) = self.electra
.forward_t(input_ids, mask, token_type_ids, position_ids, input_embeds, train)
train: bool,
) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
let (hidden_states, all_hidden_states, all_attentions) = self
.electra
.forward_t(
input_ids,
mask,
token_type_ids,
position_ids,
input_embeds,
train,
)
.unwrap();
let hidden_states = self.generator_head.forward(&hidden_states);
let hidden_states = hidden_states.apply(&self.lm_head);
@ -547,10 +629,10 @@ impl ElectraDiscriminator {
/// # Example
///
/// ```no_run
/// use rust_bert::electra::{ElectraDiscriminator, ElectraConfig};
/// use tch::{nn, Device};
/// use rust_bert::electra::{ElectraConfig, ElectraDiscriminator};
/// use rust_bert::Config;
/// use std::path::Path;
/// use tch::{nn, Device};
///
/// let config_path = Path::new("path/to/config.json");
/// let device = Device::Cpu;
@ -558,12 +640,15 @@ impl ElectraDiscriminator {
/// let config = ElectraConfig::from_file(config_path);
/// let electra_model: ElectraDiscriminator = ElectraDiscriminator::new(&p.root(), &config);
/// ```
///
pub fn new(p: &nn::Path, config: &ElectraConfig) -> ElectraDiscriminator {
let electra = ElectraModel::new(&(p / "electra"), config);
let discriminator_head = ElectraDiscriminatorHead::new(&(p / "discriminator_predictions"), config);
let discriminator_head =
ElectraDiscriminatorHead::new(&(p / "discriminator_predictions"), config);
ElectraDiscriminator { electra, discriminator_head }
ElectraDiscriminator {
electra,
discriminator_head,
}
}
/// Forward pass through the model
@ -611,21 +696,26 @@ impl ElectraDiscriminator {
/// None,
/// false)
/// });
///
/// ```
///
pub fn forward_t(&self,
pub fn forward_t(
&self,
input_ids: Option<Tensor>,
mask: Option<Tensor>,
token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool)
-> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
let (hidden_states,
all_hidden_states,
all_attentions) = self.electra
.forward_t(input_ids, mask, token_type_ids, position_ids, input_embeds, train)
train: bool,
) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
let (hidden_states, all_hidden_states, all_attentions) = self
.electra
.forward_t(
input_ids,
mask,
token_type_ids,
position_ids,
input_embeds,
train,
)
.unwrap();
let probabilities = self.discriminator_head.forward(&hidden_states).sigmoid();
(probabilities, all_hidden_states, all_attentions)
@ -656,24 +746,37 @@ impl ElectraForTokenClassification {
/// # Example
///
/// ```no_run
/// use rust_bert::electra::{ElectraForTokenClassification, ElectraConfig};
/// use tch::{nn, Device};
/// use rust_bert::electra::{ElectraConfig, ElectraForTokenClassification};
/// use rust_bert::Config;
/// use std::path::Path;
/// use tch::{nn, Device};
/// let config_path = Path::new("path/to/config.json");
/// let device = Device::Cpu;
/// let p = nn::VarStore::new(device);
/// let config = ElectraConfig::from_file(config_path);
/// let electra_model: ElectraForTokenClassification = ElectraForTokenClassification::new(&p.root(), &config);
/// let electra_model: ElectraForTokenClassification =
/// ElectraForTokenClassification::new(&p.root(), &config);
/// ```
///
pub fn new(p: &nn::Path, config: &ElectraConfig) -> ElectraForTokenClassification {
let electra = ElectraModel::new(&(p / "electra"), config);
let dropout = Dropout::new(config.hidden_dropout_prob);
let num_labels = config.id2label.as_ref().expect("id2label must be provided for classifiers").len() as i64;
let classifier = nn::linear(&(p / "classifier"), config.hidden_size, num_labels, Default::default());
let num_labels = config
.id2label
.as_ref()
.expect("id2label must be provided for classifiers")
.len() as i64;
let classifier = nn::linear(
&(p / "classifier"),
config.hidden_size,
num_labels,
Default::default(),
);
ElectraForTokenClassification { electra, dropout, classifier }
ElectraForTokenClassification {
electra,
dropout,
classifier,
}
}
/// Forward pass through the model
@ -721,21 +824,26 @@ impl ElectraForTokenClassification {
/// None,
/// false)
/// });
///
/// ```
///
pub fn forward_t(&self,
pub fn forward_t(
&self,
input_ids: Option<Tensor>,
mask: Option<Tensor>,
token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool)
-> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
let (hidden_states,
all_hidden_states,
all_attentions) = self.electra
.forward_t(input_ids, mask, token_type_ids, position_ids, input_embeds, train)
train: bool,
) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
let (hidden_states, all_hidden_states, all_attentions) = self
.electra
.forward_t(
input_ids,
mask,
token_type_ids,
position_ids,
input_embeds,
train,
)
.unwrap();
let output = hidden_states
.apply_t(&self.dropout, train)

View File

@ -12,10 +12,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use tch::{nn, Tensor, Kind};
use crate::common::dropout::Dropout;
use crate::electra::electra::ElectraConfig;
use tch::nn::{EmbeddingConfig, embedding};
use tch::nn::{embedding, EmbeddingConfig};
use tch::{nn, Kind, Tensor};
#[derive(Debug)]
/// # Embeddings implementation for Electra model
@ -34,49 +34,77 @@ impl ElectraEmbeddings {
..Default::default()
};
let word_embeddings: nn::Embedding = embedding(p / "word_embeddings",
let word_embeddings: nn::Embedding = embedding(
p / "word_embeddings",
config.vocab_size,
config.embedding_size,
embedding_config);
embedding_config,
);
let position_embeddings: nn::Embedding = embedding(p / "position_embeddings",
let position_embeddings: nn::Embedding = embedding(
p / "position_embeddings",
config.max_position_embeddings,
config.embedding_size,
Default::default());
Default::default(),
);
let token_type_embeddings: nn::Embedding = embedding(p / "token_type_embeddings",
let token_type_embeddings: nn::Embedding = embedding(
p / "token_type_embeddings",
config.type_vocab_size,
config.embedding_size,
Default::default());
Default::default(),
);
let layer_norm_eps = match config.layer_norm_eps {
Some(value) => value,
None => 1e-12
None => 1e-12,
};
let layer_norm_config = nn::LayerNormConfig { eps: layer_norm_eps, ..Default::default() };
let layer_norm: nn::LayerNorm = nn::layer_norm(p / "LayerNorm", vec![config.embedding_size], layer_norm_config);
let layer_norm_config = nn::LayerNormConfig {
eps: layer_norm_eps,
..Default::default()
};
let layer_norm: nn::LayerNorm = nn::layer_norm(
p / "LayerNorm",
vec![config.embedding_size],
layer_norm_config,
);
let dropout: Dropout = Dropout::new(config.hidden_dropout_prob);
ElectraEmbeddings { word_embeddings, position_embeddings, token_type_embeddings, layer_norm, dropout}
ElectraEmbeddings {
word_embeddings,
position_embeddings,
token_type_embeddings,
layer_norm,
dropout,
}
}
pub fn forward_t(&self,
pub fn forward_t(
&self,
input_ids: Option<Tensor>,
token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool) -> Result<Tensor, &'static str> {
train: bool,
) -> Result<Tensor, &'static str> {
let (input_embeddings, input_shape) = match input_ids {
Some(input_value) => match input_embeds {
Some(_) => { return Err("Only one of input ids or input embeddings may be set"); }
None => (input_value.apply_t(&self.word_embeddings, train), input_value.size())
Some(_) => {
return Err("Only one of input ids or input embeddings may be set");
}
None => (
input_value.apply_t(&self.word_embeddings, train),
input_value.size(),
),
},
None => match input_embeds {
Some(embeds) => {
let size = vec!(embeds.size()[0], embeds.size()[1]);
let size = vec![embeds.size()[0], embeds.size()[1]];
(embeds, size)
},
None => { return Err("Only one of input ids or input embeddings may be set"); }
}
None => {
return Err("Only one of input ids or input embeddings may be set");
}
},
};
let seq_length = input_embeddings.as_ref().size()[1].to_owned();
@ -84,19 +112,22 @@ impl ElectraEmbeddings {
let position_ids = match position_ids {
Some(value) => value,
None => Tensor::arange(seq_length, (Kind::Int64, input_embeddings.device()))
.unsqueeze(0).
expand(&input_shape, true)
.unsqueeze(0)
.expand(&input_shape, true),
};
let token_type_ids = match token_type_ids {
Some(value) => value,
None => Tensor::zeros(&input_shape, (Kind::Int64, input_embeddings.device()))
None => Tensor::zeros(&input_shape, (Kind::Int64, input_embeddings.device())),
};
let position_embeddings = position_ids.apply(&self.position_embeddings);
let token_type_embeddings = token_type_ids.apply(&self.token_type_embeddings);
let input_embeddings: Tensor = input_embeddings + position_embeddings + token_type_embeddings;
Ok(input_embeddings.apply(&self.layer_norm).apply_t(&self.dropout, train))
let input_embeddings: Tensor =
input_embeddings + position_embeddings + token_type_embeddings;
Ok(input_embeddings
.apply(&self.layer_norm)
.apply_t(&self.dropout, train))
}
}

View File

@ -28,13 +28,19 @@
//! use rust_tokenizers::BertTokenizer;
//! use tch::{nn, Device};
//! # use std::path::PathBuf;
//! use rust_bert::electra::{ElectraForMaskedLM, ElectraConfig};
//! use rust_bert::electra::{ElectraConfig, ElectraForMaskedLM};
//! use rust_bert::resources::{download_resource, LocalResource, Resource};
//! use rust_bert::Config;
//! use rust_bert::resources::{Resource, download_resource, LocalResource};
//!
//! let config_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/config.json")});
//! let vocab_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/vocab.txt")});
//! let weights_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/model.ot")});
//! let config_resource = Resource::Local(LocalResource {
//! local_path: PathBuf::from("path/to/config.json"),
//! });
//! let vocab_resource = Resource::Local(LocalResource {
//! local_path: PathBuf::from("path/to/vocab.txt"),
//! });
//! let weights_resource = Resource::Local(LocalResource {
//! local_path: PathBuf::from("path/to/model.ot"),
//! });
//! let config_path = download_resource(&config_resource)?;
//! let vocab_path = download_resource(&vocab_resource)?;
//! let weights_path = download_resource(&weights_resource)?;
@ -49,9 +55,11 @@
//! # }
//! ```
mod embeddings;
mod electra;
mod embeddings;
pub use electra::{ElectraModelResources, ElectraVocabResources, ElectraConfigResources, ElectraConfig,
ElectraModel, ElectraDiscriminator, ElectraForMaskedLM, ElectraDiscriminatorHead, ElectraGeneratorHead, ElectraForTokenClassification};
pub use electra::{
ElectraConfig, ElectraConfigResources, ElectraDiscriminator, ElectraDiscriminatorHead,
ElectraForMaskedLM, ElectraForTokenClassification, ElectraGeneratorHead, ElectraModel,
ElectraModelResources, ElectraVocabResources,
};

View File

@ -12,11 +12,11 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use tch::{Tensor, nn};
use crate::common::dropout::Dropout;
use crate::gpt2::gpt2::Gpt2Config;
use tch::kind::Kind::Float;
use tch::nn::{Init, Module};
use tch::{nn, Tensor};
#[derive(Debug)]
pub struct GPTConv1D {
@ -26,7 +26,14 @@ pub struct GPTConv1D {
impl GPTConv1D {
pub fn new(p: &nn::Path, nf: i64, nx: i64) -> GPTConv1D {
let weight = p.var("weight", &[nx, nf], Init::Randn { mean: 0., stdev: 0.02 });
let weight = p.var(
"weight",
&[nx, nf],
Init::Randn {
mean: 0.,
stdev: 0.02,
},
);
let bias = p.var("bias", &[nf], Init::Const(0.));
GPTConv1D { weight, bias }
}
@ -38,7 +45,6 @@ impl Module for GPTConv1D {
}
}
pub struct Attention {
bias: Tensor,
c_attn: GPTConv1D,
@ -62,23 +68,27 @@ impl Attention {
let attn_pdrop = match config.attn_pdrop {
Some(value) => value,
None => 0.1
None => 0.1,
};
let resid_pdrop = match config.resid_pdrop {
Some(value) => value,
None => 0.1
None => 0.1,
};
let output_attentions = match config.output_attentions {
Some(value) => value,
None => false
None => false,
};
let attn_dropout = Dropout::new(attn_pdrop);
let resid_dropout = Dropout::new(resid_pdrop);
assert_eq!(config.n_embd % config.n_head, 0, "Attention hidden states not a multiple of the number of heads");
assert_eq!(
config.n_embd % config.n_head,
0,
"Attention hidden states not a multiple of the number of heads"
);
let dim_per_head = config.n_embd / config.n_head;
Attention {
@ -105,19 +115,31 @@ impl Attention {
}
fn flatten(&self, x: Tensor) -> Tensor {
x.transpose(1, 2).contiguous().view((x.size()[0], -1, &self.n_head * self.dim_per_head))
x.transpose(1, 2)
.contiguous()
.view((x.size()[0], -1, &self.n_head * self.dim_per_head))
}
fn attention(&self, q: &Tensor, k: &Tensor, v: &Tensor, attention_mask: &Option<Tensor>, train: bool)
-> (Tensor, Option<Tensor>) {
fn attention(
&self,
q: &Tensor,
k: &Tensor,
v: &Tensor,
attention_mask: &Option<Tensor>,
train: bool,
) -> (Tensor, Option<Tensor>) {
let mut w = q.matmul(&k);
if self.scale { w = w / (*v.size().last().unwrap() as f64).sqrt(); }
if self.scale {
w = w / (*v.size().last().unwrap() as f64).sqrt();
}
let (nd, ns) = (w.size()[2], w.size()[3]);
let b = self.bias.narrow(2, ns - nd, nd).narrow(3, 0, ns);
let mut w: Tensor = w * &b + 1e4 * (&b - 1);
if let Some(mask) = attention_mask { w = w + mask; }
if let Some(mask) = attention_mask {
w = w + mask;
}
w = w.softmax(-1, Float).apply_t(&self.attn_dropout, train);
let output = w.matmul(&v);
@ -128,15 +150,19 @@ impl Attention {
}
}
pub fn forward_t(&self, x: &Tensor, layer_past: &Option<Tensor>, attention_mask: &Option<Tensor>, train: bool)
-> (Tensor, Tensor, Option<Tensor>) {
pub fn forward_t(
&self,
x: &Tensor,
layer_past: &Option<Tensor>,
attention_mask: &Option<Tensor>,
train: bool,
) -> (Tensor, Tensor, Option<Tensor>) {
let x = x.apply(&self.c_attn).split(self.n_state, 2);
let (query, key, value) =
(
let (query, key, value) = (
self.split_heads(&x[0], false),
self.split_heads(&x[1], true),
self.split_heads(&x[2], false)
self.split_heads(&x[2], false),
);
let (key, value) = match layer_past {
Some(past) => {
@ -144,12 +170,15 @@ impl Attention {
let value = Tensor::cat(&[past.get(1), value], -2);
(key, value)
}
None => (key, value)
None => (key, value),
};
let present = Tensor::stack(&[key.transpose(-2, -1), value.copy()], 0);
let (a, attentions) = self.attention(&query, &key, &value, &attention_mask, train);
let a = self.flatten(a).apply(&self.c_proj).apply_t(&self.resid_dropout, train);
let a = self
.flatten(a)
.apply(&self.c_proj)
.apply_t(&self.resid_dropout, train);
(a, present, attentions)
}

View File

@ -12,16 +12,16 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use serde::{Deserialize, Serialize};
use tch::{nn, Tensor};
use crate::common::dropout::Dropout;
use tch::nn::embedding;
use crate::common::linear::{linear_no_bias, LinearNoBias};
use crate::gpt2::transformer::Block;
use tch::kind::Kind::Int64;
use std::borrow::BorrowMut;
use crate::common::linear::{LinearNoBias, linear_no_bias};
use crate::pipelines::generation::{Cache, LMHeadModel};
use crate::Config;
use crate::pipelines::generation::{LMHeadModel, Cache};
use serde::{Deserialize, Serialize};
use std::borrow::BorrowMut;
use tch::kind::Kind::Int64;
use tch::nn::embedding;
use tch::{nn, Tensor};
/// # GPT2 Pretrained model weight files
pub struct Gpt2ModelResources;
@ -37,54 +37,114 @@ pub struct Gpt2MergesResources;
impl Gpt2ModelResources {
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
pub const GPT2: (&'static str, &'static str) = ("gpt2/model.ot", "https://cdn.huggingface.co/gpt2-rust_model.ot");
pub const GPT2: (&'static str, &'static str) = (
"gpt2/model.ot",
"https://cdn.huggingface.co/gpt2-rust_model.ot",
);
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
pub const GPT2_MEDIUM: (&'static str, &'static str) = ("gpt2-medium/model.ot", "https://cdn.huggingface.co/gpt2-medium-rust_model.ot");
pub const GPT2_MEDIUM: (&'static str, &'static str) = (
"gpt2-medium/model.ot",
"https://cdn.huggingface.co/gpt2-medium-rust_model.ot",
);
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
pub const GPT2_LARGE: (&'static str, &'static str) = ("gpt2-large/model.ot", "https://cdn.huggingface.co/gpt2-large-rust_model.ot");
pub const GPT2_LARGE: (&'static str, &'static str) = (
"gpt2-large/model.ot",
"https://cdn.huggingface.co/gpt2-large-rust_model.ot",
);
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
pub const GPT2_XL: (&'static str, &'static str) = ("gpt2-xl/model.ot", "https://cdn.huggingface.co/gpt2-xl-rust_model.ot");
pub const GPT2_XL: (&'static str, &'static str) = (
"gpt2-xl/model.ot",
"https://cdn.huggingface.co/gpt2-xl-rust_model.ot",
);
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
pub const DISTIL_GPT2: (&'static str, &'static str) = ("distilgpt2/model.ot", "https://cdn.huggingface.co/distilgpt2-rust_model.ot");
pub const DISTIL_GPT2: (&'static str, &'static str) = (
"distilgpt2/model.ot",
"https://cdn.huggingface.co/distilgpt2-rust_model.ot",
);
}
impl Gpt2ConfigResources {
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
pub const GPT2: (&'static str, &'static str) = ("gpt2/config.json", "https://cdn.huggingface.co/gpt2-config.json");
pub const GPT2: (&'static str, &'static str) = (
"gpt2/config.json",
"https://cdn.huggingface.co/gpt2-config.json",
);
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
pub const GPT2_MEDIUM: (&'static str, &'static str) = ("gpt2-medium/config.json", "https://cdn.huggingface.co/gpt2-medium-config.json");
pub const GPT2_MEDIUM: (&'static str, &'static str) = (
"gpt2-medium/config.json",
"https://cdn.huggingface.co/gpt2-medium-config.json",
);
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
pub const GPT2_LARGE: (&'static str, &'static str) = ("gpt2-large/config.json", "https://cdn.huggingface.co/gpt2-large-config.json");
pub const GPT2_LARGE: (&'static str, &'static str) = (
"gpt2-large/config.json",
"https://cdn.huggingface.co/gpt2-large-config.json",
);
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
pub const GPT2_XL: (&'static str, &'static str) = ("gpt2-xl/config.json", "https://cdn.huggingface.co/gpt2-xl-config.json");
pub const GPT2_XL: (&'static str, &'static str) = (
"gpt2-xl/config.json",
"https://cdn.huggingface.co/gpt2-xl-config.json",
);
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
pub const DISTIL_GPT2: (&'static str, &'static str) = ("distilgpt2/config.json", "https://cdn.huggingface.co/distilgpt2-config.json");
pub const DISTIL_GPT2: (&'static str, &'static str) = (
"distilgpt2/config.json",
"https://cdn.huggingface.co/distilgpt2-config.json",
);
}
impl Gpt2VocabResources {
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
pub const GPT2: (&'static str, &'static str) = ("gpt2/vocab.txt", "https://cdn.huggingface.co/gpt2-vocab.json");
pub const GPT2: (&'static str, &'static str) = (
"gpt2/vocab.txt",
"https://cdn.huggingface.co/gpt2-vocab.json",
);
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
pub const GPT2_MEDIUM: (&'static str, &'static str) = ("gpt2-medium/vocab.txt", "https://cdn.huggingface.co/gpt2-medium-vocab.json");
pub const GPT2_MEDIUM: (&'static str, &'static str) = (
"gpt2-medium/vocab.txt",
"https://cdn.huggingface.co/gpt2-medium-vocab.json",
);
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
pub const GPT2_LARGE: (&'static str, &'static str) = ("gpt2-large/vocab.txt", "https://cdn.huggingface.co/gpt2-large-vocab.json");
pub const GPT2_LARGE: (&'static str, &'static str) = (
"gpt2-large/vocab.txt",
"https://cdn.huggingface.co/gpt2-large-vocab.json",
);
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
pub const GPT2_XL: (&'static str, &'static str) = ("gpt2-xl/vocab.txt", "https://cdn.huggingface.co/gpt2-xl-vocab.json");
pub const GPT2_XL: (&'static str, &'static str) = (
"gpt2-xl/vocab.txt",
"https://cdn.huggingface.co/gpt2-xl-vocab.json",
);
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
pub const DISTIL_GPT2: (&'static str, &'static str) = ("distilgpt2/vocab.txt", "https://cdn.huggingface.co/distilgpt2-vocab.json");
pub const DISTIL_GPT2: (&'static str, &'static str) = (
"distilgpt2/vocab.txt",
"https://cdn.huggingface.co/distilgpt2-vocab.json",
);
}
impl Gpt2MergesResources {
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
pub const GPT2: (&'static str, &'static str) = ("gpt2/merges.txt", "https://cdn.huggingface.co/gpt2-merges.txt");
pub const GPT2: (&'static str, &'static str) = (
"gpt2/merges.txt",
"https://cdn.huggingface.co/gpt2-merges.txt",
);
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
pub const GPT2_MEDIUM: (&'static str, &'static str) = ("gpt2-medium/merges.txt", "https://cdn.huggingface.co/gpt2-medium-merges.txt");
pub const GPT2_MEDIUM: (&'static str, &'static str) = (
"gpt2-medium/merges.txt",
"https://cdn.huggingface.co/gpt2-medium-merges.txt",
);
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
pub const GPT2_LARGE: (&'static str, &'static str) = ("gpt2-large/merges.txt", "https://cdn.huggingface.co/gpt2-large-merges.txt");
pub const GPT2_LARGE: (&'static str, &'static str) = (
"gpt2-large/merges.txt",
"https://cdn.huggingface.co/gpt2-large-merges.txt",
);
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
pub const GPT2_XL: (&'static str, &'static str) = ("gpt2-xl/merges.txt", "https://cdn.huggingface.co/gpt2-xl-merges.txt");
pub const GPT2_XL: (&'static str, &'static str) = (
"gpt2-xl/merges.txt",
"https://cdn.huggingface.co/gpt2-xl-merges.txt",
);
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
pub const DISTIL_GPT2: (&'static str, &'static str) = ("distilgpt2/merges.txt", "https://cdn.huggingface.co/distilgpt2-merges.txt");
pub const DISTIL_GPT2: (&'static str, &'static str) = (
"distilgpt2/merges.txt",
"https://cdn.huggingface.co/distilgpt2-merges.txt",
);
}
#[allow(non_camel_case_types)]
@ -156,10 +216,10 @@ impl Gpt2Model {
/// # Example
///
/// ```no_run
/// use tch::{nn, Device};
/// use rust_bert::gpt2::{Gpt2Config, Gpt2Model};
/// use rust_bert::Config;
/// use std::path::Path;
/// use rust_bert::gpt2::{Gpt2Config, Gpt2Model};
/// use tch::{nn, Device};
///
/// let config_path = Path::new("path/to/config.json");
/// let device = Device::Cpu;
@ -167,37 +227,58 @@ impl Gpt2Model {
/// let config = Gpt2Config::from_file(config_path);
/// let gpt2: Gpt2Model = Gpt2Model::new(&(&p.root() / "gpt2"), &config);
/// ```
///
pub fn new(p: &nn::Path, config: &Gpt2Config) -> Gpt2Model {
let p = &(p / "transformer");
let wte = embedding(&(p / "wte"), config.vocab_size, config.n_embd, Default::default());
let wpe = embedding(&(p / "wpe"), config.n_positions, config.n_embd, Default::default());
let wte = embedding(
&(p / "wte"),
config.vocab_size,
config.n_embd,
Default::default(),
);
let wpe = embedding(
&(p / "wpe"),
config.n_positions,
config.n_embd,
Default::default(),
);
let embd_pdrop = match config.embd_pdrop {
Some(value) => value,
None => 0.1
None => 0.1,
};
let drop = Dropout::new(embd_pdrop);
let layer_norm_config = nn::LayerNormConfig { eps: config.layer_norm_epsilon, ..Default::default() };
let layer_norm_config = nn::LayerNormConfig {
eps: config.layer_norm_epsilon,
..Default::default()
};
let ln_f = nn::layer_norm(p / "ln_f", vec![config.n_embd], layer_norm_config);
let mut h: Vec<Block> = vec!();
let mut h: Vec<Block> = vec![];
let h_path = &(p / "h");
for layer_index in 0..config.n_layer {
h.push(Block::new(&(h_path / layer_index), config, true));
};
}
let output_attentions = match config.output_attentions {
Some(value) => value,
None => false
None => false,
};
let output_past = match config.output_past {
Some(value) => value,
None => true
None => true,
};
let output_hidden_states = match config.output_hidden_states {
Some(value) => value,
None => false
None => false,
};
Gpt2Model { wte, wpe, drop, ln_f, h, output_past, output_hidden_states, output_attentions }
Gpt2Model {
wte,
wpe,
drop,
ln_f,
h,
output_past,
output_hidden_states,
output_attentions,
}
}
/// Forward pass through the model
@ -226,7 +307,7 @@ impl Gpt2Model {
/// # use rust_bert::Config;
/// # use std::path::Path;
/// # use tch::kind::Kind::{Int64, Double};
/// use rust_bert::gpt2::{Gpt2Model, Gpt2Config};
/// use rust_bert::gpt2::{Gpt2Config, Gpt2Model};
/// # let config_path = Path::new("path/to/config.json");
/// # let vocab_path = Path::new("path/to/vocab.txt");
/// # let device = Device::Cpu;
@ -237,48 +318,86 @@ impl Gpt2Model {
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
/// let mut past: Vec<Tensor> = Vec::with_capacity(config.n_layer as usize);
/// for _ in 0..config.n_layer as usize {
/// past.push(Tensor::rand(&[2, batch_size, config.n_head, past_sequence_length, config.n_embd / config.n_head], (Double, device)))
/// past.push(Tensor::rand(
/// &[
/// 2,
/// batch_size,
/// config.n_head,
/// past_sequence_length,
/// config.n_embd / config.n_head,
/// ],
/// (Double, device),
/// ))
/// }
/// let attention_mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
/// let token_type_ids = Tensor::ones(&[batch_size, sequence_length], (Int64, device));
/// let position_ids = Tensor::arange(sequence_length, (Int64, device)).expand(&[batch_size, sequence_length], true);
/// let position_ids = Tensor::arange(sequence_length, (Int64, device))
/// .expand(&[batch_size, sequence_length], true);
///
/// let (output, past, hidden_states, attentions) = no_grad(|| {
/// gpt2_model
/// .forward_t(&Some(input_tensor),
/// .forward_t(
/// &Some(input_tensor),
/// &Some(past),
/// &Some(attention_mask),
/// &Some(token_type_ids),
/// &Some(position_ids),
/// &None,
/// false).unwrap()
/// false,
/// )
/// .unwrap()
/// });
///
/// ```
///
pub fn forward_t(&self,
pub fn forward_t(
&self,
input_ids: &Option<Tensor>,
layer_past: &Option<Vec<Tensor>>,
attention_mask: &Option<Tensor>,
token_type_ids: &Option<Tensor>,
position_ids: &Option<Tensor>,
input_embeds: &Option<Tensor>,
train: bool) -> Result<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
train: bool,
) -> Result<
(
Tensor,
Option<Vec<Tensor>>,
Option<Vec<Tensor>>,
Option<Vec<Tensor>>,
),
&'static str,
> {
let (input_embeddings, seq_length) = match input_ids {
Some(input_value) => match input_embeds {
Some(_) => { return Err("Only one of input ids or input embeddings may be set"); }
None => (input_value.apply(&self.wte), *input_value.size().last().unwrap())
Some(_) => {
return Err("Only one of input ids or input embeddings may be set");
}
None => (
input_value.apply(&self.wte),
*input_value.size().last().unwrap(),
),
},
None => match input_embeds {
Some(embeds) => (embeds.copy(), embeds.size()[1]),
None => { return Err("At least one of input ids or input embeddings must be set"); }
None => {
return Err("At least one of input ids or input embeddings must be set");
}
},
};
let (layer_past, layer_past_length) = match layer_past {
Some(value) => {
assert_eq!(value.len(), self.h.len(), "Past activations vector must be of length equal to the number of layers");
(value.iter().map(|v| Some(v.copy())).collect::<Vec<Option<Tensor>>>(), value[0].size()[3])
assert_eq!(
value.len(),
self.h.len(),
"Past activations vector must be of length equal to the number of layers"
);
(
value
.iter()
.map(|v| Some(v.copy()))
.collect::<Vec<Option<Tensor>>>(),
value[0].size()[3],
)
}
None => {
let mut out = Vec::with_capacity(self.h.len());
@ -289,31 +408,45 @@ impl Gpt2Model {
let position_ids = match position_ids {
Some(value) => value.copy(),
None => Tensor::arange1(layer_past_length, seq_length + layer_past_length, (Int64, input_embeddings.device())).unsqueeze(0)
None => Tensor::arange1(
layer_past_length,
seq_length + layer_past_length,
(Int64, input_embeddings.device()),
)
.unsqueeze(0),
};
let attention_mask: Option<Tensor> = match attention_mask {
Some(value) => {
Some(
Some(value) => Some(
(value
.view((input_embeddings.size()[0], -1))
.unsqueeze(1)
.unsqueeze(2)
- 1.0
) * 10000.0)
}
None => None
- 1.0)
* 10000.0,
),
None => None,
};
let position_embeds = position_ids.apply(&self.wpe);
let token_type_embeds = match token_type_ids {
Some(value) => value.apply(&self.wte),
None => Tensor::zeros_like(&position_embeds)
None => Tensor::zeros_like(&position_embeds),
};
let mut hidden_state: Tensor =
(input_embeddings + position_embeds + token_type_embeds).apply_t(&self.drop, train);
let mut all_presents: Option<Vec<Tensor>> =
if self.output_past { Some(vec![]) } else { None };
let mut all_hidden_states: Option<Vec<Tensor>> = if self.output_hidden_states {
Some(vec![])
} else {
None
};
let mut all_attentions: Option<Vec<Tensor>> = if self.output_attentions {
Some(vec![])
} else {
None
};
let mut hidden_state: Tensor = (input_embeddings + position_embeds + token_type_embeds).apply_t(&self.drop, train);
let mut all_presents: Option<Vec<Tensor>> = if self.output_past { Some(vec!()) } else { None };
let mut all_hidden_states: Option<Vec<Tensor>> = if self.output_hidden_states { Some(vec!()) } else { None };
let mut all_attentions: Option<Vec<Tensor>> = if self.output_attentions { Some(vec!()) } else { None };
let mut layer_iter = self.h.iter().zip(layer_past);
loop {
@ -333,11 +466,16 @@ impl Gpt2Model {
attentions.push(temp.2.as_ref().unwrap().copy());
};
}
None => break
};
None => break,
};
}
Ok((hidden_state.apply(&self.ln_f), all_presents, all_hidden_states, all_attentions))
Ok((
hidden_state.apply(&self.ln_f),
all_presents,
all_hidden_states,
all_attentions,
))
}
}
@ -362,10 +500,10 @@ impl GPT2LMHeadModel {
/// # Example
///
/// ```no_run
/// use tch::{nn, Device};
/// use rust_bert::gpt2::{GPT2LMHeadModel, Gpt2Config};
/// use rust_bert::Config;
/// use std::path::Path;
/// use rust_bert::gpt2::{Gpt2Config, GPT2LMHeadModel};
/// use tch::{nn, Device};
///
/// let config_path = Path::new("path/to/config.json");
/// let device = Device::Cpu;
@ -373,11 +511,18 @@ impl GPT2LMHeadModel {
/// let config = Gpt2Config::from_file(config_path);
/// let gpt2: GPT2LMHeadModel = GPT2LMHeadModel::new(&(&p.root() / "gpt2"), &config);
/// ```
///
pub fn new(p: &nn::Path, config: &Gpt2Config) -> GPT2LMHeadModel {
let transformer = Gpt2Model::new(&p, config);
let lm_head = linear_no_bias(&(p / "lm_head"), config.n_embd, config.vocab_size, Default::default());
GPT2LMHeadModel { transformer, lm_head }
let lm_head = linear_no_bias(
&(p / "lm_head"),
config.n_embd,
config.vocab_size,
Default::default(),
);
GPT2LMHeadModel {
transformer,
lm_head,
}
}
}
@ -412,8 +557,8 @@ impl LMHeadModel for GPT2LMHeadModel {
/// # use rust_bert::Config;
/// # use std::path::Path;
/// # use tch::kind::Kind::{Int64, Double};
/// use rust_bert::gpt2::{Gpt2Config, GPT2LMHeadModel};
/// use rust_bert::pipelines::generation::{LMHeadModel, Cache};
/// use rust_bert::gpt2::{GPT2LMHeadModel, Gpt2Config};
/// use rust_bert::pipelines::generation::{Cache, LMHeadModel};
/// # let config_path = Path::new("path/to/config.json");
/// # let vocab_path = Path::new("path/to/vocab.txt");
/// # let device = Device::Cpu;
@ -424,15 +569,26 @@ impl LMHeadModel for GPT2LMHeadModel {
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
/// let mut past: Vec<Tensor> = Vec::with_capacity(config.n_layer as usize);
/// for _ in 0..config.n_layer as usize {
/// past.push(Tensor::rand(&[2, batch_size, config.n_head, past_sequence_length, config.n_embd / config.n_head], (Double, device)))
/// past.push(Tensor::rand(
/// &[
/// 2,
/// batch_size,
/// config.n_head,
/// past_sequence_length,
/// config.n_embd / config.n_head,
/// ],
/// (Double, device),
/// ))
/// }
/// let attention_mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
/// let token_type_ids = Tensor::ones(&[batch_size, sequence_length], (Int64, device));
/// let position_ids = Tensor::arange(sequence_length, (Int64, device)).expand(&[batch_size, sequence_length], true);
/// let position_ids = Tensor::arange(sequence_length, (Int64, device))
/// .expand(&[batch_size, sequence_length], true);
///
/// let (output, _, past, hidden_states, attentions) = no_grad(|| {
/// gpt2_model
/// .forward_t(&Some(input_tensor),
/// .forward_t(
/// &Some(input_tensor),
/// Cache::GPT2Cache(Some(past)),
/// &Some(attention_mask),
/// &Some(token_type_ids),
@ -440,12 +596,13 @@ impl LMHeadModel for GPT2LMHeadModel {
/// &None,
/// None,
/// &None,
/// false).unwrap()
/// false,
/// )
/// .unwrap()
/// });
///
/// ```
///
fn forward_t(&self,
fn forward_t(
&self,
input_ids: &Option<Tensor>,
layer_past: Cache,
attention_mask: &Option<Tensor>,
@ -454,29 +611,46 @@ impl LMHeadModel for GPT2LMHeadModel {
input_embeds: &Option<Tensor>,
_encoder_outputs: Option<&Tensor>,
_decoder_input_ids: &Option<Tensor>,
train: bool) -> Result<(Tensor, Option<Tensor>, Cache, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
let (output,
past,
all_hidden_states,
all_attentions) = match layer_past {
Cache::GPT2Cache(layer_past) => Ok(self.transformer.forward_t(input_ids,
train: bool,
) -> Result<
(
Tensor,
Option<Tensor>,
Cache,
Option<Vec<Tensor>>,
Option<Vec<Tensor>>,
),
&'static str,
> {
let (output, past, all_hidden_states, all_attentions) = match layer_past {
Cache::GPT2Cache(layer_past) => Ok(self.transformer.forward_t(
input_ids,
&layer_past,
attention_mask,
token_type_ids,
position_ids,
input_embeds,
train)?),
Cache::None => Ok(self.transformer.forward_t(input_ids,
train,
)?),
Cache::None => Ok(self.transformer.forward_t(
input_ids,
&None,
attention_mask,
token_type_ids,
position_ids,
input_embeds,
train)?),
_ => Err("Cache not compatible with GPT2 model")
train,
)?),
_ => Err("Cache not compatible with GPT2 model"),
}?;
let lm_logits = output.apply(&self.lm_head);
Ok((lm_logits, None, Cache::GPT2Cache(past), all_hidden_states, all_attentions))
Ok((
lm_logits,
None,
Cache::GPT2Cache(past),
all_hidden_states,
all_attentions,
))
}
}

View File

@ -19,14 +19,22 @@
//! use rust_tokenizers::Gpt2Tokenizer;
//! use tch::{nn, Device};
//! # use std::path::PathBuf;
//! use rust_bert::gpt2::{GPT2LMHeadModel, Gpt2Config};
//! use rust_bert::resources::{download_resource, LocalResource, Resource};
//! use rust_bert::Config;
//! use rust_bert::gpt2::{Gpt2Config, GPT2LMHeadModel};
//! use rust_bert::resources::{Resource, download_resource, LocalResource};
//!
//! let config_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/config.json")});
//! let vocab_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/vocab.txt")});
//! let merges_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/vocab.txt")});
//! let weights_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/model.ot")});
//! let config_resource = Resource::Local(LocalResource {
//! local_path: PathBuf::from("path/to/config.json"),
//! });
//! let vocab_resource = Resource::Local(LocalResource {
//! local_path: PathBuf::from("path/to/vocab.txt"),
//! });
//! let merges_resource = Resource::Local(LocalResource {
//! local_path: PathBuf::from("path/to/vocab.txt"),
//! });
//! let weights_resource = Resource::Local(LocalResource {
//! local_path: PathBuf::from("path/to/model.ot"),
//! });
//! let config_path = download_resource(&config_resource)?;
//! let vocab_path = download_resource(&vocab_resource)?;
//! let merges_path = download_resource(&merges_resource)?;
@ -34,7 +42,11 @@
//!
//! let device = Device::cuda_if_available();
//! let mut vs = nn::VarStore::new(device);
//! let tokenizer: Gpt2Tokenizer = Gpt2Tokenizer::from_file(vocab_path.to_str().unwrap(), merges_path.to_str().unwrap(), true);
//! let tokenizer: Gpt2Tokenizer = Gpt2Tokenizer::from_file(
//! vocab_path.to_str().unwrap(),
//! merges_path.to_str().unwrap(),
//! true,
//! );
//! let config = Gpt2Config::from_file(config_path);
//! let gpt2_model = GPT2LMHeadModel::new(&vs.root(), &config);
//! vs.load(weights_path)?;
@ -43,9 +55,11 @@
//! # }
//! ```
mod gpt2;
pub(crate) mod attention;
mod gpt2;
pub(crate) mod transformer;
pub use gpt2::{Gpt2ModelResources, Gpt2ConfigResources, Gpt2VocabResources, Gpt2MergesResources,
Gpt2Config, Gpt2Model, GptActivation, GPT2LMHeadModel};
pub use gpt2::{
GPT2LMHeadModel, Gpt2Config, Gpt2ConfigResources, Gpt2MergesResources, Gpt2Model,
Gpt2ModelResources, Gpt2VocabResources, GptActivation,
};

View File

@ -12,11 +12,11 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use crate::gpt2::attention::{GPTConv1D, Attention};
use tch::{Tensor, nn};
use crate::common::dropout::Dropout;
use crate::gpt2::gpt2::{Gpt2Config, GptActivation};
use crate::common::activations::{_gelu_new, _relu, _swish};
use crate::common::dropout::Dropout;
use crate::gpt2::attention::{Attention, GPTConv1D};
use crate::gpt2::gpt2::{Gpt2Config, GptActivation};
use tch::{nn, Tensor};
pub struct MLP {
c_fc: GPTConv1D,
@ -35,14 +35,19 @@ impl MLP {
GptActivation::relu => _relu,
GptActivation::swish => _swish,
},
None => _gelu_new
None => _gelu_new,
});
let resid_pdrop = match config.resid_pdrop {
Some(value) => value,
None => 0.1
None => 0.1,
};
let dropout = Dropout::new(resid_pdrop);
MLP { c_fc, c_proj, activation, dropout }
MLP {
c_fc,
c_proj,
activation,
dropout,
}
}
pub fn forward_t(&self, x: &Tensor, train: bool) -> Tensor {
@ -60,18 +65,33 @@ pub struct Block {
impl Block {
pub fn new(p: &nn::Path, config: &Gpt2Config, scale: bool) -> Block {
let layer_norm_config = nn::LayerNormConfig { eps: config.layer_norm_epsilon, ..Default::default() };
let layer_norm_config = nn::LayerNormConfig {
eps: config.layer_norm_epsilon,
..Default::default()
};
let ln_1 = nn::layer_norm(p / "ln_1", vec![config.n_embd], layer_norm_config);
let ln_2 = nn::layer_norm(p / "ln_2", vec![config.n_embd], layer_norm_config);
let attn = Attention::new(&(p / "attn"), config, scale);
let mlp = MLP::new(&(p / "mlp"), config);
Block { ln_1, attn, ln_2, mlp }
Block {
ln_1,
attn,
ln_2,
mlp,
}
}
pub fn forward_t(&self, x: &Tensor, layer_past: &Option<Tensor>, attention_mask: &Option<Tensor>, train: bool)
-> (Tensor, Tensor, Option<Tensor>) {
let (output, present, attentions) = self.attn.forward_t(&x.apply(&self.ln_1), layer_past, attention_mask, train);
pub fn forward_t(
&self,
x: &Tensor,
layer_past: &Option<Tensor>,
attention_mask: &Option<Tensor>,
train: bool,
) -> (Tensor, Tensor, Option<Tensor>) {
let (output, present, attentions) =
self.attn
.forward_t(&x.apply(&self.ln_1), layer_past, attention_mask, train);
let x = x + output;
let m = self.mlp.forward_t(&x.apply(&self.ln_2), train);
let x = x + m;

View File

@ -16,14 +16,14 @@
//!
//! More information on these can be found in the [`pipelines` module](./pipelines/index.html)
//! ```no_run
//! use rust_bert::pipelines::question_answering::{QuestionAnsweringModel, QaInput};
//! use rust_bert::pipelines::question_answering::{QaInput, QuestionAnsweringModel};
//!
//! # fn main() -> failure::Fallible<()> {
//! let qa_model = QuestionAnsweringModel::new(Default::default())?;
//!
//! let question = String::from("Where does Amy live ?");
//! let context = String::from("Amy lives in Amsterdam");
//! let answers = qa_model.predict(&vec!(QaInput { question, context }), 1, 32);
//! let answers = qa_model.predict(&vec![QaInput { question, context }], 1, 32);
//! # Ok(())
//! # }
//! ```
@ -55,19 +55,18 @@
//! - Set-up a virtual environment and install dependencies
//! - run the conversion script python /utils/download-dependencies_{MODEL_TO_DOWNLOAD}.py. The dependencies will be downloaded to the user's home directory, under ~/rustbert/{}
//! 3. Run the example cargo run --release
//!
pub mod distilbert;
pub mod bert;
pub mod roberta;
pub mod openai_gpt;
pub mod gpt2;
pub mod bart;
pub mod electra;
pub mod marian;
pub mod albert;
pub mod bart;
pub mod bert;
mod common;
pub mod distilbert;
pub mod electra;
pub mod gpt2;
pub mod marian;
pub mod openai_gpt;
pub mod pipelines;
pub mod roberta;
pub use common::Config;
pub use common::resources;
pub use common::Config;

View File

@ -11,10 +11,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use crate::bart::{BartModel, BartConfig, LayerState};
use tch::{Tensor, nn};
use crate::pipelines::generation::{LMHeadModel, Cache};
use crate::bart::{BartConfig, BartModel, LayerState};
use crate::pipelines::generation::{Cache, LMHeadModel};
use tch::nn::Init;
use tch::{nn, Tensor};
/// # Marian Pretrained model weight files
pub struct MarianModelResources;
@ -33,78 +33,174 @@ pub struct MarianPrefix;
impl MarianModelResources {
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. Modified with conversion to C-array format.
pub const ENGLISH2ROMANCE: (&'static str, &'static str) = ("marian-mt-en-ROMANCE/model.ot", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-ROMANCE/rust_model.ot");
pub const ENGLISH2ROMANCE: (&'static str, &'static str) = (
"marian-mt-en-ROMANCE/model.ot",
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-ROMANCE/rust_model.ot",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. Modified with conversion to C-array format.
pub const ROMANCE2ENGLISH: (&'static str, &'static str) = ("marian-mt-ROMANCE-en/model.ot", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-ROMANCE-en/rust_model.ot");
pub const ROMANCE2ENGLISH: (&'static str, &'static str) = (
"marian-mt-ROMANCE-en/model.ot",
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-ROMANCE-en/rust_model.ot",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. Modified with conversion to C-array format.
pub const ENGLISH2GERMAN: (&'static str, &'static str) = ("marian-mt-en-de/model.ot", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-de/rust_model.ot");
pub const ENGLISH2GERMAN: (&'static str, &'static str) = (
"marian-mt-en-de/model.ot",
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-de/rust_model.ot",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. Modified with conversion to C-array format.
pub const GERMAN2ENGLISH: (&'static str, &'static str) = ("marian-mt-de-en/model.ot", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-de-en/rust_model.ot");
pub const GERMAN2ENGLISH: (&'static str, &'static str) = (
"marian-mt-de-en/model.ot",
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-de-en/rust_model.ot",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. Modified with conversion to C-array format.
pub const ENGLISH2RUSSIAN: (&'static str, &'static str) = ("marian-mt-en-ru/model.ot", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-ru/rust_model.ot");
pub const ENGLISH2RUSSIAN: (&'static str, &'static str) = (
"marian-mt-en-ru/model.ot",
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-ru/rust_model.ot",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. Modified with conversion to C-array format.
pub const RUSSIAN2ENGLISH: (&'static str, &'static str) = ("marian-mt-ru-en/model.ot", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-ru-en/rust_model.ot");
pub const RUSSIAN2ENGLISH: (&'static str, &'static str) = (
"marian-mt-ru-en/model.ot",
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-ru-en/rust_model.ot",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. Modified with conversion to C-array format.
pub const FRENCH2GERMAN: (&'static str, &'static str) = ("marian-mt-fr-de/model.ot", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-fr-de/rust_model.ot");
pub const FRENCH2GERMAN: (&'static str, &'static str) = (
"marian-mt-fr-de/model.ot",
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-fr-de/rust_model.ot",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. Modified with conversion to C-array format.
pub const GERMAN2FRENCH: (&'static str, &'static str) = ("marian-mt-de-fr/model.ot", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-de-fr/rust_model.ot");
pub const GERMAN2FRENCH: (&'static str, &'static str) = (
"marian-mt-de-fr/model.ot",
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-de-fr/rust_model.ot",
);
}
impl MarianConfigResources {
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
pub const ENGLISH2ROMANCE: (&'static str, &'static str) = ("marian-mt-en-ROMANCE/config.json", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-ROMANCE/config.json");
pub const ENGLISH2ROMANCE: (&'static str, &'static str) = (
"marian-mt-en-ROMANCE/config.json",
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-ROMANCE/config.json",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
pub const ROMANCE2ENGLISH: (&'static str, &'static str) = ("marian-mt-ROMANCE-en/config.json", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-ROMANCE-en/config.json");
pub const ROMANCE2ENGLISH: (&'static str, &'static str) = (
"marian-mt-ROMANCE-en/config.json",
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-ROMANCE-en/config.json",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
pub const ENGLISH2GERMAN: (&'static str, &'static str) = ("marian-mt-en-de/config.json", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-de/config.json");
pub const ENGLISH2GERMAN: (&'static str, &'static str) = (
"marian-mt-en-de/config.json",
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-de/config.json",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
pub const GERMAN2ENGLISH: (&'static str, &'static str) = ("marian-mt-de-en/config.json", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-de-en/config.json");
pub const GERMAN2ENGLISH: (&'static str, &'static str) = (
"marian-mt-de-en/config.json",
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-de-en/config.json",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
pub const ENGLISH2RUSSIAN: (&'static str, &'static str) = ("marian-mt-en-ru/config.json", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-ru/config.json");
pub const ENGLISH2RUSSIAN: (&'static str, &'static str) = (
"marian-mt-en-ru/config.json",
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-ru/config.json",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
pub const RUSSIAN2ENGLISH: (&'static str, &'static str) = ("marian-mt-ru-en/config.json", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-ru-en/config.json");
pub const RUSSIAN2ENGLISH: (&'static str, &'static str) = (
"marian-mt-ru-en/config.json",
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-ru-en/config.json",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
pub const FRENCH2GERMAN: (&'static str, &'static str) = ("marian-mt-fr-de/config.json", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-fr-de/config.json");
pub const FRENCH2GERMAN: (&'static str, &'static str) = (
"marian-mt-fr-de/config.json",
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-fr-de/config.json",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
pub const GERMAN2FRENCH: (&'static str, &'static str) = ("marian-mt-de-fr/config.json", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-de-fr/config.json");
pub const GERMAN2FRENCH: (&'static str, &'static str) = (
"marian-mt-de-fr/config.json",
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-de-fr/config.json",
);
}
impl MarianVocabResources {
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
pub const ENGLISH2ROMANCE: (&'static str, &'static str) = ("marian-mt-en-ROMANCE/vocab.json", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-ROMANCE/vocab.json");
pub const ENGLISH2ROMANCE: (&'static str, &'static str) = (
"marian-mt-en-ROMANCE/vocab.json",
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-ROMANCE/vocab.json",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
pub const ROMANCE2ENGLISH: (&'static str, &'static str) = ("marian-mt-ROMANCE-en/vocab.json", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-ROMANCE-en/vocab.json");
pub const ROMANCE2ENGLISH: (&'static str, &'static str) = (
"marian-mt-ROMANCE-en/vocab.json",
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-ROMANCE-en/vocab.json",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
pub const ENGLISH2GERMAN: (&'static str, &'static str) = ("marian-mt-en-de/vocab.json", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-de/vocab.json");
pub const ENGLISH2GERMAN: (&'static str, &'static str) = (
"marian-mt-en-de/vocab.json",
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-de/vocab.json",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
pub const GERMAN2ENGLISH: (&'static str, &'static str) = ("marian-mt-de-en/vocab.json", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-de-en/vocab.json");
pub const GERMAN2ENGLISH: (&'static str, &'static str) = (
"marian-mt-de-en/vocab.json",
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-de-en/vocab.json",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
pub const ENGLISH2RUSSIAN: (&'static str, &'static str) = ("marian-mt-en-ru/vocab.json", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-ru/vocab.json");
pub const ENGLISH2RUSSIAN: (&'static str, &'static str) = (
"marian-mt-en-ru/vocab.json",
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-ru/vocab.json",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
pub const RUSSIAN2ENGLISH: (&'static str, &'static str) = ("marian-mt-ru-en/vocab.json", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-ru-en/vocab.json");
pub const RUSSIAN2ENGLISH: (&'static str, &'static str) = (
"marian-mt-ru-en/vocab.json",
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-ru-en/vocab.json",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
pub const FRENCH2GERMAN: (&'static str, &'static str) = ("marian-mt-fr-de/vocab.json", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-fr-de/vocab.json");
pub const FRENCH2GERMAN: (&'static str, &'static str) = (
"marian-mt-fr-de/vocab.json",
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-fr-de/vocab.json",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
pub const GERMAN2FRENCH: (&'static str, &'static str) = ("marian-mt-de-fr/vocab.json", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-de-fr/vocab.json");
pub const GERMAN2FRENCH: (&'static str, &'static str) = (
"marian-mt-de-fr/vocab.json",
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-de-fr/vocab.json",
);
}
impl MarianSpmResources {
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
pub const ENGLISH2ROMANCE: (&'static str, &'static str) = ("marian-mt-en-ROMANCE/spiece.model", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-ROMANCE/source.spm");
pub const ENGLISH2ROMANCE: (&'static str, &'static str) = (
"marian-mt-en-ROMANCE/spiece.model",
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-ROMANCE/source.spm",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
pub const ROMANCE2ENGLISH: (&'static str, &'static str) = ("marian-mt-ROMANCE-en/spiece.model", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-ROMANCE-en/source.spm");
pub const ROMANCE2ENGLISH: (&'static str, &'static str) = (
"marian-mt-ROMANCE-en/spiece.model",
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-ROMANCE-en/source.spm",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
pub const ENGLISH2GERMAN: (&'static str, &'static str) = ("marian-mt-en-de/spiece.model", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-de/source.spm");
pub const ENGLISH2GERMAN: (&'static str, &'static str) = (
"marian-mt-en-de/spiece.model",
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-de/source.spm",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
pub const GERMAN2ENGLISH: (&'static str, &'static str) = ("marian-mt-de-en/spiece.model", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-de-en/source.spm");
pub const GERMAN2ENGLISH: (&'static str, &'static str) = (
"marian-mt-de-en/spiece.model",
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-de-en/source.spm",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
pub const ENGLISH2RUSSIAN: (&'static str, &'static str) = ("marian-mt-en-ru/spiece.model", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-ru/source.spm");
pub const ENGLISH2RUSSIAN: (&'static str, &'static str) = (
"marian-mt-en-ru/spiece.model",
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-ru/source.spm",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
pub const RUSSIAN2ENGLISH: (&'static str, &'static str) = ("marian-mt-ru-en/spiece.model", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-ru-en/source.spm");
pub const RUSSIAN2ENGLISH: (&'static str, &'static str) = (
"marian-mt-ru-en/spiece.model",
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-ru-en/source.spm",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
pub const FRENCH2GERMAN: (&'static str, &'static str) = ("marian-mt-fr-de/spiece.model", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-fr-de/source.spm");
pub const FRENCH2GERMAN: (&'static str, &'static str) = (
"marian-mt-fr-de/spiece.model",
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-fr-de/source.spm",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
pub const GERMAN2FRENCH: (&'static str, &'static str) = ("marian-mt-de-fr/spiece.model", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-de-fr/source.spm");
pub const GERMAN2FRENCH: (&'static str, &'static str) = (
"marian-mt-de-fr/spiece.model",
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-de-fr/source.spm",
);
}
impl MarianPrefix {
@ -150,23 +246,34 @@ impl MarianForConditionalGeneration {
/// # Example
///
/// ```no_run
/// use tch::{nn, Device};
/// use rust_bert::bart::{BartConfig, BartForConditionalGeneration};
/// use rust_bert::Config;
/// use std::path::Path;
/// use rust_bert::bart::{BartConfig, BartForConditionalGeneration};
/// use tch::{nn, Device};
///
/// let config_path = Path::new("path/to/config.json");
/// let device = Device::Cpu;
/// let p = nn::VarStore::new(device);
/// let config = BartConfig::from_file(config_path);
/// let generation_mode = true;
/// let bart: BartForConditionalGeneration = BartForConditionalGeneration::new(&(&p.root() / "bart"), &config, generation_mode);
/// let bart: BartForConditionalGeneration =
/// BartForConditionalGeneration::new(&(&p.root() / "bart"), &config, generation_mode);
/// ```
///
pub fn new(p: &nn::Path, config: &BartConfig, generation_mode: bool) -> MarianForConditionalGeneration {
pub fn new(
p: &nn::Path,
config: &BartConfig,
generation_mode: bool,
) -> MarianForConditionalGeneration {
let base_model = BartModel::new(&(p / "model"), config, generation_mode);
let final_logits_bias = p.var("final_logits_bias", &[1, config.vocab_size], Init::Const(0.));
MarianForConditionalGeneration { base_model, final_logits_bias }
let final_logits_bias = p.var(
"final_logits_bias",
&[1, config.vocab_size],
Init::Const(0.),
);
MarianForConditionalGeneration {
base_model,
final_logits_bias,
}
}
/// Forward pass through the model
@ -197,7 +304,7 @@ impl MarianForConditionalGeneration {
/// # use rust_bert::Config;
/// # use std::path::Path;
/// # use tch::kind::Kind::{Int64, Double};
/// use rust_bert::bart::{BartConfig};
/// use rust_bert::bart::BartConfig;
/// use rust_bert::marian::MarianForConditionalGeneration;
/// # let config_path = Path::new("path/to/config.json");
/// # let vocab_path = Path::new("path/to/vocab.txt");
@ -208,49 +315,86 @@ impl MarianForConditionalGeneration {
/// let (batch_size, source_sequence_length, target_sequence_length) = (64, 128, 56);
/// let input_tensor = Tensor::rand(&[batch_size, source_sequence_length], (Int64, device));
/// let target_tensor = Tensor::rand(&[batch_size, target_sequence_length], (Int64, device));
/// let encoder_attention_mask = Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
/// let decoder_attention_mask = Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
/// let encoder_attention_mask =
/// Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
/// let decoder_attention_mask =
/// Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
///
/// let (decoder_output, encoder_hidden_states, cache,
/// all_encoder_hidden_states, all_encoder_attentions,
/// all_decoder_hidden_states, all_decoder_attentions) = no_grad(|| {
/// marian_model
/// .forward_t(Some(&input_tensor),
/// let (
/// decoder_output,
/// encoder_hidden_states,
/// cache,
/// all_encoder_hidden_states,
/// all_encoder_attentions,
/// all_decoder_hidden_states,
/// all_decoder_attentions,
/// ) = no_grad(|| {
/// marian_model.forward_t(
/// Some(&input_tensor),
/// Some(&encoder_attention_mask),
/// None,
/// Some(&target_tensor),
/// Some(&decoder_attention_mask),
/// None,
/// false)
/// false,
/// )
/// });
///
/// ```
///
pub fn forward_t(&self,
pub fn forward_t(
&self,
input_ids: Option<&Tensor>,
attention_mask: Option<&Tensor>,
encoder_outputs: Option<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>)>,
decoder_input_ids: Option<&Tensor>,
decoder_attention_mask: Option<&Tensor>,
old_layer_states: Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
train: bool)
-> (Tensor, Tensor, Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
Option<Vec<Tensor>>, Option<Vec<Tensor>>,
Option<Vec<Tensor>>, Option<Vec<Tensor>>)
{
let (decoder_outputs, encoder_hidden_states, decoder_cache,
all_decoder_hidden_states, all_decoder_attentions,
all_encoder_hidden_states, all_encoder_attentions) =
self.base_model.forward_t(input_ids, attention_mask, decoder_input_ids, encoder_outputs, decoder_attention_mask, old_layer_states, train);
train: bool,
) -> (
Tensor,
Tensor,
Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
Option<Vec<Tensor>>,
Option<Vec<Tensor>>,
Option<Vec<Tensor>>,
Option<Vec<Tensor>>,
) {
let (
decoder_outputs,
encoder_hidden_states,
decoder_cache,
all_decoder_hidden_states,
all_decoder_attentions,
all_encoder_hidden_states,
all_encoder_attentions,
) = self.base_model.forward_t(
input_ids,
attention_mask,
decoder_input_ids,
encoder_outputs,
decoder_attention_mask,
old_layer_states,
train,
);
let lm_logits = decoder_outputs.linear::<Tensor>(&self.base_model.embeddings.ws, None);
(lm_logits, encoder_hidden_states, decoder_cache,
all_decoder_hidden_states, all_decoder_attentions,
all_encoder_hidden_states, all_encoder_attentions)
(
lm_logits,
encoder_hidden_states,
decoder_cache,
all_decoder_hidden_states,
all_decoder_attentions,
all_encoder_hidden_states,
all_encoder_attentions,
)
}
pub fn encode(&self, input_ids: &Tensor, attention_mask: Option<&Tensor>) -> Tensor {
let (encoder_hidden_states, _, _) = self.base_model.encoder.forward_t(input_ids, attention_mask, &self.base_model.embeddings, false);
let (encoder_hidden_states, _, _) = self.base_model.encoder.forward_t(
input_ids,
attention_mask,
&self.base_model.embeddings,
false,
);
encoder_hidden_states
}
}
@ -287,7 +431,7 @@ impl LMHeadModel for MarianForConditionalGeneration {
/// # use rust_bert::Config;
/// # use std::path::Path;
/// # use tch::kind::Kind::{Int64, Double};
/// use rust_bert::bart::{BartConfig};
/// use rust_bert::bart::BartConfig;
/// use rust_bert::marian::MarianForConditionalGeneration;
/// # let config_path = Path::new("path/to/config.json");
/// # let vocab_path = Path::new("path/to/vocab.txt");
@ -298,25 +442,33 @@ impl LMHeadModel for MarianForConditionalGeneration {
/// let (batch_size, source_sequence_length, target_sequence_length) = (64, 128, 56);
/// let input_tensor = Tensor::rand(&[batch_size, source_sequence_length], (Int64, device));
/// let target_tensor = Tensor::rand(&[batch_size, target_sequence_length], (Int64, device));
/// let encoder_attention_mask = Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
/// let decoder_attention_mask = Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
/// let encoder_attention_mask =
/// Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
/// let decoder_attention_mask =
/// Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
///
/// let (decoder_output, encoder_hidden_states, cache,
/// all_encoder_hidden_states, all_encoder_attentions,
/// all_decoder_hidden_states, all_decoder_attentions) = no_grad(|| {
/// marian_model
/// .forward_t(Some(&input_tensor),
/// let (
/// decoder_output,
/// encoder_hidden_states,
/// cache,
/// all_encoder_hidden_states,
/// all_encoder_attentions,
/// all_decoder_hidden_states,
/// all_decoder_attentions,
/// ) = no_grad(|| {
/// marian_model.forward_t(
/// Some(&input_tensor),
/// Some(&encoder_attention_mask),
/// None,
/// Some(&target_tensor),
/// Some(&decoder_attention_mask),
/// None,
/// false)
/// false,
/// )
/// });
///
/// ```
///
fn forward_t(&self,
fn forward_t(
&self,
input_ids: &Option<Tensor>,
cache: Cache,
attention_mask: &Option<Tensor>,
@ -325,26 +477,47 @@ impl LMHeadModel for MarianForConditionalGeneration {
_input_embeds: &Option<Tensor>,
encoder_outputs: Option<&Tensor>,
decoder_input_ids: &Option<Tensor>,
train: bool) -> Result<(Tensor, Option<Tensor>, Cache, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
train: bool,
) -> Result<
(
Tensor,
Option<Tensor>,
Cache,
Option<Vec<Tensor>>,
Option<Vec<Tensor>>,
),
&'static str,
> {
let (decoder_output, encoder_hidden_states, new_cache, _, _, _, _) = match cache {
Cache::BARTCache(cached_layer_states) => self.base_model.forward_t(input_ids.as_ref(),
Cache::BARTCache(cached_layer_states) => self.base_model.forward_t(
input_ids.as_ref(),
attention_mask.as_ref(),
decoder_input_ids.as_ref(),
Some((encoder_outputs.as_ref().unwrap().copy(), None, None)),
None,
cached_layer_states,
train),
Cache::None => self.base_model.forward_t(input_ids.as_ref(),
train,
),
Cache::None => self.base_model.forward_t(
input_ids.as_ref(),
attention_mask.as_ref(),
decoder_input_ids.as_ref(),
Some((encoder_outputs.as_ref().unwrap().copy(), None, None)),
None,
None,
train),
_ => Err("Cache not compatible with Marian Model")?
train,
),
_ => Err("Cache not compatible with Marian Model")?,
};
let lm_logits = decoder_output.linear::<Tensor>(&self.base_model.embeddings.ws, None) + &self.final_logits_bias;
Ok((lm_logits, Some(encoder_hidden_states), Cache::BARTCache(new_cache), None, None))
let lm_logits = decoder_output.linear::<Tensor>(&self.base_model.embeddings.ws, None)
+ &self.final_logits_bias;
Ok((
lm_logits,
Some(encoder_hidden_states),
Cache::BARTCache(new_cache),
None,
None,
))
}
}

View File

@ -19,16 +19,24 @@
//! #
//! use tch::{nn, Device};
//! # use std::path::PathBuf;
//! use rust_bert::Config;
//! use rust_bert::bart::{BartConfig, BartModel};
//! use rust_bert::resources::{Resource, download_resource, LocalResource};
//! use rust_tokenizers::preprocessing::tokenizer::marian_tokenizer::MarianTokenizer;
//! use rust_bert::marian::MarianForConditionalGeneration;
//! use rust_bert::resources::{download_resource, LocalResource, Resource};
//! use rust_bert::Config;
//! use rust_tokenizers::preprocessing::tokenizer::marian_tokenizer::MarianTokenizer;
//!
//! let config_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/config.json")});
//! let vocab_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/vocab.json")});
//! let sentence_piece_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/spiece.model")});
//! let weights_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/model.ot")});
//! let config_resource = Resource::Local(LocalResource {
//! local_path: PathBuf::from("path/to/config.json"),
//! });
//! let vocab_resource = Resource::Local(LocalResource {
//! local_path: PathBuf::from("path/to/vocab.json"),
//! });
//! let sentence_piece_resource = Resource::Local(LocalResource {
//! local_path: PathBuf::from("path/to/spiece.model"),
//! });
//! let weights_resource = Resource::Local(LocalResource {
//! local_path: PathBuf::from("path/to/model.ot"),
//! });
//! let config_path = download_resource(&config_resource)?;
//! let vocab_path = download_resource(&vocab_resource)?;
//! let spiece_path = download_resource(&sentence_piece_resource)?;
@ -36,7 +44,11 @@
//!
//! let device = Device::cuda_if_available();
//! let mut vs = nn::VarStore::new(device);
//! let tokenizer = MarianTokenizer::from_files(vocab_path.to_str().unwrap(), spiece_path.to_str().unwrap(), true);
//! let tokenizer = MarianTokenizer::from_files(
//! vocab_path.to_str().unwrap(),
//! spiece_path.to_str().unwrap(),
//! true,
//! );
//! let config = BartConfig::from_file(config_path);
//! let marian_model = MarianForConditionalGeneration::new(&vs.root(), &config, false);
//! vs.load(weights_path)?;
@ -47,4 +59,7 @@
mod marian;
pub use marian::{MarianForConditionalGeneration, MarianModelResources, MarianConfigResources, MarianVocabResources, MarianSpmResources, MarianPrefix};
pub use marian::{
MarianConfigResources, MarianForConditionalGeneration, MarianModelResources, MarianPrefix,
MarianSpmResources, MarianVocabResources,
};

View File

@ -18,15 +18,23 @@
//! use rust_tokenizers::OpenAiGptTokenizer;
//! use tch::{nn, Device};
//! # use std::path::PathBuf;
//! use rust_bert::Config;
//! use rust_bert::gpt2::Gpt2Config;
//! use rust_bert::openai_gpt::OpenAiGptModel;
//! use rust_bert::resources::{Resource, download_resource, LocalResource};
//! use rust_bert::resources::{download_resource, LocalResource, Resource};
//! use rust_bert::Config;
//!
//! let config_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/config.json")});
//! let vocab_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/vocab.txt")});
//! let merges_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/vocab.txt")});
//! let weights_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/model.ot")});
//! let config_resource = Resource::Local(LocalResource {
//! local_path: PathBuf::from("path/to/config.json"),
//! });
//! let vocab_resource = Resource::Local(LocalResource {
//! local_path: PathBuf::from("path/to/vocab.txt"),
//! });
//! let merges_resource = Resource::Local(LocalResource {
//! local_path: PathBuf::from("path/to/vocab.txt"),
//! });
//! let weights_resource = Resource::Local(LocalResource {
//! local_path: PathBuf::from("path/to/model.ot"),
//! });
//! let config_path = download_resource(&config_resource)?;
//! let vocab_path = download_resource(&vocab_resource)?;
//! let merges_path = download_resource(&merges_resource)?;
@ -34,7 +42,11 @@
//!
//! let device = Device::cuda_if_available();
//! let mut vs = nn::VarStore::new(device);
//! let tokenizer: OpenAiGptTokenizer = OpenAiGptTokenizer::from_file(vocab_path.to_str().unwrap(), merges_path.to_str().unwrap(), true);
//! let tokenizer: OpenAiGptTokenizer = OpenAiGptTokenizer::from_file(
//! vocab_path.to_str().unwrap(),
//! merges_path.to_str().unwrap(),
//! true,
//! );
//! let config = Gpt2Config::from_file(config_path);
//! let gpt_model = OpenAiGptModel::new(&vs.root(), &config);
//! vs.load(weights_path)?;
@ -46,5 +58,7 @@
mod openai_gpt;
mod transformer;
pub use openai_gpt::{OpenAiGptModelResources, OpenAiGptConfigResources, OpenAiGptVocabResources, OpenAiGptMergesResources,
OpenAiGptModel, OpenAIGPTLMHeadModel};
pub use openai_gpt::{
OpenAIGPTLMHeadModel, OpenAiGptConfigResources, OpenAiGptMergesResources, OpenAiGptModel,
OpenAiGptModelResources, OpenAiGptVocabResources,
};

View File

@ -12,15 +12,15 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use tch::{nn, Tensor};
use crate::common::dropout::Dropout;
use tch::nn::embedding;
use tch::kind::Kind::Int64;
use std::borrow::BorrowMut;
use crate::common::linear::{LinearNoBias, linear_no_bias};
use crate::openai_gpt::transformer::Block;
use crate::common::linear::{linear_no_bias, LinearNoBias};
use crate::gpt2::Gpt2Config;
use crate::pipelines::generation::{LMHeadModel, Cache};
use crate::openai_gpt::transformer::Block;
use crate::pipelines::generation::{Cache, LMHeadModel};
use std::borrow::BorrowMut;
use tch::kind::Kind::Int64;
use tch::nn::embedding;
use tch::{nn, Tensor};
/// # GPT Pretrained model weight files
pub struct OpenAiGptModelResources;
@ -36,22 +36,34 @@ pub struct OpenAiGptMergesResources;
impl OpenAiGptModelResources {
/// Shared under MIT license by the OpenAI team at https://github.com/openai/finetune-transformer-lm. Modified with conversion to C-array format.
pub const GPT: (&'static str, &'static str) = ("openai-gpt/model.ot", "https://cdn.huggingface.co/openai-gpt-rust_model.ot");
pub const GPT: (&'static str, &'static str) = (
"openai-gpt/model.ot",
"https://cdn.huggingface.co/openai-gpt-rust_model.ot",
);
}
impl OpenAiGptConfigResources {
/// Shared under MIT license by the OpenAI team at https://github.com/openai/finetune-transformer-lm. Modified with conversion to C-array format.
pub const GPT: (&'static str, &'static str) = ("openai-gpt/config.json", "https://cdn.huggingface.co/openai-gpt-config.json");
pub const GPT: (&'static str, &'static str) = (
"openai-gpt/config.json",
"https://cdn.huggingface.co/openai-gpt-config.json",
);
}
impl OpenAiGptVocabResources {
/// Shared under MIT license by the OpenAI team at https://github.com/openai/finetune-transformer-lm. Modified with conversion to C-array format.
pub const GPT: (&'static str, &'static str) = ("openai-gpt/vocab.txt", "https://cdn.huggingface.co/openai-gpt-vocab.json");
pub const GPT: (&'static str, &'static str) = (
"openai-gpt/vocab.txt",
"https://cdn.huggingface.co/openai-gpt-vocab.json",
);
}
impl OpenAiGptMergesResources {
/// Shared under MIT license by the OpenAI team at https://github.com/openai/finetune-transformer-lm. Modified with conversion to C-array format.
pub const GPT: (&'static str, &'static str) = ("openai-gpt/merges.txt", "https://cdn.huggingface.co/openai-gpt-merges.txt");
pub const GPT: (&'static str, &'static str) = (
"openai-gpt/merges.txt",
"https://cdn.huggingface.co/openai-gpt-merges.txt",
);
}
/// # GPT Base model
@ -82,11 +94,11 @@ impl OpenAiGptModel {
/// # Example
///
/// ```no_run
/// use tch::{nn, Device};
/// use rust_bert::Config;
/// use std::path::Path;
/// use rust_bert::gpt2::Gpt2Config;
/// use rust_bert::openai_gpt::OpenAiGptModel;
/// use rust_bert::Config;
/// use std::path::Path;
/// use tch::{nn, Device};
///
/// let config_path = Path::new("path/to/config.json");
/// let device = Device::Cpu;
@ -94,30 +106,46 @@ impl OpenAiGptModel {
/// let config = Gpt2Config::from_file(config_path);
/// let gpt2: OpenAiGptModel = OpenAiGptModel::new(&(&p.root() / "gpt"), &config);
/// ```
///
pub fn new(p: &nn::Path, config: &Gpt2Config) -> OpenAiGptModel {
let tokens_embed = embedding(&(p / "tokens_embed"), config.vocab_size, config.n_embd, Default::default());
let positions_embed = embedding(&(p / "positions_embed"), config.n_positions, config.n_embd, Default::default());
let tokens_embed = embedding(
&(p / "tokens_embed"),
config.vocab_size,
config.n_embd,
Default::default(),
);
let positions_embed = embedding(
&(p / "positions_embed"),
config.n_positions,
config.n_embd,
Default::default(),
);
let embd_pdrop = match config.embd_pdrop {
Some(value) => value,
None => 0.1
None => 0.1,
};
let drop = Dropout::new(embd_pdrop);
let mut h: Vec<Block> = vec!();
let mut h: Vec<Block> = vec![];
let h_path = &(p / "h");
for layer_index in 0..config.n_layer {
h.push(Block::new(&(h_path / layer_index), config, true));
};
}
let output_attentions = match config.output_attentions {
Some(value) => value,
None => false
None => false,
};
let output_hidden_states = match config.output_hidden_states {
Some(value) => value,
None => false
None => false,
};
OpenAiGptModel { tokens_embed, positions_embed, drop, h, output_hidden_states, output_attentions }
OpenAiGptModel {
tokens_embed,
positions_embed,
drop,
h,
output_hidden_states,
output_attentions,
}
}
/// Forward pass through the model
@ -156,64 +184,83 @@ impl OpenAiGptModel {
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
/// let attention_mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
/// let token_type_ids = Tensor::ones(&[batch_size, sequence_length], (Int64, device));
/// let position_ids = Tensor::arange(sequence_length, (Int64, device)).expand(&[batch_size, sequence_length], true);
/// let position_ids = Tensor::arange(sequence_length, (Int64, device))
/// .expand(&[batch_size, sequence_length], true);
///
/// let (output, hidden_states, attentions) = no_grad(|| {
/// gpt_model
/// .forward_t(&Some(input_tensor),
/// .forward_t(
/// &Some(input_tensor),
/// &Some(attention_mask),
/// &Some(token_type_ids),
/// &Some(position_ids),
/// &None,
/// false).unwrap()
/// false,
/// )
/// .unwrap()
/// });
///
/// ```
///
pub fn forward_t(&self,
pub fn forward_t(
&self,
input_ids: &Option<Tensor>,
attention_mask: &Option<Tensor>,
token_type_ids: &Option<Tensor>,
position_ids: &Option<Tensor>,
input_embeds: &Option<Tensor>,
train: bool) -> Result<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
train: bool,
) -> Result<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
let (input_embeddings, seq_length) = match input_ids {
Some(input_value) => match input_embeds {
Some(_) => { return Err("Only one of input ids or input embeddings may be set"); }
None => (input_value.apply(&self.tokens_embed), *input_value.size().last().unwrap())
Some(_) => {
return Err("Only one of input ids or input embeddings may be set");
}
None => (
input_value.apply(&self.tokens_embed),
*input_value.size().last().unwrap(),
),
},
None => match input_embeds {
Some(embeds) => (embeds.copy(), embeds.size()[1]),
None => { return Err("At least one of input ids or input embeddings must be set"); }
None => {
return Err("At least one of input ids or input embeddings must be set");
}
},
};
let position_ids = match position_ids {
Some(value) => value.copy(),
None => Tensor::arange(seq_length, (Int64, input_embeddings.device())).unsqueeze(0)
None => Tensor::arange(seq_length, (Int64, input_embeddings.device())).unsqueeze(0),
};
let attention_mask: Option<Tensor> = match attention_mask {
Some(value) => {
Some(
Some(value) => Some(
(value
.view((input_embeddings.size()[0], -1))
.unsqueeze(1)
.unsqueeze(2)
- 1.0
) * 10000.0)
}
None => None
- 1.0)
* 10000.0,
),
None => None,
};
let position_embeds = position_ids.apply(&self.positions_embed);
let token_type_embeds = match token_type_ids {
Some(value) => value.apply(&self.tokens_embed),
None => Tensor::zeros_like(&position_embeds)
None => Tensor::zeros_like(&position_embeds),
};
let mut hidden_state: Tensor =
(input_embeddings + position_embeds + token_type_embeds).apply_t(&self.drop, train);
let mut all_hidden_states: Option<Vec<Tensor>> = if self.output_hidden_states {
Some(vec![])
} else {
None
};
let mut all_attentions: Option<Vec<Tensor>> = if self.output_attentions {
Some(vec![])
} else {
None
};
let mut hidden_state: Tensor = (input_embeddings + position_embeds + token_type_embeds).apply_t(&self.drop, train);
let mut all_hidden_states: Option<Vec<Tensor>> = if self.output_hidden_states { Some(vec!()) } else { None };
let mut all_attentions: Option<Vec<Tensor>> = if self.output_attentions { Some(vec!()) } else { None };
let mut layers = self.h.iter();
loop {
@ -229,9 +276,9 @@ impl OpenAiGptModel {
attentions.push(temp.1.as_ref().unwrap().copy());
};
}
None => break
};
None => break,
};
}
Ok((hidden_state, all_hidden_states, all_attentions))
}
@ -258,11 +305,11 @@ impl OpenAIGPTLMHeadModel {
/// # Example
///
/// ```no_run
/// use tch::{nn, Device};
/// use rust_bert::Config;
/// use std::path::Path;
/// use rust_bert::gpt2::Gpt2Config;
/// use rust_bert::openai_gpt::OpenAIGPTLMHeadModel;
/// use rust_bert::Config;
/// use std::path::Path;
/// use tch::{nn, Device};
///
/// let config_path = Path::new("path/to/config.json");
/// let device = Device::Cpu;
@ -270,11 +317,18 @@ impl OpenAIGPTLMHeadModel {
/// let config = Gpt2Config::from_file(config_path);
/// let gpt2: OpenAIGPTLMHeadModel = OpenAIGPTLMHeadModel::new(&(&p.root() / "gpt"), &config);
/// ```
///
pub fn new(p: &nn::Path, config: &Gpt2Config) -> OpenAIGPTLMHeadModel {
let transformer = OpenAiGptModel::new(&p, config);
let lm_head = linear_no_bias(&(p / "lm_head"), config.n_embd, config.vocab_size, Default::default());
OpenAIGPTLMHeadModel { transformer, lm_head }
let lm_head = linear_no_bias(
&(p / "lm_head"),
config.n_embd,
config.vocab_size,
Default::default(),
);
OpenAIGPTLMHeadModel {
transformer,
lm_head,
}
}
}
@ -336,10 +390,9 @@ impl LMHeadModel for OpenAIGPTLMHeadModel {
/// &None,
/// false).unwrap()
/// });
///
/// ```
///
fn forward_t(&self,
fn forward_t(
&self,
input_ids: &Option<Tensor>,
_layer_past: Cache,
attention_mask: &Option<Tensor>,
@ -348,17 +401,33 @@ impl LMHeadModel for OpenAIGPTLMHeadModel {
input_embeds: &Option<Tensor>,
_encoder_outputs: Option<&Tensor>,
_decoder_input_ids: &Option<Tensor>,
train: bool) -> Result<(Tensor, Option<Tensor>, Cache, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
let (output,
all_hidden_states,
all_attentions) = self.transformer.forward_t(input_ids,
train: bool,
) -> Result<
(
Tensor,
Option<Tensor>,
Cache,
Option<Vec<Tensor>>,
Option<Vec<Tensor>>,
),
&'static str,
> {
let (output, all_hidden_states, all_attentions) = self.transformer.forward_t(
input_ids,
attention_mask,
token_type_ids,
position_ids,
input_embeds,
train)?;
train,
)?;
let lm_logits = output.apply(&self.lm_head);
Ok((lm_logits, None, Cache::None, all_hidden_states, all_attentions))
Ok((
lm_logits,
None,
Cache::None,
all_hidden_states,
all_attentions,
))
}
}

View File

@ -13,9 +13,9 @@
// limitations under the License.
use crate::gpt2::attention::Attention;
use tch::{Tensor, nn};
use crate::gpt2::transformer::MLP;
use crate::gpt2::Gpt2Config;
use tch::{nn, Tensor};
pub struct Block {
ln_1: nn::LayerNorm,
@ -26,17 +26,29 @@ pub struct Block {
impl Block {
pub fn new(p: &nn::Path, config: &Gpt2Config, scale: bool) -> Block {
let layer_norm_config = nn::LayerNormConfig { eps: config.layer_norm_epsilon, ..Default::default() };
let layer_norm_config = nn::LayerNormConfig {
eps: config.layer_norm_epsilon,
..Default::default()
};
let ln_1 = nn::layer_norm(p / "ln_1", vec![config.n_embd], layer_norm_config);
let ln_2 = nn::layer_norm(p / "ln_2", vec![config.n_embd], layer_norm_config);
let attn = Attention::new(&(p / "attn"), config, scale);
let mlp = MLP::new(&(p / "mlp"), config);
Block { ln_1, attn, ln_2, mlp }
Block {
ln_1,
attn,
ln_2,
mlp,
}
}
pub fn forward_t(&self, x: &Tensor, attention_mask: &Option<Tensor>, train: bool)
-> (Tensor, Option<Tensor>) {
pub fn forward_t(
&self,
x: &Tensor,
attention_mask: &Option<Tensor>,
train: bool,
) -> (Tensor, Option<Tensor>) {
let (output, _, attentions) = self.attn.forward_t(x, &None, attention_mask, train);
let x = (x + output).apply(&self.ln_1);
let m = self.mlp.forward_t(&x, train);

View File

@ -19,13 +19,13 @@
//!
use crate::bert::BertConfig;
use crate::distilbert::DistilBertConfig;
use rust_tokenizers::{BertTokenizer, RobertaTokenizer, TokenizedInput, TruncationStrategy};
use rust_tokenizers::preprocessing::tokenizer::base_tokenizer::Tokenizer;
use std::path::Path;
use crate::Config;
use std::collections::HashMap;
use serde::{Serialize, Deserialize};
use crate::electra::ElectraConfig;
use crate::Config;
use rust_tokenizers::preprocessing::tokenizer::base_tokenizer::Tokenizer;
use rust_tokenizers::{BertTokenizer, RobertaTokenizer, TokenizedInput, TruncationStrategy};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::Path;
#[derive(Clone, Copy, Serialize, Deserialize)]
/// # Identifies the type of model
@ -60,25 +60,42 @@ impl ConfigOption {
match model_type {
ModelType::Bert | ModelType::Roberta => ConfigOption::Bert(BertConfig::from_file(path)),
ModelType::DistilBert => ConfigOption::DistilBert(DistilBertConfig::from_file(path)),
ModelType::Electra => ConfigOption::Electra(ElectraConfig::from_file(path))
ModelType::Electra => ConfigOption::Electra(ElectraConfig::from_file(path)),
}
}
pub fn get_label_mapping(self) -> HashMap<i64, String> {
match self {
Self::Bert(config) => config.id2label.expect("No label dictionary (id2label) provided in configuration file"),
Self::DistilBert(config) => config.id2label.expect("No label dictionary (id2label) provided in configuration file"),
Self::Electra(config) => config.id2label.expect("No label dictionary (id2label) provided in configuration file"),
Self::Bert(config) => config
.id2label
.expect("No label dictionary (id2label) provided in configuration file"),
Self::DistilBert(config) => config
.id2label
.expect("No label dictionary (id2label) provided in configuration file"),
Self::Electra(config) => config
.id2label
.expect("No label dictionary (id2label) provided in configuration file"),
}
}
}
impl TokenizerOption {
/// Interface method to load a tokenizer from file
pub fn from_file(model_type: ModelType, vocab_path: &str, merges_path: Option<&str>, lower_case: bool) -> Self {
pub fn from_file(
model_type: ModelType,
vocab_path: &str,
merges_path: Option<&str>,
lower_case: bool,
) -> Self {
match model_type {
ModelType::Bert | ModelType::DistilBert | ModelType::Electra => TokenizerOption::Bert(BertTokenizer::from_file(vocab_path, lower_case)),
ModelType::Roberta => TokenizerOption::Roberta(RobertaTokenizer::from_file(vocab_path, merges_path.expect("No merges specified!"), lower_case)),
ModelType::Bert | ModelType::DistilBert | ModelType::Electra => {
TokenizerOption::Bert(BertTokenizer::from_file(vocab_path, lower_case))
}
ModelType::Roberta => TokenizerOption::Roberta(RobertaTokenizer::from_file(
vocab_path,
merges_path.expect("No merges specified!"),
lower_case,
)),
}
}
@ -86,15 +103,25 @@ impl TokenizerOption {
pub fn model_type(&self) -> ModelType {
match *self {
Self::Bert(_) => ModelType::Bert,
Self::Roberta(_) => ModelType::Roberta
Self::Roberta(_) => ModelType::Roberta,
}
}
/// Interface method
pub fn encode_list(&self, text_list: Vec<&str>, max_len: usize, truncation_strategy: &TruncationStrategy, stride: usize) -> Vec<TokenizedInput> {
pub fn encode_list(
&self,
text_list: Vec<&str>,
max_len: usize,
truncation_strategy: &TruncationStrategy,
stride: usize,
) -> Vec<TokenizedInput> {
match *self {
Self::Bert(ref tokenizer) => tokenizer.encode_list(text_list, max_len, truncation_strategy, stride),
Self::Roberta(ref tokenizer) => tokenizer.encode_list(text_list, max_len, truncation_strategy, stride)
Self::Bert(ref tokenizer) => {
tokenizer.encode_list(text_list, max_len, truncation_strategy, stride)
}
Self::Roberta(ref tokenizer) => {
tokenizer.encode_list(text_list, max_len, truncation_strategy, stride)
}
}
}
}

File diff suppressed because it is too large Load Diff

View File

@ -6,14 +6,14 @@
//! Extractive question answering from a given question and context. DistilBERT model finetuned on SQuAD (Stanford Question Answering Dataset)
//!
//! ```no_run
//! use rust_bert::pipelines::question_answering::{QuestionAnsweringModel, QaInput};
//! use rust_bert::pipelines::question_answering::{QaInput, QuestionAnsweringModel};
//! # fn main() -> failure::Fallible<()> {
//! let qa_model = QuestionAnsweringModel::new(Default::default())?;
//!
//! let question = String::from("Where does Amy live ?");
//! let context = String::from("Amy lives in Amsterdam");
//!
//! let answers = qa_model.predict(&vec!(QaInput { question, context }), 1, 32);
//! let answers = qa_model.predict(&vec![QaInput { question, context }], 1, 32);
//! # Ok(())
//! # }
//! ```
@ -22,15 +22,12 @@
//! ```no_run
//! # use rust_bert::pipelines::question_answering::Answer;
//! # let output =
//! [
//! Answer {
//! [Answer {
//! score: 0.9976,
//! start: 13,
//! end: 21,
//! answer: "Amsterdam"
//!# .to_owned()
//! }
//! ]
//! answer: "Amsterdam", //#### # .to_owned()
//! }]
//! # ;
//! ```
//!
@ -48,9 +45,10 @@
//! ```no_run
//! # fn main() -> failure::Fallible<()> {
//! # use rust_bert::pipelines::generation::LanguageGenerator;
//! use rust_bert::pipelines::translation::{TranslationModel, TranslationConfig, Language};
//! use rust_bert::pipelines::translation::{Language, TranslationConfig, TranslationModel};
//! use tch::Device;
//! let translation_config = TranslationConfig::new(Language::EnglishToFrench, Device::cuda_if_available());
//! let translation_config =
//! TranslationConfig::new(Language::EnglishToFrench, Device::cuda_if_available());
//! let mut model = TranslationModel::new(translation_config)?;
//!
//! let input = ["This is a sentence to be translated"];
@ -129,7 +127,7 @@
//! let mut model = GPT2Generator::new(Default::default())?;
//! let input_context_1 = "The dog";
//! let input_context_2 = "The cat was";
//! let output = model.generate(Some(vec!(input_context_1, input_context_2)), None);
//! let output = model.generate(Some(vec![input_context_1, input_context_2]), None);
//! # Ok(())
//! # }
//! ```
@ -170,9 +168,18 @@
//! # use rust_bert::pipelines::sentiment::SentimentPolarity::{Positive, Negative};
//! # let output =
//! [
//! Sentiment { polarity: Positive, score: 0.998 },
//! Sentiment { polarity: Negative, score: 0.992 },
//! Sentiment { polarity: Positive, score: 0.999 }
//! Sentiment {
//! polarity: Positive,
//! score: 0.998,
//! },
//! Sentiment {
//! polarity: Negative,
//! score: 0.992,
//! },
//! Sentiment {
//! polarity: Positive,
//! score: 0.999,
//! },
//! ]
//! # ;
//! ```
@ -185,7 +192,7 @@
//! let ner_model = NERModel::new(Default::default())?;
//! let input = [
//! "My name is Amy. I live in Paris.",
//! "Paris is a city in France."
//! "Paris is a city in France.",
//! ];
//! let output = ner_model.predict(&input);
//! # Ok(())
@ -197,21 +204,37 @@
//! # use rust_bert::pipelines::ner::Entity;
//! # let output =
//! [
//! Entity { word: String::from("Amy"), score: 0.9986, label: String::from("I-PER") },
//! Entity { word: String::from("Paris"), score: 0.9985, label: String::from("I-LOC") },
//! Entity { word: String::from("Paris"), score: 0.9988, label: String::from("I-LOC") },
//! Entity { word: String::from("France"), score: 0.9993, label: String::from("I-LOC") },
//! Entity {
//! word: String::from("Amy"),
//! score: 0.9986,
//! label: String::from("I-PER"),
//! },
//! Entity {
//! word: String::from("Paris"),
//! score: 0.9985,
//! label: String::from("I-LOC"),
//! },
//! Entity {
//! word: String::from("Paris"),
//! score: 0.9988,
//! label: String::from("I-LOC"),
//! },
//! Entity {
//! word: String::from("France"),
//! score: 0.9993,
//! label: String::from("I-LOC"),
//! },
//! ]
//! # ;
//! ```
//!
pub mod sentiment;
pub mod common;
pub mod token_classification;
pub mod sequence_classification;
pub mod generation;
pub mod ner;
pub mod question_answering;
pub mod generation;
pub mod sentiment;
pub mod sequence_classification;
pub mod summarization;
pub mod token_classification;
pub mod translation;

View File

@ -25,7 +25,7 @@
//!
//! let input = [
//! "My name is Amy. I live in Paris.",
//! "Paris is a city in France."
//! "Paris is a city in France.",
//! ];
//! let output = ner_model.predict(&input);
//! # Ok(())
@ -37,16 +37,31 @@
//! # use rust_bert::pipelines::ner::Entity;
//! # let output =
//! [
//! Entity { word: String::from("Amy"), score: 0.9986, label: String::from("I-PER") },
//! Entity { word: String::from("Paris"), score: 0.9985, label: String::from("I-LOC") },
//! Entity { word: String::from("Paris"), score: 0.9988, label: String::from("I-LOC") },
//! Entity { word: String::from("France"), score: 0.9993, label: String::from("I-LOC") },
//! Entity {
//! word: String::from("Amy"),
//! score: 0.9986,
//! label: String::from("I-PER"),
//! },
//! Entity {
//! word: String::from("Paris"),
//! score: 0.9985,
//! label: String::from("I-LOC"),
//! },
//! Entity {
//! word: String::from("Paris"),
//! score: 0.9988,
//! label: String::from("I-LOC"),
//! },
//! Entity {
//! word: String::from("France"),
//! score: 0.9993,
//! label: String::from("I-LOC"),
//! },
//! ]
//! # ;
//! ```
use crate::pipelines::token_classification::{TokenClassificationModel, TokenClassificationConfig};
use crate::pipelines::token_classification::{TokenClassificationConfig, TokenClassificationModel};
#[derive(Debug)]
/// # Entity generated by a `NERModel`
@ -64,7 +79,7 @@ type NERConfig = TokenClassificationConfig;
/// # NERModel to extract named entities
pub struct NERModel {
token_classification_model: TokenClassificationModel
token_classification_model: TokenClassificationModel,
}
impl NERModel {
@ -84,10 +99,11 @@ impl NERModel {
/// # Ok(())
/// # }
/// ```
///
pub fn new(ner_config: NERConfig) -> failure::Fallible<NERModel> {
let model = TokenClassificationModel::new(ner_config)?;
Ok(NERModel { token_classification_model: model })
Ok(NERModel {
token_classification_model: model,
})
}
/// Extract entities from a text
@ -109,24 +125,22 @@ impl NERModel {
/// let ner_model = NERModel::new(Default::default())?;
/// let input = [
/// "My name is Amy. I live in Paris.",
/// "Paris is a city in France."
/// "Paris is a city in France.",
/// ];
/// let output = ner_model.predict(&input);
/// # Ok(())
/// # }
/// ```
///
pub fn predict(&self, input: &[&str]) -> Vec<Entity> {
self.token_classification_model
.predict(input, true, false)
.into_iter()
.filter(|token| token.label != "O")
.map(|token| {
Entity {
.map(|token| Entity {
word: token.text,
score: token.score,
label: token.label,
}
}).collect()
})
.collect()
}
}

View File

@ -17,7 +17,7 @@
//! The dependencies will be downloaded to the user's home directory, under ~/.cache/.rustbert/distilbert-qa
//!
//! ```no_run
//! use rust_bert::pipelines::question_answering::{QuestionAnsweringModel, QaInput};
//! use rust_bert::pipelines::question_answering::{QaInput, QuestionAnsweringModel};
//!
//! # fn main() -> failure::Fallible<()> {
//! let qa_model = QuestionAnsweringModel::new(Default::default())?;
@ -25,7 +25,7 @@
//! let question = String::from("Where does Amy live ?");
//! let context = String::from("Amy lives in Amsterdam");
//!
//! let answers = qa_model.predict(&vec!(QaInput { question, context }), 1, 32);
//! let answers = qa_model.predict(&vec![QaInput { question, context }], 1, 32);
//! # Ok(())
//! # }
//! ```
@ -34,31 +34,31 @@
//! ```no_run
//! # use rust_bert::pipelines::question_answering::Answer;
//! # let output =
//! [
//! Answer {
//! [Answer {
//! score: 0.9976,
//! start: 13,
//! end: 21,
//! answer: "Amsterdam"
//!# .to_owned()
//! }
//! ]
//! answer: "Amsterdam", //#### # .to_owned()
//! }]
//! # ;
//! ```
use rust_tokenizers::{BertTokenizer, Tokenizer, TruncationStrategy, TokenizedInput};
use rust_tokenizers::preprocessing::tokenizer::base_tokenizer::Mask;
use tch::{Device, Tensor, no_grad};
use std::path::PathBuf;
use rust_tokenizers::tokenization_utils::truncate_sequences;
use std::collections::HashMap;
use std::cmp::min;
use tch::nn::VarStore;
use tch::kind::Kind::Float;
use std::fs;
use crate::common::resources::{download_resource, RemoteResource, Resource};
use crate::distilbert::{
DistilBertConfig, DistilBertConfigResources, DistilBertForQuestionAnswering,
DistilBertModelResources, DistilBertVocabResources,
};
use crate::Config;
use crate::distilbert::{DistilBertForQuestionAnswering, DistilBertConfig, DistilBertConfigResources, DistilBertModelResources, DistilBertVocabResources};
use crate::common::resources::{Resource, RemoteResource, download_resource};
use rust_tokenizers::preprocessing::tokenizer::base_tokenizer::Mask;
use rust_tokenizers::tokenization_utils::truncate_sequences;
use rust_tokenizers::{BertTokenizer, TokenizedInput, Tokenizer, TruncationStrategy};
use std::cmp::min;
use std::collections::HashMap;
use std::fs;
use std::path::PathBuf;
use tch::kind::Kind::Float;
use tch::nn::VarStore;
use tch::{no_grad, Device, Tensor};
/// # Input for Question Answering
/// Includes a context (containing the answer) and question strings
@ -84,7 +84,6 @@ struct QaFeature {
pub token_to_orig_map: HashMap<i64, i64>,
pub p_mask: Vec<i8>,
pub example_index: i64,
}
#[derive(Debug, Clone)]
@ -102,34 +101,38 @@ pub struct Answer {
impl PartialEq for Answer {
fn eq(&self, other: &Self) -> bool {
(self.start == other.start) &&
(self.end == other.end) &&
(self.answer == other.answer)
(self.start == other.start) && (self.end == other.end) && (self.answer == other.answer)
}
}
fn remove_duplicates<T: PartialEq + Clone>(vector: &mut Vec<T>) -> &mut Vec<T> {
let mut potential_duplicates = vec!();
vector.retain(|item| if potential_duplicates.contains(item) {
let mut potential_duplicates = vec![];
vector.retain(|item| {
if potential_duplicates.contains(item) {
false
} else {
potential_duplicates.push(item.clone());
true
}
});
vector
}
impl QaExample {
pub fn new(question: &str, context: &str) -> QaExample {
let question = question.to_owned();
let (doc_tokens, char_to_word_offset) = QaExample::split_context(context);
QaExample { question, context: context.to_owned(), doc_tokens, char_to_word_offset }
QaExample {
question,
context: context.to_owned(),
doc_tokens,
char_to_word_offset,
}
}
fn split_context(context: &str) -> (Vec<String>, Vec<i64>) {
let mut doc_tokens: Vec<String> = vec!();
let mut char_to_word_offset: Vec<i64> = vec!();
let mut doc_tokens: Vec<String> = vec![];
let mut char_to_word_offset: Vec<i64> = vec![];
let max_length = context.len();
let mut current_word = String::with_capacity(max_length);
let mut previous_whitespace = false;
@ -158,11 +161,11 @@ impl QaExample {
}
fn is_whitespace(character: &char) -> bool {
(character == &' ') |
(character == &'\t') |
(character == &'\r') |
(character == &'\n') |
(*character as u32 == 0x202F)
(character == &' ')
| (character == &'\t')
| (character == &'\r')
| (character == &'\n')
| (*character as u32 == 0x202F)
}
}
@ -182,9 +185,15 @@ pub struct QuestionAnsweringConfig {
impl Default for QuestionAnsweringConfig {
fn default() -> QuestionAnsweringConfig {
QuestionAnsweringConfig {
model_resource: Resource::Remote(RemoteResource::from_pretrained(DistilBertModelResources::DISTIL_BERT_SQUAD)),
config_resource: Resource::Remote(RemoteResource::from_pretrained(DistilBertConfigResources::DISTIL_BERT_SQUAD)),
vocab_resource: Resource::Remote(RemoteResource::from_pretrained(DistilBertVocabResources::DISTIL_BERT_SQUAD)),
model_resource: Resource::Remote(RemoteResource::from_pretrained(
DistilBertModelResources::DISTIL_BERT_SQUAD,
)),
config_resource: Resource::Remote(RemoteResource::from_pretrained(
DistilBertConfigResources::DISTIL_BERT_SQUAD,
)),
vocab_resource: Resource::Remote(RemoteResource::from_pretrained(
DistilBertVocabResources::DISTIL_BERT_SQUAD,
)),
device: Device::cuda_if_available(),
}
}
@ -220,16 +229,23 @@ impl QuestionAnsweringModel {
/// # Ok(())
/// # }
/// ```
///
pub fn new(question_answering_config: QuestionAnsweringConfig) -> failure::Fallible<QuestionAnsweringModel> {
pub fn new(
question_answering_config: QuestionAnsweringConfig,
) -> failure::Fallible<QuestionAnsweringModel> {
let config_path = download_resource(&question_answering_config.config_resource)?;
let vocab_path = download_resource(&question_answering_config.vocab_resource)?;
let weights_path = download_resource(&question_answering_config.model_resource)?;
let device = question_answering_config.device;
let tokenizer = BertTokenizer::from_file(vocab_path.to_str().unwrap(), false);
let pad_idx = *Tokenizer::vocab(&tokenizer).special_values.get("[PAD]").expect("[PAD] token not found in vocabulary");
let sep_idx = *Tokenizer::vocab(&tokenizer).special_values.get("[SEP]").expect("[SEP] token not found in vocabulary");
let pad_idx = *Tokenizer::vocab(&tokenizer)
.special_values
.get("[PAD]")
.expect("[PAD] token not found in vocabulary");
let sep_idx = *Tokenizer::vocab(&tokenizer)
.special_values
.get("[SEP]")
.expect("[SEP] token not found in vocabulary");
let mut var_store = VarStore::new(device);
let mut config = DistilBertConfig::from_file(config_path);
// The config for the current pre-trained question answering model indicates position embeddings which does not seem accurate
@ -249,7 +265,6 @@ impl QuestionAnsweringModel {
})
}
/// Perform extractive question answering given a list of `QaInputs`
///
/// # Arguments
@ -265,7 +280,7 @@ impl QuestionAnsweringModel {
///
/// ```no_run
/// # fn main() -> failure::Fallible<()> {
/// use rust_bert::pipelines::question_answering::{QuestionAnsweringModel, QaInput};
/// use rust_bert::pipelines::question_answering::{QaInput, QuestionAnsweringModel};
///
/// let qa_model = QuestionAnsweringModel::new(Default::default())?;
///
@ -274,15 +289,25 @@ impl QuestionAnsweringModel {
/// let question_2 = String::from("Where does Eric live");
/// let context_2 = String::from("While Amy lives in Amsterdam, Eric is in The Hague.");
///
/// let qa_input_1 = QaInput { question: question_1, context: context_1 };
/// let qa_input_2 = QaInput { question: question_2, context: context_2 };
/// let qa_input_1 = QaInput {
/// question: question_1,
/// context: context_1,
/// };
/// let qa_input_2 = QaInput {
/// question: question_2,
/// context: context_2,
/// };
/// let answers = qa_model.predict(&[qa_input_1, qa_input_2], 1, 32);
///
/// # Ok(())
/// # }
/// ```
///
pub fn predict(&self, qa_inputs: &[QaInput], top_k: i64, batch_size: usize) -> Vec<Vec<Answer>> {
pub fn predict(
&self,
qa_inputs: &[QaInput],
top_k: i64,
batch_size: usize,
) -> Vec<Vec<Answer>> {
let examples: Vec<QaExample> = qa_inputs
.iter()
.map(|qa_input| QaExample::new(&qa_input.question, &qa_input.context))
@ -290,7 +315,15 @@ impl QuestionAnsweringModel {
let features: Vec<QaFeature> = examples
.iter()
.enumerate()
.map(|(example_index, qa_example)| self.generate_features(&qa_example, self.max_seq_len, self.doc_stride, self.max_query_length, example_index as i64))
.map(|(example_index, qa_example)| {
self.generate_features(
&qa_example,
self.max_seq_len,
self.doc_stride,
self.max_query_length,
example_index as i64,
)
})
.flatten()
.collect();
@ -310,28 +343,36 @@ impl QuestionAnsweringModel {
}
let input_ids = Tensor::stack(&input_ids, 0).to(self.var_store.device());
let attention_masks = Tensor::stack(&attention_masks, 0).to(self.var_store.device());
let attention_masks =
Tensor::stack(&attention_masks, 0).to(self.var_store.device());
let (start_logits, end_logits, _, _) = self.distilbert_qa.forward_t(Some(input_ids), Some(attention_masks), None, false).unwrap();
let (start_logits, end_logits, _, _) = self
.distilbert_qa
.forward_t(Some(input_ids), Some(attention_masks), None, false)
.unwrap();
let start_logits = start_logits.detach();
let end_logits = end_logits.detach();
let example_index_to_feature_end_position: Vec<(usize, i64)> = batch_features
.iter()
.enumerate()
.map(|(feature_index, feature)| (feature.example_index as usize, feature_index as i64 + 1))
.map(|(feature_index, feature)| {
(feature.example_index as usize, feature_index as i64 + 1)
})
.collect();
let mut feature_id_start = 0;
for (example_id, max_feature_id) in example_index_to_feature_end_position {
let mut answers: Vec<Answer> = vec!();
let mut answers: Vec<Answer> = vec![];
let example = &examples[example_id];
for feature_idx in feature_id_start..max_feature_id {
let feature = &batch_features[feature_idx as usize];
let start = start_logits.get(feature_idx);
let end = end_logits.get(feature_idx);
let p_mask = (Tensor::of_slice(&feature.p_mask) - 1).abs().to_device(start.device());
let p_mask = (Tensor::of_slice(&feature.p_mask) - 1)
.abs()
.to_device(start.device());
let start: Tensor = start.exp() / start.exp().sum(Float) * &p_mask;
let end: Tensor = end.exp() / end.exp().sum(Float) * &p_mask;
@ -343,33 +384,42 @@ impl QuestionAnsweringModel {
let end_pos = feature.token_to_orig_map[&ends[idx]] as usize;
let answer = example.doc_tokens[start_pos..end_pos + 1].join(" ");
let start = example.char_to_word_offset
let start = example
.char_to_word_offset
.iter()
.position(|&v| v as usize == start_pos)
.unwrap();
let end = example.char_to_word_offset
let end = example
.char_to_word_offset
.iter()
.rposition(|&v| v as usize == end_pos)
.unwrap();
answers.push(Answer { score: scores[idx], start, end, answer });
answers.push(Answer {
score: scores[idx],
start,
end,
answer,
});
}
}
feature_id_start = max_feature_id;
let example_answers = example_top_k_answers_map.entry(example_id).or_insert(vec!());
let example_answers = example_top_k_answers_map
.entry(example_id)
.or_insert(vec![]);
example_answers.extend(answers);
}
});
start = end;
}
let mut all_answers = vec!();
let mut all_answers = vec![];
for example_id in 0..examples.len() {
if let Some(answers) = example_top_k_answers_map.get_mut(&example_id) {
remove_duplicates(answers).sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
all_answers.push(answers[..min(answers.len(), top_k as usize)].to_vec());
} else {
all_answers.push(vec!());
all_answers.push(vec![]);
}
}
all_answers
@ -379,7 +429,10 @@ impl QuestionAnsweringModel {
let outer = start.unsqueeze(-1).matmul(&end.unsqueeze(0));
let start_dim = start.size()[0];
let end_dim = end.size()[0];
let candidates = outer.triu(0).tril(self.max_answer_len as i64 - 1).flatten(0, -1);
let candidates = outer
.triu(0)
.tril(self.max_answer_len as i64 - 1)
.flatten(0, -1);
let idx_sort = if top_k == 1 {
candidates.argmax(0, true)
} else if candidates.size()[0] < top_k {
@ -387,9 +440,9 @@ impl QuestionAnsweringModel {
} else {
candidates.argsort(0, true).slice(0, 0, top_k, 1)
};
let mut start: Vec<i64> = vec!();
let mut end: Vec<i64> = vec!();
let mut scores: Vec<f64> = vec!();
let mut start: Vec<i64> = vec![];
let mut end: Vec<i64> = vec![];
let mut scores: Vec<f64> = vec![];
for flat_index_position in 0..idx_sort.size()[0] {
let flat_index = idx_sort.int64_value(&[flat_index_position]);
scores.push(candidates.double_value(&[flat_index]));
@ -399,10 +452,16 @@ impl QuestionAnsweringModel {
(start, end, scores)
}
fn generate_features(&self, qa_example: &QaExample, max_seq_length: usize, doc_stride: usize, max_query_length: usize, example_index: i64) -> Vec<QaFeature> {
let mut tok_to_orig_index: Vec<i64> = vec!();
let mut all_doc_tokens: Vec<String> = vec!();
fn generate_features(
&self,
qa_example: &QaExample,
max_seq_length: usize,
doc_stride: usize,
max_query_length: usize,
example_index: i64,
) -> Vec<QaFeature> {
let mut tok_to_orig_index: Vec<i64> = vec![];
let mut all_doc_tokens: Vec<String> = vec![];
for (idx, token) in qa_example.doc_tokens.iter().enumerate() {
let sub_tokens = self.tokenizer.tokenize(token);
@ -414,28 +473,61 @@ impl QuestionAnsweringModel {
let truncated_query = self.prepare_query(&qa_example.question, max_query_length);
let sequence_added_tokens = self.tokenizer.build_input_with_special_tokens(vec!(), None, vec!(), None, vec!(), None, vec!(), None).0.len();
let sequence_pair_added_tokens = self.tokenizer.build_input_with_special_tokens(vec!(), Some(vec!()), vec!(), Some(vec!()), vec!(), Some(vec!()), vec!(), Some(vec!())).0.len();
let sequence_added_tokens = self
.tokenizer
.build_input_with_special_tokens(vec![], None, vec![], None, vec![], None, vec![], None)
.0
.len();
let sequence_pair_added_tokens = self
.tokenizer
.build_input_with_special_tokens(
vec![],
Some(vec![]),
vec![],
Some(vec![]),
vec![],
Some(vec![]),
vec![],
Some(vec![]),
)
.0
.len();
let mut spans: Vec<QaFeature> = vec!();
let mut spans: Vec<QaFeature> = vec![];
let mut remaining_tokens = self.tokenizer.convert_tokens_to_ids(&all_doc_tokens);
while (spans.len() * doc_stride as usize) < all_doc_tokens.len() {
let (encoded_span, attention_mask) = self.encode_qa_pair(&truncated_query, &remaining_tokens, max_seq_length, doc_stride, sequence_pair_added_tokens);
let (encoded_span, attention_mask) = self.encode_qa_pair(
&truncated_query,
&remaining_tokens,
max_seq_length,
doc_stride,
sequence_pair_added_tokens,
);
let paragraph_len = min(
all_doc_tokens.len() - spans.len() * doc_stride,
max_seq_length - truncated_query.len() - sequence_pair_added_tokens);
max_seq_length - truncated_query.len() - sequence_pair_added_tokens,
);
let mut token_to_orig_map = HashMap::new();
for i in 0..paragraph_len {
let index = truncated_query.len() + sequence_added_tokens + i;
token_to_orig_map.insert(index as i64, tok_to_orig_index[spans.len() * doc_stride + i] as i64);
token_to_orig_map.insert(
index as i64,
tok_to_orig_index[spans.len() * doc_stride + i] as i64,
);
}
let p_mask = self.get_mask(&encoded_span);
let qa_feature = QaFeature { input_ids: encoded_span.token_ids, attention_mask, token_to_orig_map, p_mask, example_index };
let qa_feature = QaFeature {
input_ids: encoded_span.token_ids,
attention_mask,
token_to_orig_map,
p_mask,
example_index,
};
spans.push(qa_feature);
if encoded_span.num_truncated_tokens == 0 {
@ -447,59 +539,81 @@ impl QuestionAnsweringModel {
}
fn prepare_query(&self, query: &str, max_query_length: usize) -> Vec<i64> {
let truncated_query = self.tokenizer.convert_tokens_to_ids(&self.tokenizer.tokenize(&query));
let num_query_tokens_to_remove = if truncated_query.len() > max_query_length as usize { truncated_query.len() - max_query_length } else { 0 };
let (truncated_query, _, _, _, _, _, _, _, _, _) = truncate_sequences(truncated_query,
let truncated_query = self
.tokenizer
.convert_tokens_to_ids(&self.tokenizer.tokenize(&query));
let num_query_tokens_to_remove = if truncated_query.len() > max_query_length as usize {
truncated_query.len() - max_query_length
} else {
0
};
let (truncated_query, _, _, _, _, _, _, _, _, _) = truncate_sequences(
truncated_query,
None,
vec!(),
vec![],
None,
vec!(),
vec![],
None,
vec!(),
vec![],
None,
num_query_tokens_to_remove,
&TruncationStrategy::OnlyFirst,
0).unwrap();
0,
)
.unwrap();
truncated_query
}
fn encode_qa_pair(&self,
fn encode_qa_pair(
&self,
truncated_query: &Vec<i64>,
spans_token_ids: &Vec<i64>,
max_seq_length: usize,
doc_stride: usize,
sequence_pair_added_tokens: usize) -> (TokenizedInput, Vec<i64>) {
sequence_pair_added_tokens: usize,
) -> (TokenizedInput, Vec<i64>) {
let len_1 = truncated_query.len();
let len_2 = spans_token_ids.len();
let total_len = len_1 + len_2 + sequence_pair_added_tokens;
let num_truncated_tokens = if total_len > max_seq_length { total_len - max_seq_length } else { 0 };
let num_truncated_tokens = if total_len > max_seq_length {
total_len - max_seq_length
} else {
0
};
let (truncated_query, truncated_context, _, _, _, _, _, _, overflowing_tokens, _)
= truncate_sequences(truncated_query.clone(),
let (truncated_query, truncated_context, _, _, _, _, _, _, overflowing_tokens, _) =
truncate_sequences(
truncated_query.clone(),
Some(spans_token_ids.clone()),
vec!(),
vec![],
None,
vec!(),
vec![],
None,
vec!(),
vec![],
None,
num_truncated_tokens,
&TruncationStrategy::OnlySecond,
max_seq_length - doc_stride - len_1 - sequence_pair_added_tokens).unwrap();
max_seq_length - doc_stride - len_1 - sequence_pair_added_tokens,
)
.unwrap();
let (mut token_ids,
let (
mut token_ids,
mut segment_ids,
special_tokens_mask,
mut token_offsets,
mut reference_offsets,
mut mask) = self.tokenizer.build_input_with_special_tokens(truncated_query,
mut mask,
) = self.tokenizer.build_input_with_special_tokens(
truncated_query,
truncated_context,
vec!(),
vec![],
None,
vec!(),
vec![],
None,
vec!(),
None);
vec![],
None,
);
let mut attention_mask = vec![1; token_ids.len()];
if token_ids.len() < max_seq_length {
token_ids.append(&mut vec![self.pad_idx; max_seq_length - token_ids.len()]);
@ -509,18 +623,32 @@ impl QuestionAnsweringModel {
reference_offsets.append(&mut vec![vec!(); max_seq_length - token_offsets.len()]);
mask.append(&mut vec![Mask::Special; max_seq_length - mask.len()]);
}
(TokenizedInput { token_ids, segment_ids, special_tokens_mask, overflowing_tokens, num_truncated_tokens, token_offsets, reference_offsets, mask }, attention_mask)
(
TokenizedInput {
token_ids,
segment_ids,
special_tokens_mask,
overflowing_tokens,
num_truncated_tokens,
token_offsets,
reference_offsets,
mask,
},
attention_mask,
)
}
fn get_mask(&self, encoded_span: &TokenizedInput) -> Vec<i8> {
let sep_indices: Vec<usize> = encoded_span.token_ids
let sep_indices: Vec<usize> = encoded_span
.token_ids
.iter()
.enumerate()
.filter(|(_, &value)| value == self.sep_idx)
.map(|(position, _)| position)
.collect();
let mut p_mask: Vec<i8> = encoded_span.segment_ids
let mut p_mask: Vec<i8> = encoded_span
.segment_ids
.iter()
.map(|v| min(v, &1i8))
.map(|&v| 1i8 - v)
@ -534,10 +662,13 @@ impl QuestionAnsweringModel {
pub fn squad_processor(file_path: PathBuf) -> Vec<QaInput> {
let file = fs::File::open(file_path).expect("unable to open file");
let json: serde_json::Value = serde_json::from_reader(file).expect("JSON not properly formatted");
let json: serde_json::Value =
serde_json::from_reader(file).expect("JSON not properly formatted");
let data = json
.get("data").expect("SQuAD file does not contain data field")
.as_array().expect("Data array not properly formatted");
.get("data")
.expect("SQuAD file does not contain data field")
.as_array()
.expect("Data array not properly formatted");
let mut qa_inputs: Vec<QaInput> = Vec::with_capacity(data.len());
for qa_input in data.iter() {
@ -548,8 +679,17 @@ pub fn squad_processor(file_path: PathBuf) -> Vec<QaInput> {
let context = paragraph.get("context").unwrap().as_str().unwrap();
let qas = paragraph.get("qas").unwrap().as_array().unwrap();
for qa in qas.iter() {
let question = qa.as_object().unwrap().get("question").unwrap().as_str().unwrap();
qa_inputs.push(QaInput { question: question.to_owned(), context: context.to_owned() });
let question = qa
.as_object()
.unwrap()
.get("question")
.unwrap()
.as_str()
.unwrap();
qa_inputs.push(QaInput {
question: question.to_owned(),
context: context.to_owned(),
});
}
}
}

View File

@ -38,18 +38,29 @@
//! # use rust_bert::pipelines::sentiment::SentimentPolarity::{Positive, Negative};
//! # let output =
//! [
//! Sentiment { polarity: Positive, score: 0.998 },
//! Sentiment { polarity: Negative, score: 0.992 },
//! Sentiment { polarity: Positive, score: 0.999 }
//! Sentiment {
//! polarity: Positive,
//! score: 0.998,
//! },
//! Sentiment {
//! polarity: Negative,
//! score: 0.992,
//! },
//! Sentiment {
//! polarity: Positive,
//! score: 0.999,
//! },
//! ]
//! # ;
//! ```
use std::path::PathBuf;
use std::fs;
use crate::pipelines::sequence_classification::{
SequenceClassificationConfig, SequenceClassificationModel,
};
use serde::Deserialize;
use std::error::Error;
use crate::pipelines::sequence_classification::{SequenceClassificationConfig, SequenceClassificationModel};
use std::fs;
use std::path::PathBuf;
#[derive(Debug, PartialEq)]
/// Enum with the possible sentiment polarities. Note that the pre-trained SST2 model does not include neutral sentiment.
@ -71,7 +82,7 @@ type SentimentConfig = SequenceClassificationConfig;
/// # SentimentClassifier to perform sentiment analysis
pub struct SentimentModel {
sequence_classification_model: SequenceClassificationModel
sequence_classification_model: SequenceClassificationModel,
}
impl SentimentModel {
@ -91,10 +102,11 @@ impl SentimentModel {
/// # Ok(())
/// # }
/// ```
///
pub fn new(sentiment_config: SentimentConfig) -> failure::Fallible<SentimentModel> {
let sequence_classification_model = SequenceClassificationModel::new(sentiment_config)?;
Ok(SentimentModel { sequence_classification_model })
Ok(SentimentModel {
sequence_classification_model,
})
}
/// Extract sentiment form an array of text inputs
@ -124,14 +136,20 @@ impl SentimentModel {
/// # Ok(())
/// # }
/// ```
///
pub fn predict(&self, input: &[&str]) -> Vec<Sentiment> {
let labels = self.sequence_classification_model.predict(input);
let mut sentiments = Vec::with_capacity(labels.len());
for label in labels {
let polarity = if label.id == 1 { SentimentPolarity::Positive } else { SentimentPolarity::Negative };
sentiments.push(Sentiment { polarity, score: label.score })
let polarity = if label.id == 1 {
SentimentPolarity::Positive
} else {
SentimentPolarity::Negative
};
sentiments.push(Sentiment {
polarity,
score: label.score,
})
}
sentiments
}
}

View File

@ -56,17 +56,21 @@
//! # ;
//! ```
//!
use tch::nn::VarStore;
use rust_tokenizers::preprocessing::tokenizer::base_tokenizer::{TokenizedInput, TruncationStrategy};
use std::collections::HashMap;
use tch::{Tensor, no_grad, Device, Kind};
use crate::bert::BertForSequenceClassification;
use crate::common::resources::{download_resource, RemoteResource, Resource};
use crate::distilbert::{
DistilBertConfigResources, DistilBertModelClassifier, DistilBertModelResources,
DistilBertVocabResources,
};
use crate::pipelines::common::{ConfigOption, ModelType, TokenizerOption};
use crate::roberta::RobertaForSequenceClassification;
use crate::distilbert::{DistilBertModelResources, DistilBertConfigResources, DistilBertVocabResources, DistilBertModelClassifier};
use crate::common::resources::{Resource, RemoteResource, download_resource};
use serde::{Serialize, Deserialize};
use crate::pipelines::common::{ModelType, ConfigOption, TokenizerOption};
use rust_tokenizers::preprocessing::tokenizer::base_tokenizer::{
TokenizedInput, TruncationStrategy,
};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use tch::nn::VarStore;
use tch::{no_grad, Device, Kind, Tensor};
#[derive(Debug, Serialize, Deserialize)]
/// # Label generated by a `SequenceClassificationModel`
@ -112,8 +116,14 @@ impl SequenceClassificationConfig {
/// * vocab - The `Resource' pointing to the tokenizer's vocabulary to load (e.g. vocab.txt/vocab.json)
/// * vocab - An optional `Resource` tuple (`Option<Resource>`) pointing to the tokenizer's merge file to load (e.g. merges.txt), needed only for Roberta.
/// * lower_case - A `bool' indicating whether the tokeniser should lower case all input (in case of a lower-cased model)
///
pub fn new(model_type: ModelType, model_resource: Resource, config_resource: Resource, vocab_resource: Resource, merges_resource: Option<Resource>, lower_case: bool) -> SequenceClassificationConfig {
pub fn new(
model_type: ModelType,
model_resource: Resource,
config_resource: Resource,
vocab_resource: Resource,
merges_resource: Option<Resource>,
lower_case: bool,
) -> SequenceClassificationConfig {
SequenceClassificationConfig {
model_type,
model_resource,
@ -131,9 +141,15 @@ impl Default for SequenceClassificationConfig {
fn default() -> SequenceClassificationConfig {
SequenceClassificationConfig {
model_type: ModelType::DistilBert,
model_resource: Resource::Remote(RemoteResource::from_pretrained(DistilBertModelResources::DISTIL_BERT_SST2)),
config_resource: Resource::Remote(RemoteResource::from_pretrained(DistilBertConfigResources::DISTIL_BERT_SST2)),
vocab_resource: Resource::Remote(RemoteResource::from_pretrained(DistilBertVocabResources::DISTIL_BERT_SST2)),
model_resource: Resource::Remote(RemoteResource::from_pretrained(
DistilBertModelResources::DISTIL_BERT_SST2,
)),
config_resource: Resource::Remote(RemoteResource::from_pretrained(
DistilBertConfigResources::DISTIL_BERT_SST2,
)),
vocab_resource: Resource::Remote(RemoteResource::from_pretrained(
DistilBertVocabResources::DISTIL_BERT_SST2,
)),
merges_resource: None,
lower_case: true,
device: Device::cuda_if_available(),
@ -160,31 +176,38 @@ impl SequenceClassificationOption {
/// * `p` - `tch::nn::Path` path to the model file to load (e.g. model.ot)
/// * `config` - A configuration (the model type of the configuration must be compatible with the value for
/// `model_type`)
///
pub fn new(model_type: ModelType, p: &tch::nn::Path, config: &ConfigOption) -> Self {
match model_type {
ModelType::Bert => {
if let ConfigOption::Bert(config) = config {
SequenceClassificationOption::Bert(BertForSequenceClassification::new(p, config))
SequenceClassificationOption::Bert(BertForSequenceClassification::new(
p, config,
))
} else {
panic!("You can only supply a BertConfig for Bert!");
}
}
ModelType::DistilBert => {
if let ConfigOption::DistilBert(config) = config {
SequenceClassificationOption::DistilBert(DistilBertModelClassifier::new(p, config))
SequenceClassificationOption::DistilBert(DistilBertModelClassifier::new(
p, config,
))
} else {
panic!("You can only supply a DistilBertConfig for DistilBert!");
}
}
ModelType::Roberta => {
if let ConfigOption::Bert(config) = config {
SequenceClassificationOption::Roberta(RobertaForSequenceClassification::new(p, config))
SequenceClassificationOption::Roberta(RobertaForSequenceClassification::new(
p, config,
))
} else {
panic!("You can only supply a BertConfig for Roberta!");
}
}
ModelType::Electra => { panic!("SequenceClassification not implemented for Electra!"); }
ModelType::Electra => {
panic!("SequenceClassification not implemented for Electra!");
}
}
}
@ -193,27 +216,44 @@ impl SequenceClassificationOption {
match *self {
Self::Bert(_) => ModelType::Bert,
Self::Roberta(_) => ModelType::Roberta,
Self::DistilBert(_) => ModelType::DistilBert
Self::DistilBert(_) => ModelType::DistilBert,
}
}
/// Interface method to forward_t() of the particular models.
pub fn forward_t(&self,
pub fn forward_t(
&self,
input_ids: Option<Tensor>,
mask: Option<Tensor>,
token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
train: bool,
) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
match *self {
Self::Bert(ref model) => model.forward_t(input_ids, mask, token_type_ids, position_ids, input_embeds, train),
Self::DistilBert(ref model) => model.forward_t(input_ids, mask, input_embeds, train).expect("Error in distilbert forward_t"),
Self::Roberta(ref model) => model.forward_t(input_ids, mask, token_type_ids, position_ids, input_embeds, train),
Self::Bert(ref model) => model.forward_t(
input_ids,
mask,
token_type_ids,
position_ids,
input_embeds,
train,
),
Self::DistilBert(ref model) => model
.forward_t(input_ids, mask, input_embeds, train)
.expect("Error in distilbert forward_t"),
Self::Roberta(ref model) => model.forward_t(
input_ids,
mask,
token_type_ids,
position_ids,
input_embeds,
train,
),
}
}
}
/// # SequenceClassificationModel for Classification (e.g. Sentiment Analysis)
pub struct SequenceClassificationModel {
tokenizer: TokenizerOption,
@ -239,8 +279,9 @@ impl SequenceClassificationModel {
/// # Ok(())
/// # }
/// ```
///
pub fn new(config: SequenceClassificationConfig) -> failure::Fallible<SequenceClassificationModel> {
pub fn new(
config: SequenceClassificationConfig,
) -> failure::Fallible<SequenceClassificationModel> {
let config_path = download_resource(&config.config_resource)?;
let vocab_path = download_resource(&config.vocab_resource)?;
let weights_path = download_resource(&config.model_resource)?;
@ -251,31 +292,44 @@ impl SequenceClassificationModel {
};
let device = config.device;
let tokenizer = TokenizerOption::from_file(config.model_type, vocab_path.to_str().unwrap(), merges_path.map(|path| path.to_str().unwrap()), config.lower_case);
let tokenizer = TokenizerOption::from_file(
config.model_type,
vocab_path.to_str().unwrap(),
merges_path.map(|path| path.to_str().unwrap()),
config.lower_case,
);
let mut var_store = VarStore::new(device);
let model_config = ConfigOption::from_file(config.model_type, config_path);
let sequence_classifier = SequenceClassificationOption::new(config.model_type, &var_store.root(), &model_config);
let sequence_classifier =
SequenceClassificationOption::new(config.model_type, &var_store.root(), &model_config);
let label_mapping = model_config.get_label_mapping();
var_store.load(weights_path)?;
Ok(SequenceClassificationModel { tokenizer, sequence_classifier, label_mapping, var_store })
Ok(SequenceClassificationModel {
tokenizer,
sequence_classifier,
label_mapping,
var_store,
})
}
fn prepare_for_model(&self, input: Vec<&str>) -> Tensor {
let tokenized_input: Vec<TokenizedInput> = self.tokenizer.encode_list(input.to_vec(),
128,
&TruncationStrategy::LongestFirst,
0);
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
let tokenized_input_tensors: Vec<tch::Tensor> = tokenized_input.
iter().
map(|input| input.token_ids.clone()).
map(|mut input| {
let tokenized_input: Vec<TokenizedInput> =
self.tokenizer
.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let tokenized_input_tensors: Vec<tch::Tensor> = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]);
input
}).
map(|input|
Tensor::of_slice(&(input))).
collect::<Vec<_>>();
})
.map(|input| Tensor::of_slice(&(input)))
.collect::<Vec<_>>();
Tensor::stack(tokenized_input_tensors.as_slice(), 0).to(self.var_store.device())
}
@ -308,23 +362,30 @@ impl SequenceClassificationModel {
pub fn predict(&self, input: &[&str]) -> Vec<Label> {
let input_tensor = self.prepare_for_model(input.to_vec());
let output = no_grad(|| {
let (output, _, _) = self.sequence_classifier
.forward_t(Some(input_tensor.copy()),
let (output, _, _) = self.sequence_classifier.forward_t(
Some(input_tensor.copy()),
None,
None,
None,
None,
false);
false,
);
output.softmax(-1, Kind::Float).detach().to(Device::Cpu)
});
let label_indices = output.as_ref().argmax(-1, true).squeeze1(1);
let scores = output.gather(1, &label_indices.unsqueeze(-1), false).squeeze1(1);
let scores = output
.gather(1, &label_indices.unsqueeze(-1), false)
.squeeze1(1);
let label_indices = label_indices.iter::<i64>().unwrap().collect::<Vec<i64>>();
let scores = scores.iter::<f64>().unwrap().collect::<Vec<f64>>();
let mut labels: Vec<Label> = vec!();
let mut labels: Vec<Label> = vec![];
for sentence_idx in 0..label_indices.len() {
let label_string = self.label_mapping.get(&label_indices[sentence_idx]).unwrap().clone();
let label_string = self
.label_mapping
.get(&label_indices[sentence_idx])
.unwrap()
.clone();
let label = Label {
text: label_string,
score: scores[sentence_idx],
@ -366,28 +427,31 @@ impl SequenceClassificationModel {
pub fn predict_multilabel(&self, input: &[&str], threshold: f64) -> Vec<Vec<Label>> {
let input_tensor = self.prepare_for_model(input.to_vec());
let output = no_grad(|| {
let (output, _, _) = self.sequence_classifier
.forward_t(Some(input_tensor.copy()),
let (output, _, _) = self.sequence_classifier.forward_t(
Some(input_tensor.copy()),
None,
None,
None,
None,
false);
false,
);
output.sigmoid().detach().to(Device::Cpu)
});
let label_indices = output.as_ref().ge(threshold).nonzero();
let mut labels: Vec<Vec<Label>> = vec!();
let mut sequence_labels: Vec<Label> = vec!();
let mut labels: Vec<Vec<Label>> = vec![];
let mut sequence_labels: Vec<Label> = vec![];
for sentence_idx in 0..label_indices.size()[0] {
let label_index_tensor = label_indices.get(sentence_idx);
let sentence_label = label_index_tensor.iter::<i64>().unwrap().collect::<Vec<i64>>();
let sentence_label = label_index_tensor
.iter::<i64>()
.unwrap()
.collect::<Vec<i64>>();
let (sentence, id) = (sentence_label[0], sentence_label[1]);
if sentence as usize > labels.len() {
labels.push(sequence_labels);
sequence_labels = vec!();
sequence_labels = vec![];
}
let score = output.double_value(sentence_label.as_slice());
let label_string = self.label_mapping.get(&id).unwrap().to_owned();
@ -398,7 +462,6 @@ impl SequenceClassificationModel {
sentence: sentence as usize,
};
sequence_labels.push(label);
}
if sequence_labels.len() > 0 {
labels.push(sequence_labels);

View File

@ -11,7 +11,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
//! # Summarization pipeline
//! Abstractive summarization of texts based on the BART encoder-decoder architecture
//! Include techniques such as beam search, top-k and nucleus sampling, temperature setting and repetition penalty.
@ -63,10 +62,12 @@
//! # ;
//! ```
use crate::bart::{
BartConfigResources, BartMergesResources, BartModelResources, BartVocabResources,
};
use crate::common::resources::{RemoteResource, Resource};
use crate::pipelines::generation::{BartGenerator, GenerateConfig, LanguageGenerator};
use tch::Device;
use crate::common::resources::{Resource, RemoteResource};
use crate::bart::{BartModelResources, BartConfigResources, BartVocabResources, BartMergesResources};
/// # Configuration for text summarization
/// Contains information regarding the model to load, mirrors the GenerationConfig, with a
@ -111,10 +112,18 @@ pub struct SummarizationConfig {
impl Default for SummarizationConfig {
fn default() -> SummarizationConfig {
SummarizationConfig {
model_resource: Resource::Remote(RemoteResource::from_pretrained(BartModelResources::BART_CNN)),
config_resource: Resource::Remote(RemoteResource::from_pretrained(BartConfigResources::BART_CNN)),
vocab_resource: Resource::Remote(RemoteResource::from_pretrained(BartVocabResources::BART_CNN)),
merges_resource: Resource::Remote(RemoteResource::from_pretrained(BartMergesResources::BART_CNN)),
model_resource: Resource::Remote(RemoteResource::from_pretrained(
BartModelResources::BART_CNN,
)),
config_resource: Resource::Remote(RemoteResource::from_pretrained(
BartConfigResources::BART_CNN,
)),
vocab_resource: Resource::Remote(RemoteResource::from_pretrained(
BartVocabResources::BART_CNN,
)),
merges_resource: Resource::Remote(RemoteResource::from_pretrained(
BartMergesResources::BART_CNN,
)),
min_length: 56,
max_length: 142,
do_sample: false,
@ -134,7 +143,7 @@ impl Default for SummarizationConfig {
/// # SummarizationModel to perform summarization
pub struct SummarizationModel {
model: BartGenerator
model: BartGenerator,
}
impl SummarizationModel {
@ -154,9 +163,7 @@ impl SummarizationModel {
/// # Ok(())
/// # }
/// ```
///
pub fn new(summarization_config: SummarizationConfig)
-> failure::Fallible<SummarizationModel> {
pub fn new(summarization_config: SummarizationConfig) -> failure::Fallible<SummarizationModel> {
let generate_config = GenerateConfig {
model_resource: summarization_config.model_resource,
config_resource: summarization_config.config_resource,
@ -226,7 +233,6 @@ impl SummarizationModel {
/// # }
/// ```
/// (New sample credits: [WikiNews](https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b))
///
pub fn summarize(&self, texts: &[&str]) -> Vec<String> {
self.model.generate(Some(texts.to_vec()), None)
}

View File

@ -47,34 +47,87 @@
//! ```no_run
//! # use rust_bert::pipelines::token_classification::Token;
//! use rust_tokenizers::preprocessing::tokenizer::base_tokenizer::Mask::Special;
//! use rust_tokenizers::preprocessing::tokenizer::base_tokenizer::{Offset, Mask};
//! use rust_tokenizers::preprocessing::tokenizer::base_tokenizer::{Mask, Offset};
//! # let output =
//! [
//! Token { text: String::from("[CLS]"), score: 0.9995001554489136, label: String::from("O"), label_index: 0, sentence: 0, index: 0, word_index: 0, offset: None, mask: Special },
//! Token { text: String::from("My"), score: 0.9980450868606567, label: String::from("O"), label_index: 0, sentence: 0, index: 1, word_index: 1, offset: Some(Offset { begin: 0, end: 2 }), mask: Mask::None },
//! Token { text: String::from("name"), score: 0.9995062351226807, label: String::from("O"), label_index: 0, sentence: 0, index: 2, word_index: 2, offset: Some(Offset { begin: 3, end: 7 }), mask: Mask::None },
//! Token { text: String::from("is"), score: 0.9997343420982361, label: String::from("O"), label_index: 0, sentence: 0, index: 3, word_index: 3, offset: Some(Offset { begin: 8, end: 10 }), mask: Mask::None },
//! Token { text: String::from("Amélie"), score: 0.9913727683112525, label: String::from("I-PER"), label_index: 4, sentence: 0, index: 4, word_index: 4, offset: Some(Offset { begin: 11, end: 17 }), mask: Mask::None }
//! // ...
//! Token {
//! text: String::from("[CLS]"),
//! score: 0.9995001554489136,
//! label: String::from("O"),
//! label_index: 0,
//! sentence: 0,
//! index: 0,
//! word_index: 0,
//! offset: None,
//! mask: Special,
//! },
//! Token {
//! text: String::from("My"),
//! score: 0.9980450868606567,
//! label: String::from("O"),
//! label_index: 0,
//! sentence: 0,
//! index: 1,
//! word_index: 1,
//! offset: Some(Offset { begin: 0, end: 2 }),
//! mask: Mask::None,
//! },
//! Token {
//! text: String::from("name"),
//! score: 0.9995062351226807,
//! label: String::from("O"),
//! label_index: 0,
//! sentence: 0,
//! index: 2,
//! word_index: 2,
//! offset: Some(Offset { begin: 3, end: 7 }),
//! mask: Mask::None,
//! },
//! Token {
//! text: String::from("is"),
//! score: 0.9997343420982361,
//! label: String::from("O"),
//! label_index: 0,
//! sentence: 0,
//! index: 3,
//! word_index: 3,
//! offset: Some(Offset { begin: 8, end: 10 }),
//! mask: Mask::None,
//! },
//! Token {
//! text: String::from("Amélie"),
//! score: 0.9913727683112525,
//! label: String::from("I-PER"),
//! label_index: 4,
//! sentence: 0,
//! index: 4,
//! word_index: 4,
//! offset: Some(Offset { begin: 11, end: 17 }),
//! mask: Mask::None,
//! }, // ...
//! ]
//! # ;
//! ```
use tch::nn::VarStore;
use rust_tokenizers::preprocessing::tokenizer::base_tokenizer::{Tokenizer, TokenizedInput, TruncationStrategy, Mask, Offset, ConsolidatableTokens, ConsolidatedTokenIterator, TokenTrait};
use std::collections::HashMap;
use tch::{Tensor, no_grad, Device};
use tch::kind::Kind::Float;
use crate::bert::{BertForTokenClassification, BertModelResources, BertConfigResources, BertVocabResources};
use crate::roberta::RobertaForTokenClassification;
use crate::bert::{
BertConfigResources, BertForTokenClassification, BertModelResources, BertVocabResources,
};
use crate::common::resources::{download_resource, RemoteResource, Resource};
use crate::distilbert::DistilBertForTokenClassification;
use crate::common::resources::{Resource, RemoteResource, download_resource};
use crate::pipelines::common::{ModelType, ConfigOption, TokenizerOption};
use crate::electra::ElectraForTokenClassification;
use crate::pipelines::common::{ConfigOption, ModelType, TokenizerOption};
use crate::roberta::RobertaForTokenClassification;
use itertools::Itertools;
use rust_tokenizers::preprocessing::tokenizer::base_tokenizer::{
ConsolidatableTokens, ConsolidatedTokenIterator, Mask, Offset, TokenTrait, TokenizedInput,
Tokenizer, TruncationStrategy,
};
use serde::{Deserialize, Serialize};
use std::cmp::min;
use serde::{Serialize, Deserialize};
use std::collections::HashMap;
use tch::kind::Kind::Float;
use tch::nn::VarStore;
use tch::{no_grad, Device, Tensor};
#[derive(Debug, Clone, Serialize, Deserialize)]
/// # Token generated by a `TokenClassificationModel`
@ -140,7 +193,6 @@ pub enum LabelAggregationOption {
Custom(Box<dyn Fn(&[Token]) -> (i64, String)>),
}
/// # Configuration for TokenClassificationModel
/// Contains information regarding the model to load and device to place the model on.
pub struct TokenClassificationConfig {
@ -173,14 +225,15 @@ impl TokenClassificationConfig {
/// * vocab - The `Resource' pointing to the tokenizer's vocabulary to load (e.g. vocab.txt/vocab.json)
/// * vocab - An optional `Resource` tuple (`Option<Resource>`) pointing to the tokenizer's merge file to load (e.g. merges.txt), needed only for Roberta.
/// * lower_case - A `bool' indicating whether the tokenizer should lower case all input (in case of a lower-cased model)
///
pub fn new(model_type: ModelType,
pub fn new(
model_type: ModelType,
model_resource: Resource,
config_resource: Resource,
vocab_resource: Resource,
merges_resource: Option<Resource>,
lower_case: bool,
label_aggregation_function: LabelAggregationOption) -> TokenClassificationConfig {
label_aggregation_function: LabelAggregationOption,
) -> TokenClassificationConfig {
TokenClassificationConfig {
model_type,
model_resource,
@ -199,9 +252,15 @@ impl Default for TokenClassificationConfig {
fn default() -> TokenClassificationConfig {
TokenClassificationConfig {
model_type: ModelType::Bert,
model_resource: Resource::Remote(RemoteResource::from_pretrained(BertModelResources::BERT_NER)),
config_resource: Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT_NER)),
vocab_resource: Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT_NER)),
model_resource: Resource::Remote(RemoteResource::from_pretrained(
BertModelResources::BERT_NER,
)),
config_resource: Resource::Remote(RemoteResource::from_pretrained(
BertConfigResources::BERT_NER,
)),
vocab_resource: Resource::Remote(RemoteResource::from_pretrained(
BertVocabResources::BERT_NER,
)),
merges_resource: None,
lower_case: false,
device: Device::cuda_if_available(),
@ -231,7 +290,6 @@ impl TokenClassificationOption {
/// * `p` - `tch::nn::Path` path to the model file to load (e.g. model.ot)
/// * `config` - A configuration (the model type of the configuration must be compatible with the value for
/// `model_type`)
///
pub fn new(model_type: ModelType, p: &tch::nn::Path, config: &ConfigOption) -> Self {
match model_type {
ModelType::Bert => {
@ -243,21 +301,27 @@ impl TokenClassificationOption {
}
ModelType::DistilBert => {
if let ConfigOption::DistilBert(config) = config {
TokenClassificationOption::DistilBert(DistilBertForTokenClassification::new(p, config))
TokenClassificationOption::DistilBert(DistilBertForTokenClassification::new(
p, config,
))
} else {
panic!("You can only supply a DistilBertConfig for DistilBert!");
}
}
ModelType::Roberta => {
if let ConfigOption::Bert(config) = config {
TokenClassificationOption::Roberta(RobertaForTokenClassification::new(p, config))
TokenClassificationOption::Roberta(RobertaForTokenClassification::new(
p, config,
))
} else {
panic!("You can only supply a BertConfig for Roberta!");
}
}
ModelType::Electra => {
if let ConfigOption::Electra(config) = config {
TokenClassificationOption::Electra(ElectraForTokenClassification::new(p, config))
TokenClassificationOption::Electra(ElectraForTokenClassification::new(
p, config,
))
} else {
panic!("You can only supply a BertConfig for Roberta!");
}
@ -271,27 +335,51 @@ impl TokenClassificationOption {
Self::Bert(_) => ModelType::Bert,
Self::Roberta(_) => ModelType::Roberta,
Self::DistilBert(_) => ModelType::DistilBert,
Self::Electra(_) => ModelType::Electra
Self::Electra(_) => ModelType::Electra,
}
}
fn forward_t(&self,
fn forward_t(
&self,
input_ids: Option<Tensor>,
mask: Option<Tensor>,
token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
train: bool,
) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
match *self {
Self::Bert(ref model) => model.forward_t(input_ids, mask, token_type_ids, position_ids, input_embeds, train),
Self::DistilBert(ref model) => model.forward_t(input_ids, mask, input_embeds, train).expect("Error in distilbert forward_t"),
Self::Roberta(ref model) => model.forward_t(input_ids, mask, token_type_ids, position_ids, input_embeds, train),
Self::Electra(ref model) => model.forward_t(input_ids, mask, token_type_ids, position_ids, input_embeds, train),
Self::Bert(ref model) => model.forward_t(
input_ids,
mask,
token_type_ids,
position_ids,
input_embeds,
train,
),
Self::DistilBert(ref model) => model
.forward_t(input_ids, mask, input_embeds, train)
.expect("Error in distilbert forward_t"),
Self::Roberta(ref model) => model.forward_t(
input_ids,
mask,
token_type_ids,
position_ids,
input_embeds,
train,
),
Self::Electra(ref model) => model.forward_t(
input_ids,
mask,
token_type_ids,
position_ids,
input_embeds,
train,
),
}
}
}
/// # TokenClassificationModel for Named Entity Recognition or Part-of-Speech tagging
pub struct TokenClassificationModel {
tokenizer: TokenizerOption,
@ -318,7 +406,6 @@ impl TokenClassificationModel {
/// # Ok(())
/// # }
/// ```
///
pub fn new(config: TokenClassificationConfig) -> failure::Fallible<TokenClassificationModel> {
let config_path = download_resource(&config.config_resource)?;
let vocab_path = download_resource(&config.vocab_resource)?;
@ -331,32 +418,49 @@ impl TokenClassificationModel {
let device = config.device;
let label_aggregation_function = config.label_aggregation_function;
let tokenizer = TokenizerOption::from_file(config.model_type, vocab_path.to_str().unwrap(), merges_path.map(|path| path.to_str().unwrap()), config.lower_case);
let tokenizer = TokenizerOption::from_file(
config.model_type,
vocab_path.to_str().unwrap(),
merges_path.map(|path| path.to_str().unwrap()),
config.lower_case,
);
let mut var_store = VarStore::new(device);
let model_config = ConfigOption::from_file(config.model_type, config_path);
let token_sequence_classifier = TokenClassificationOption::new(config.model_type, &var_store.root(), &model_config);
let token_sequence_classifier =
TokenClassificationOption::new(config.model_type, &var_store.root(), &model_config);
let label_mapping = model_config.get_label_mapping();
var_store.load(weights_path)?;
Ok(TokenClassificationModel { tokenizer, token_sequence_classifier, label_mapping, var_store, label_aggregation_function })
Ok(TokenClassificationModel {
tokenizer,
token_sequence_classifier,
label_mapping,
var_store,
label_aggregation_function,
})
}
fn prepare_for_model(&self, input: Vec<&str>) -> (Vec<TokenizedInput>, Tensor) {
let tokenized_input: Vec<TokenizedInput> = self.tokenizer.encode_list(input.to_vec(),
128,
&TruncationStrategy::LongestFirst,
0);
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
let tokenized_input_tensors: Vec<tch::Tensor> = tokenized_input.
iter().
map(|input| input.token_ids.clone()).
map(|mut input| {
let tokenized_input: Vec<TokenizedInput> =
self.tokenizer
.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let tokenized_input_tensors: Vec<tch::Tensor> = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]);
input
}).
map(|input|
Tensor::of_slice(&(input))).
collect::<Vec<_>>();
(tokenized_input, Tensor::stack(tokenized_input_tensors.as_slice(), 0).to(self.var_store.device()))
})
.map(|input| Tensor::of_slice(&(input)))
.collect::<Vec<_>>();
(
tokenized_input,
Tensor::stack(tokenized_input_tensors.as_slice(), 0).to(self.var_store.device()),
)
}
/// Classify tokens in a text sequence
@ -380,27 +484,33 @@ impl TokenClassificationModel {
/// let ner_model = TokenClassificationModel::new(Default::default())?;
/// let input = [
/// "My name is Amy. I live in Paris.",
/// "Paris is a city in France."
/// "Paris is a city in France.",
/// ];
/// let output = ner_model.predict(&input, true, true);
/// # Ok(())
/// # }
/// ```
pub fn predict(&self, input: &[&str], consolidate_sub_tokens: bool, return_special: bool) -> Vec<Token> {
pub fn predict(
&self,
input: &[&str],
consolidate_sub_tokens: bool,
return_special: bool,
) -> Vec<Token> {
let (tokenized_input, input_tensor) = self.prepare_for_model(input.to_vec());
let (output, _, _) = no_grad(|| {
self.token_sequence_classifier
.forward_t(Some(input_tensor.copy()),
self.token_sequence_classifier.forward_t(
Some(input_tensor.copy()),
None,
None,
None,
None,
false)
false,
)
});
let output = output.detach().to(Device::Cpu);
let score: Tensor = output.exp() / output.exp().sum1(&[-1], true, Float);
let labels_idx = &score.argmax(-1, true);
let mut tokens: Vec<Token> = vec!();
let mut tokens: Vec<Token> = vec![];
for sentence_idx in 0..labels_idx.size()[0] {
let labels = labels_idx.get(sentence_idx);
let sentence_tokens = &tokenized_input[sentence_idx as usize];
@ -415,7 +525,16 @@ impl TokenClassificationModel {
word_idx += 1;
}
let token = {
self.decode_token(&original_chars, sentence_tokens, &input_tensor, &labels, &score, sentence_idx, position_idx as i64, word_idx - 1)
self.decode_token(
&original_chars,
sentence_tokens,
&input_tensor,
&labels,
&score,
sentence_idx,
position_idx as i64,
word_idx - 1,
)
};
tokens.push(token);
}
@ -426,8 +545,17 @@ impl TokenClassificationModel {
tokens
}
fn decode_token(&self, original_sentence_chars: &Vec<char>, sentence_tokens: &TokenizedInput, input_tensor: &Tensor,
labels: &Tensor, score: &Tensor, sentence_idx: i64, position_idx: i64, word_index: u16) -> Token {
fn decode_token(
&self,
original_sentence_chars: &Vec<char>,
sentence_tokens: &TokenizedInput,
input_tensor: &Tensor,
labels: &Tensor,
score: &Tensor,
sentence_idx: i64,
position_idx: i64,
word_index: u16,
) -> Token {
let label_id = labels.int64_value(&[position_idx as i64]);
let token_id = input_tensor.int64_value(&[sentence_idx, position_idx as i64]);
@ -435,13 +563,19 @@ impl TokenClassificationModel {
let text = match offsets {
None => match self.tokenizer {
TokenizerOption::Bert(ref tokenizer) => Tokenizer::decode(tokenizer, vec!(token_id), false, false),
TokenizerOption::Roberta(ref tokenizer) => Tokenizer::decode(tokenizer, vec!(token_id), false, false),
TokenizerOption::Bert(ref tokenizer) => {
Tokenizer::decode(tokenizer, vec![token_id], false, false)
}
TokenizerOption::Roberta(ref tokenizer) => {
Tokenizer::decode(tokenizer, vec![token_id], false, false)
}
},
Some(offsets) => {
let (start_char, end_char) = (offsets.begin as usize, offsets.end as usize);
let end_char = min(end_char, original_sentence_chars.len());
let text = original_sentence_chars[start_char..end_char].iter().collect();
let text = original_sentence_chars[start_char..end_char]
.iter()
.collect();
text
}
};
@ -449,7 +583,11 @@ impl TokenClassificationModel {
Token {
text,
score: score.double_value(&[sentence_idx, position_idx, label_id]),
label: self.label_mapping.get(&label_id).expect("Index out of vocabulary bounds.").to_owned(),
label: self
.label_mapping
.get(&label_id)
.expect("Index out of vocabulary bounds.")
.to_owned(),
label_index: label_id,
sentence: sentence_idx as usize,
index: position_idx as u16,
@ -459,24 +597,29 @@ impl TokenClassificationModel {
}
}
fn consolidate_tokens(&self, tokens: &mut Vec<Token>, label_aggregation_function: &LabelAggregationOption) {
let mut tokens_to_replace = vec!();
fn consolidate_tokens(
&self,
tokens: &mut Vec<Token>,
label_aggregation_function: &LabelAggregationOption,
) {
let mut tokens_to_replace = vec![];
let mut token_iter = tokens.iter_consolidate_tokens();
let mut cursor = 0;
while let Some(sub_tokens) = token_iter.next() {
if sub_tokens.len() > 1 {
let (label_index, label) = self.consolidate_labels(sub_tokens, label_aggregation_function);
let (label_index, label) =
self.consolidate_labels(sub_tokens, label_aggregation_function);
let sentence = (&sub_tokens[0]).sentence;
let index = (&sub_tokens[0]).index;
let word_index = (&sub_tokens[0]).word_index;
let offset_start = match &sub_tokens.first().unwrap().offset {
Some(offset) => Some(offset.begin),
None => None
None => None,
};
let offset_end = match &sub_tokens.last().unwrap().offset {
Some(offset) => Some(offset.end),
None => None
None => None,
};
let offset = if offset_start.is_some() & offset_end.is_some() {
Some(Offset::new(offset_start.unwrap(), offset_end.unwrap()))
@ -513,7 +656,11 @@ impl TokenClassificationModel {
}
}
fn consolidate_labels(&self, tokens: &[Token], aggregation: &LabelAggregationOption) -> (i64, String) {
fn consolidate_labels(
&self,
tokens: &[Token],
aggregation: &LabelAggregationOption,
) -> (i64, String) {
match aggregation {
LabelAggregationOption::First => {
let token = tokens.first().unwrap();
@ -524,22 +671,17 @@ impl TokenClassificationModel {
(token.label_index, token.label.clone())
}
LabelAggregationOption::Mode => {
let counts = tokens
.iter()
.fold(
HashMap::new(),
|mut m, c| {
let counts = tokens.iter().fold(HashMap::new(), |mut m, c| {
*m.entry((c.label_index, c.label.as_str())).or_insert(0) += 1;
m
},
);
});
counts
.into_iter()
.max_by(|a, b| a.1.cmp(&b.1))
.map(|((label_index, label), _)| (label_index, label.to_owned()))
.unwrap()
}
LabelAggregationOption::Custom(function) => function(tokens)
LabelAggregationOption::Custom(function) => function(tokens),
}
}
}

View File

@ -11,7 +11,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
//! # Translation pipeline
//! Translation based on the Marian encoder-decoder architecture
//! Include techniques such as beam search, top-k and nucleus sampling, temperature setting and repetition penalty.
@ -35,9 +34,10 @@
//! ```no_run
//! # fn main() -> failure::Fallible<()> {
//! # use rust_bert::pipelines::generation::LanguageGenerator;
//! use rust_bert::pipelines::translation::{TranslationModel, TranslationConfig, Language};
//! use rust_bert::pipelines::translation::{Language, TranslationConfig, TranslationModel};
//! use tch::Device;
//! let translation_config = TranslationConfig::new(Language::EnglishToFrench, Device::cuda_if_available());
//! let translation_config =
//! TranslationConfig::new(Language::EnglishToFrench, Device::cuda_if_available());
//! let mut model = TranslationModel::new(translation_config)?;
//!
//! let input = ["This is a sentence to be translated"];
@ -54,10 +54,13 @@
//! # ;
//! ```
use crate::pipelines::generation::{MarianGenerator, GenerateConfig, LanguageGenerator};
use crate::common::resources::{RemoteResource, Resource};
use crate::marian::{
MarianConfigResources, MarianModelResources, MarianPrefix, MarianSpmResources,
MarianVocabResources,
};
use crate::pipelines::generation::{GenerateConfig, LanguageGenerator, MarianGenerator};
use tch::Device;
use crate::common::resources::{Resource, RemoteResource};
use crate::marian::{MarianModelResources, MarianConfigResources, MarianVocabResources, MarianSpmResources, MarianPrefix};
/// Pretrained languages available for direct use
pub enum Language {
@ -84,47 +87,244 @@ pub enum Language {
struct RemoteTranslationResources;
impl RemoteTranslationResources {
pub const ENGLISH2FRENCH: ((&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), Option<&'static str>) =
(MarianModelResources::ENGLISH2ROMANCE, MarianConfigResources::ENGLISH2ROMANCE, MarianVocabResources::ENGLISH2ROMANCE, MarianSpmResources::ENGLISH2ROMANCE, MarianPrefix::ENGLISH2FRENCH);
pub const ENGLISH2CATALAN: ((&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), Option<&'static str>) =
(MarianModelResources::ENGLISH2ROMANCE, MarianConfigResources::ENGLISH2ROMANCE, MarianVocabResources::ENGLISH2ROMANCE, MarianSpmResources::ENGLISH2ROMANCE, MarianPrefix::ENGLISH2CATALAN);
pub const ENGLISH2SPANISH: ((&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), Option<&'static str>) =
(MarianModelResources::ENGLISH2ROMANCE, MarianConfigResources::ENGLISH2ROMANCE, MarianVocabResources::ENGLISH2ROMANCE, MarianSpmResources::ENGLISH2ROMANCE, MarianPrefix::ENGLISH2SPANISH);
pub const ENGLISH2PORTUGUESE: ((&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), Option<&'static str>) =
(MarianModelResources::ENGLISH2ROMANCE, MarianConfigResources::ENGLISH2ROMANCE, MarianVocabResources::ENGLISH2ROMANCE, MarianSpmResources::ENGLISH2ROMANCE, MarianPrefix::ENGLISH2PORTUGUESE);
pub const ENGLISH2ITALIAN: ((&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), Option<&'static str>) =
(MarianModelResources::ENGLISH2ROMANCE, MarianConfigResources::ENGLISH2ROMANCE, MarianVocabResources::ENGLISH2ROMANCE, MarianSpmResources::ENGLISH2ROMANCE, MarianPrefix::ENGLISH2ITALIAN);
pub const ENGLISH2ROMANIAN: ((&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), Option<&'static str>) =
(MarianModelResources::ENGLISH2ROMANCE, MarianConfigResources::ENGLISH2ROMANCE, MarianVocabResources::ENGLISH2ROMANCE, MarianSpmResources::ENGLISH2ROMANCE, MarianPrefix::ENGLISH2ROMANIAN);
pub const ENGLISH2GERMAN: ((&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), Option<&'static str>) =
(MarianModelResources::ENGLISH2GERMAN, MarianConfigResources::ENGLISH2GERMAN, MarianVocabResources::ENGLISH2GERMAN, MarianSpmResources::ENGLISH2GERMAN, MarianPrefix::ENGLISH2GERMAN);
pub const ENGLISH2RUSSIAN: ((&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), Option<&'static str>) =
(MarianModelResources::ENGLISH2RUSSIAN, MarianConfigResources::ENGLISH2RUSSIAN, MarianVocabResources::ENGLISH2RUSSIAN, MarianSpmResources::ENGLISH2RUSSIAN, MarianPrefix::ENGLISH2RUSSIAN);
pub const ENGLISH2FRENCH: (
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
) = (
MarianModelResources::ENGLISH2ROMANCE,
MarianConfigResources::ENGLISH2ROMANCE,
MarianVocabResources::ENGLISH2ROMANCE,
MarianSpmResources::ENGLISH2ROMANCE,
MarianPrefix::ENGLISH2FRENCH,
);
pub const ENGLISH2CATALAN: (
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
) = (
MarianModelResources::ENGLISH2ROMANCE,
MarianConfigResources::ENGLISH2ROMANCE,
MarianVocabResources::ENGLISH2ROMANCE,
MarianSpmResources::ENGLISH2ROMANCE,
MarianPrefix::ENGLISH2CATALAN,
);
pub const ENGLISH2SPANISH: (
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
) = (
MarianModelResources::ENGLISH2ROMANCE,
MarianConfigResources::ENGLISH2ROMANCE,
MarianVocabResources::ENGLISH2ROMANCE,
MarianSpmResources::ENGLISH2ROMANCE,
MarianPrefix::ENGLISH2SPANISH,
);
pub const ENGLISH2PORTUGUESE: (
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
) = (
MarianModelResources::ENGLISH2ROMANCE,
MarianConfigResources::ENGLISH2ROMANCE,
MarianVocabResources::ENGLISH2ROMANCE,
MarianSpmResources::ENGLISH2ROMANCE,
MarianPrefix::ENGLISH2PORTUGUESE,
);
pub const ENGLISH2ITALIAN: (
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
) = (
MarianModelResources::ENGLISH2ROMANCE,
MarianConfigResources::ENGLISH2ROMANCE,
MarianVocabResources::ENGLISH2ROMANCE,
MarianSpmResources::ENGLISH2ROMANCE,
MarianPrefix::ENGLISH2ITALIAN,
);
pub const ENGLISH2ROMANIAN: (
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
) = (
MarianModelResources::ENGLISH2ROMANCE,
MarianConfigResources::ENGLISH2ROMANCE,
MarianVocabResources::ENGLISH2ROMANCE,
MarianSpmResources::ENGLISH2ROMANCE,
MarianPrefix::ENGLISH2ROMANIAN,
);
pub const ENGLISH2GERMAN: (
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
) = (
MarianModelResources::ENGLISH2GERMAN,
MarianConfigResources::ENGLISH2GERMAN,
MarianVocabResources::ENGLISH2GERMAN,
MarianSpmResources::ENGLISH2GERMAN,
MarianPrefix::ENGLISH2GERMAN,
);
pub const ENGLISH2RUSSIAN: (
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
) = (
MarianModelResources::ENGLISH2RUSSIAN,
MarianConfigResources::ENGLISH2RUSSIAN,
MarianVocabResources::ENGLISH2RUSSIAN,
MarianSpmResources::ENGLISH2RUSSIAN,
MarianPrefix::ENGLISH2RUSSIAN,
);
pub const FRENCH2ENGLISH: ((&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), Option<&'static str>) =
(MarianModelResources::ROMANCE2ENGLISH, MarianConfigResources::ROMANCE2ENGLISH, MarianVocabResources::ROMANCE2ENGLISH, MarianSpmResources::ROMANCE2ENGLISH, MarianPrefix::FRENCH2ENGLISH);
pub const CATALAN2ENGLISH: ((&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), Option<&'static str>) =
(MarianModelResources::ROMANCE2ENGLISH, MarianConfigResources::ROMANCE2ENGLISH, MarianVocabResources::ROMANCE2ENGLISH, MarianSpmResources::ROMANCE2ENGLISH, MarianPrefix::CATALAN2ENGLISH);
pub const SPANISH2ENGLISH: ((&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), Option<&'static str>) =
(MarianModelResources::ROMANCE2ENGLISH, MarianConfigResources::ROMANCE2ENGLISH, MarianVocabResources::ROMANCE2ENGLISH, MarianSpmResources::ROMANCE2ENGLISH, MarianPrefix::SPANISH2ENGLISH);
pub const PORTUGUESE2ENGLISH: ((&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), Option<&'static str>) =
(MarianModelResources::ROMANCE2ENGLISH, MarianConfigResources::ROMANCE2ENGLISH, MarianVocabResources::ROMANCE2ENGLISH, MarianSpmResources::ROMANCE2ENGLISH, MarianPrefix::PORTUGUESE2ENGLISH);
pub const ITALIAN2ENGLISH: ((&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), Option<&'static str>) =
(MarianModelResources::ROMANCE2ENGLISH, MarianConfigResources::ROMANCE2ENGLISH, MarianVocabResources::ROMANCE2ENGLISH, MarianSpmResources::ROMANCE2ENGLISH, MarianPrefix::ITALIAN2ENGLISH);
pub const ROMANIAN2ENGLISH: ((&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), Option<&'static str>) =
(MarianModelResources::ROMANCE2ENGLISH, MarianConfigResources::ROMANCE2ENGLISH, MarianVocabResources::ROMANCE2ENGLISH, MarianSpmResources::ROMANCE2ENGLISH, MarianPrefix::ROMANIAN2ENGLISH);
pub const GERMAN2ENGLISH: ((&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), Option<&'static str>) =
(MarianModelResources::GERMAN2ENGLISH, MarianConfigResources::GERMAN2ENGLISH, MarianVocabResources::GERMAN2ENGLISH, MarianSpmResources::GERMAN2ENGLISH, MarianPrefix::GERMAN2ENGLISH);
pub const RUSSIAN2ENGLISH: ((&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), Option<&'static str>) =
(MarianModelResources::RUSSIAN2ENGLISH, MarianConfigResources::RUSSIAN2ENGLISH, MarianVocabResources::RUSSIAN2ENGLISH, MarianSpmResources::RUSSIAN2ENGLISH, MarianPrefix::RUSSIAN2ENGLISH);
pub const FRENCH2ENGLISH: (
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
) = (
MarianModelResources::ROMANCE2ENGLISH,
MarianConfigResources::ROMANCE2ENGLISH,
MarianVocabResources::ROMANCE2ENGLISH,
MarianSpmResources::ROMANCE2ENGLISH,
MarianPrefix::FRENCH2ENGLISH,
);
pub const CATALAN2ENGLISH: (
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
) = (
MarianModelResources::ROMANCE2ENGLISH,
MarianConfigResources::ROMANCE2ENGLISH,
MarianVocabResources::ROMANCE2ENGLISH,
MarianSpmResources::ROMANCE2ENGLISH,
MarianPrefix::CATALAN2ENGLISH,
);
pub const SPANISH2ENGLISH: (
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
) = (
MarianModelResources::ROMANCE2ENGLISH,
MarianConfigResources::ROMANCE2ENGLISH,
MarianVocabResources::ROMANCE2ENGLISH,
MarianSpmResources::ROMANCE2ENGLISH,
MarianPrefix::SPANISH2ENGLISH,
);
pub const PORTUGUESE2ENGLISH: (
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
) = (
MarianModelResources::ROMANCE2ENGLISH,
MarianConfigResources::ROMANCE2ENGLISH,
MarianVocabResources::ROMANCE2ENGLISH,
MarianSpmResources::ROMANCE2ENGLISH,
MarianPrefix::PORTUGUESE2ENGLISH,
);
pub const ITALIAN2ENGLISH: (
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
) = (
MarianModelResources::ROMANCE2ENGLISH,
MarianConfigResources::ROMANCE2ENGLISH,
MarianVocabResources::ROMANCE2ENGLISH,
MarianSpmResources::ROMANCE2ENGLISH,
MarianPrefix::ITALIAN2ENGLISH,
);
pub const ROMANIAN2ENGLISH: (
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
) = (
MarianModelResources::ROMANCE2ENGLISH,
MarianConfigResources::ROMANCE2ENGLISH,
MarianVocabResources::ROMANCE2ENGLISH,
MarianSpmResources::ROMANCE2ENGLISH,
MarianPrefix::ROMANIAN2ENGLISH,
);
pub const GERMAN2ENGLISH: (
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
) = (
MarianModelResources::GERMAN2ENGLISH,
MarianConfigResources::GERMAN2ENGLISH,
MarianVocabResources::GERMAN2ENGLISH,
MarianSpmResources::GERMAN2ENGLISH,
MarianPrefix::GERMAN2ENGLISH,
);
pub const RUSSIAN2ENGLISH: (
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
) = (
MarianModelResources::RUSSIAN2ENGLISH,
MarianConfigResources::RUSSIAN2ENGLISH,
MarianVocabResources::RUSSIAN2ENGLISH,
MarianSpmResources::RUSSIAN2ENGLISH,
MarianPrefix::RUSSIAN2ENGLISH,
);
pub const FRENCH2GERMAN: ((&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), Option<&'static str>) =
(MarianModelResources::FRENCH2GERMAN, MarianConfigResources::FRENCH2GERMAN, MarianVocabResources::FRENCH2GERMAN, MarianSpmResources::FRENCH2GERMAN, MarianPrefix::FRENCH2GERMAN);
pub const GERMAN2FRENCH: ((&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), Option<&'static str>) =
(MarianModelResources::GERMAN2FRENCH, MarianConfigResources::GERMAN2FRENCH, MarianVocabResources::GERMAN2FRENCH, MarianSpmResources::GERMAN2FRENCH, MarianPrefix::GERMAN2FRENCH);
pub const FRENCH2GERMAN: (
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
) = (
MarianModelResources::FRENCH2GERMAN,
MarianConfigResources::FRENCH2GERMAN,
MarianVocabResources::FRENCH2GERMAN,
MarianSpmResources::FRENCH2GERMAN,
MarianPrefix::FRENCH2GERMAN,
);
pub const GERMAN2FRENCH: (
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
) = (
MarianModelResources::GERMAN2FRENCH,
MarianConfigResources::GERMAN2FRENCH,
MarianVocabResources::GERMAN2FRENCH,
MarianSpmResources::GERMAN2FRENCH,
MarianPrefix::GERMAN2FRENCH,
);
}
/// # Configuration for text translation
/// Contains information regarding the model to load, mirrors the GenerationConfig, with a
/// different set of default parameters and sets the device to place the model on.
@ -179,16 +379,17 @@ impl TranslationConfig {
///
/// ```no_run
/// # fn main() -> failure::Fallible<()> {
/// use rust_bert::pipelines::translation::{TranslationConfig, Language};
/// use rust_bert::pipelines::translation::{Language, TranslationConfig};
/// use tch::Device;
///
/// let translation_config = TranslationConfig::new(Language::FrenchToEnglish, Device::cuda_if_available());
/// let translation_config =
/// TranslationConfig::new(Language::FrenchToEnglish, Device::cuda_if_available());
/// # Ok(())
/// # }
/// ```
///
pub fn new(language: Language, device: Device) -> TranslationConfig {
let (model_resource, config_resource, vocab_resource, merges_resource, prefix) = match language {
let (model_resource, config_resource, vocab_resource, merges_resource, prefix) =
match language {
Language::EnglishToFrench => RemoteTranslationResources::ENGLISH2FRENCH,
Language::EnglishToCatalan => RemoteTranslationResources::ENGLISH2CATALAN,
Language::EnglishToSpanish => RemoteTranslationResources::ENGLISH2SPANISH,
@ -216,7 +417,7 @@ impl TranslationConfig {
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(merges_resource));
let prefix = match prefix {
Some(value) => Some(value.to_string()),
None => None
None => None,
};
TranslationConfig {
model_resource,
@ -255,31 +456,42 @@ impl TranslationConfig {
/// ```no_run
/// # fn main() -> failure::Fallible<()> {
/// use rust_bert::pipelines::translation::TranslationConfig;
/// use tch::Device;
/// use rust_bert::resources::{Resource, LocalResource};
/// use rust_bert::resources::{LocalResource, Resource};
/// use std::path::PathBuf;
/// use tch::Device;
///
/// let config_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/config.json") });
/// let model_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/model.ot") });
/// let vocab_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/vocab.json") });
/// let sentence_piece_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/spiece.model") });
/// let config_resource = Resource::Local(LocalResource {
/// local_path: PathBuf::from("path/to/config.json"),
/// });
/// let model_resource = Resource::Local(LocalResource {
/// local_path: PathBuf::from("path/to/model.ot"),
/// });
/// let vocab_resource = Resource::Local(LocalResource {
/// local_path: PathBuf::from("path/to/vocab.json"),
/// });
/// let sentence_piece_resource = Resource::Local(LocalResource {
/// local_path: PathBuf::from("path/to/spiece.model"),
/// });
///
/// let translation_config = TranslationConfig::new_from_resources(model_resource,
/// let translation_config = TranslationConfig::new_from_resources(
/// model_resource,
/// config_resource,
/// vocab_resource,
/// sentence_piece_resource,
/// Some(">>fr<<".to_string()),
/// Device::cuda_if_available());
/// Device::cuda_if_available(),
/// );
/// # Ok(())
/// # }
/// ```
///
pub fn new_from_resources(model_resource: Resource,
pub fn new_from_resources(
model_resource: Resource,
config_resource: Resource,
vocab_resource: Resource,
sentence_piece_resource: Resource,
prefix: Option<String>,
device: Device) -> TranslationConfig {
device: Device,
) -> TranslationConfig {
TranslationConfig {
model_resource,
config_resource,
@ -320,17 +532,16 @@ impl TranslationModel {
///
/// ```no_run
/// # fn main() -> failure::Fallible<()> {
/// use rust_bert::pipelines::translation::{TranslationModel, TranslationConfig, Language};
/// use rust_bert::pipelines::translation::{Language, TranslationConfig, TranslationModel};
/// use tch::Device;
///
/// let translation_config = TranslationConfig::new(Language::FrenchToEnglish, Device::cuda_if_available());
/// let translation_config =
/// TranslationConfig::new(Language::FrenchToEnglish, Device::cuda_if_available());
/// let mut summarization_model = TranslationModel::new(translation_config)?;
/// # Ok(())
/// # }
/// ```
///
pub fn new(translation_config: TranslationConfig)
-> failure::Fallible<TranslationModel> {
pub fn new(translation_config: TranslationConfig) -> failure::Fallible<TranslationModel> {
let generate_config = GenerateConfig {
model_resource: translation_config.model_resource,
config_resource: translation_config.config_resource,
@ -353,7 +564,10 @@ impl TranslationModel {
let model = MarianGenerator::new(generate_config)?;
Ok(TranslationModel { model, prefix: translation_config.prefix })
Ok(TranslationModel {
model,
prefix: translation_config.prefix,
})
}
/// Translates texts provided
@ -370,10 +584,11 @@ impl TranslationModel {
/// ```no_run
/// # fn main() -> failure::Fallible<()> {
/// use rust_bert::pipelines::generation::LanguageGenerator;
/// use rust_bert::pipelines::translation::{TranslationModel, TranslationConfig, Language};
/// use rust_bert::pipelines::translation::{Language, TranslationConfig, TranslationModel};
/// use tch::Device;
///
/// let translation_config = TranslationConfig::new(Language::EnglishToFrench, Device::cuda_if_available());
/// let translation_config =
/// TranslationConfig::new(Language::EnglishToFrench, Device::cuda_if_available());
/// let model = TranslationModel::new(translation_config)?;
///
/// let input = ["This is a sentence to be translated"];
@ -382,17 +597,17 @@ impl TranslationModel {
/// # Ok(())
/// # }
/// ```
///
pub fn translate(&self, texts: &[&str]) -> Vec<String> {
match &self.prefix {
Some(value) => {
let texts: Vec<String> = texts
.into_iter()
.map(|&v| { format!("{} {}", value, v) })
.map(|&v| format!("{} {}", value, v))
.collect();
self.model.generate(Some(texts.iter().map(AsRef::as_ref).collect()), None)
self.model
.generate(Some(texts.iter().map(AsRef::as_ref).collect()), None)
}
None => self.model.generate(Some(texts.to_vec()), None)
None => self.model.generate(Some(texts.to_vec()), None),
}
}
}

View File

@ -11,10 +11,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use tch::{nn, Tensor, Kind};
use crate::common::dropout::Dropout;
use tch::nn::{EmbeddingConfig, embedding};
use crate::bert::{BertConfig, BertEmbedding};
use crate::common::dropout::Dropout;
use tch::nn::{embedding, EmbeddingConfig};
use tch::{nn, Kind, Tensor};
#[derive(Debug)]
/// # BertEmbeddings implementation for RoBERTa model
@ -36,8 +36,12 @@ impl RobertaEmbeddings {
fn create_position_ids_from_embeddings(&self, x: &Tensor) -> Tensor {
let input_shape = x.size();
let input_shape = vec!(input_shape[0], input_shape[1]);
let position_ids: Tensor = Tensor::arange1(self.padding_index + 1, input_shape[0], (Kind::Int64, x.device()));
let input_shape = vec![input_shape[0], input_shape[1]];
let position_ids: Tensor = Tensor::arange1(
self.padding_index + 1,
input_shape[0],
(Kind::Int64, x.device()),
);
position_ids.unsqueeze(0).expand(&input_shape, true)
}
}
@ -54,10 +58,10 @@ impl BertEmbedding for RobertaEmbeddings {
///
/// ```no_run
/// use rust_bert::bert::{BertConfig, BertEmbedding};
/// use tch::{nn, Device};
/// use rust_bert::roberta::RobertaEmbeddings;
/// use rust_bert::Config;
/// use std::path::Path;
/// use rust_bert::roberta::RobertaEmbeddings;
/// use tch::{nn, Device};
///
/// let config_path = Path::new("path/to/config.json");
/// let device = Device::Cpu;
@ -65,29 +69,48 @@ impl BertEmbedding for RobertaEmbeddings {
/// let config = BertConfig::from_file(config_path);
/// let robert_embeddings = RobertaEmbeddings::new(&(&p.root() / "bert_embeddings"), &config);
/// ```
///
fn new(p: &nn::Path, config: &BertConfig) -> RobertaEmbeddings {
let embedding_config = EmbeddingConfig { padding_idx: 1, ..Default::default() };
let embedding_config = EmbeddingConfig {
padding_idx: 1,
..Default::default()
};
let word_embeddings: nn::Embedding = embedding(p / "word_embeddings",
let word_embeddings: nn::Embedding = embedding(
p / "word_embeddings",
config.vocab_size,
config.hidden_size,
embedding_config);
embedding_config,
);
let position_embeddings: nn::Embedding = embedding(p / "position_embeddings",
let position_embeddings: nn::Embedding = embedding(
p / "position_embeddings",
config.max_position_embeddings,
config.hidden_size,
Default::default());
Default::default(),
);
let token_type_embeddings: nn::Embedding = embedding(p / "token_type_embeddings",
let token_type_embeddings: nn::Embedding = embedding(
p / "token_type_embeddings",
config.type_vocab_size,
config.hidden_size,
Default::default());
Default::default(),
);
let layer_norm_config = nn::LayerNormConfig { eps: 1e-12, ..Default::default() };
let layer_norm: nn::LayerNorm = nn::layer_norm(p / "LayerNorm", vec![config.hidden_size], layer_norm_config);
let layer_norm_config = nn::LayerNormConfig {
eps: 1e-12,
..Default::default()
};
let layer_norm: nn::LayerNorm =
nn::layer_norm(p / "LayerNorm", vec![config.hidden_size], layer_norm_config);
let dropout: Dropout = Dropout::new(config.hidden_dropout_prob);
RobertaEmbeddings { word_embeddings, position_embeddings, token_type_embeddings, layer_norm, dropout, padding_index: 1 }
RobertaEmbeddings {
word_embeddings,
position_embeddings,
token_type_embeddings,
layer_norm,
dropout,
padding_index: 1,
}
}
/// Forward pass through the embedding layer.
@ -123,53 +146,67 @@ impl BertEmbedding for RobertaEmbeddings {
/// let (batch_size, sequence_length) = (64, 128);
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
/// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
/// let position_ids = Tensor::arange(sequence_length, (Int64, device)).expand(&[batch_size, sequence_length], true);
/// let position_ids = Tensor::arange(sequence_length, (Int64, device))
/// .expand(&[batch_size, sequence_length], true);
///
/// let embedded_output = no_grad(|| {
/// roberta_embeddings
/// .forward_t(Some(input_tensor),
/// .forward_t(
/// Some(input_tensor),
/// Some(token_type_ids),
/// Some(position_ids),
/// None,
/// false).unwrap()
/// false,
/// )
/// .unwrap()
/// });
/// ```
///
fn forward_t(&self,
fn forward_t(
&self,
input_ids: Option<Tensor>,
token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool) -> Result<Tensor, &'static str> {
train: bool,
) -> Result<Tensor, &'static str> {
let (input_embeddings, input_shape) = match &input_ids {
Some(input_value) => match &input_embeds {
Some(_) => { return Err("Only one of input ids or input embeddings may be set"); }
None => (input_value.apply_t(&self.word_embeddings, train), input_value.size())
Some(_) => {
return Err("Only one of input ids or input embeddings may be set");
}
None => (
input_value.apply_t(&self.word_embeddings, train),
input_value.size(),
),
},
None => match &input_embeds {
Some(embeds) => (embeds.copy(), vec!(embeds.size()[0], embeds.size()[1])),
None => { return Err("Only one of input ids or input embeddings may be set"); }
Some(embeds) => (embeds.copy(), vec![embeds.size()[0], embeds.size()[1]]),
None => {
return Err("Only one of input ids or input embeddings may be set");
}
},
};
let position_ids = match position_ids {
Some(value) => value,
None => match input_ids {
Some(value) => self.create_position_ids_from_input_ids(&value),
None => self.create_position_ids_from_embeddings(&input_embeds.unwrap())
}
None => self.create_position_ids_from_embeddings(&input_embeds.unwrap()),
},
};
let token_type_ids = match token_type_ids {
Some(value) => value,
None => Tensor::zeros(&input_shape, (Kind::Int64, input_embeddings.device()))
None => Tensor::zeros(&input_shape, (Kind::Int64, input_embeddings.device())),
};
let position_embeddings = position_ids.apply(&self.position_embeddings);
let token_type_embeddings = token_type_ids.apply(&self.token_type_embeddings);
let input_embeddings: Tensor = input_embeddings + position_embeddings + token_type_embeddings;
Ok(input_embeddings.apply(&self.layer_norm).apply_t(&self.dropout, train))
let input_embeddings: Tensor =
input_embeddings + position_embeddings + token_type_embeddings;
Ok(input_embeddings
.apply(&self.layer_norm)
.apply_t(&self.dropout, train))
}
}

View File

@ -25,14 +25,22 @@
//! use tch::{nn, Device};
//! # use std::path::PathBuf;
//! use rust_bert::bert::BertConfig;
//! use rust_bert::Config;
//! use rust_bert::resources::{download_resource, LocalResource, Resource};
//! use rust_bert::roberta::RobertaForMaskedLM;
//! use rust_bert::resources::{Resource, download_resource, LocalResource};
//! use rust_bert::Config;
//!
//! let config_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/config.json")});
//! let vocab_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/vocab.txt")});
//! let merges_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/vocab.txt")});
//! let weights_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/model.ot")});
//! let config_resource = Resource::Local(LocalResource {
//! local_path: PathBuf::from("path/to/config.json"),
//! });
//! let vocab_resource = Resource::Local(LocalResource {
//! local_path: PathBuf::from("path/to/vocab.txt"),
//! });
//! let merges_resource = Resource::Local(LocalResource {
//! local_path: PathBuf::from("path/to/vocab.txt"),
//! });
//! let weights_resource = Resource::Local(LocalResource {
//! local_path: PathBuf::from("path/to/model.ot"),
//! });
//! let config_path = download_resource(&config_resource)?;
//! let vocab_path = download_resource(&vocab_resource)?;
//! let merges_path = download_resource(&merges_resource)?;
@ -40,7 +48,11 @@
//!
//! let device = Device::cuda_if_available();
//! let mut vs = nn::VarStore::new(device);
//! let tokenizer: RobertaTokenizer = RobertaTokenizer::from_file(vocab_path.to_str().unwrap(), merges_path.to_str().unwrap(), true);
//! let tokenizer: RobertaTokenizer = RobertaTokenizer::from_file(
//! vocab_path.to_str().unwrap(),
//! merges_path.to_str().unwrap(),
//! true,
//! );
//! let config = BertConfig::from_file(config_path);
//! let bert_model = RobertaForMaskedLM::new(&vs.root(), &config);
//! vs.load(weights_path)?;
@ -49,10 +61,12 @@
//! # }
//! ```
mod embeddings;
mod roberta;
pub use roberta::{RobertaModelResources, RobertaConfigResources, RobertaVocabResources, RobertaMergesResources,
RobertaForMaskedLM, RobertaForMultipleChoice, RobertaForTokenClassification, RobertaForQuestionAnswering, RobertaForSequenceClassification};
pub use embeddings::RobertaEmbeddings;
pub use roberta::{
RobertaConfigResources, RobertaForMaskedLM, RobertaForMultipleChoice,
RobertaForQuestionAnswering, RobertaForSequenceClassification, RobertaForTokenClassification,
RobertaMergesResources, RobertaModelResources, RobertaVocabResources,
};

View File

@ -11,13 +11,13 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use tch::{nn, Tensor};
use crate::common::linear::{linear_no_bias, LinearNoBias};
use tch::nn::Init;
use crate::common::activations::_gelu;
use crate::roberta::embeddings::RobertaEmbeddings;
use crate::common::dropout::Dropout;
use crate::bert::{BertConfig, BertModel};
use crate::common::activations::_gelu;
use crate::common::dropout::Dropout;
use crate::common::linear::{linear_no_bias, LinearNoBias};
use crate::roberta::embeddings::RobertaEmbeddings;
use tch::nn::Init;
use tch::{nn, Tensor};
/// # RoBERTa Pretrained model weight files
pub struct RobertaModelResources;
@ -33,22 +33,34 @@ pub struct RobertaMergesResources;
impl RobertaModelResources {
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
pub const ROBERTA: (&'static str, &'static str) = ("roberta/model.ot", "https://cdn.huggingface.co/roberta-base-rust_model.ot");
pub const ROBERTA: (&'static str, &'static str) = (
"roberta/model.ot",
"https://cdn.huggingface.co/roberta-base-rust_model.ot",
);
}
impl RobertaConfigResources {
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
pub const ROBERTA: (&'static str, &'static str) = ("roberta/config.json", "https://cdn.huggingface.co/roberta-base-config.json");
pub const ROBERTA: (&'static str, &'static str) = (
"roberta/config.json",
"https://cdn.huggingface.co/roberta-base-config.json",
);
}
impl RobertaVocabResources {
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
pub const ROBERTA: (&'static str, &'static str) = ("roberta/vocab.txt", "https://cdn.huggingface.co/roberta-base-vocab.json");
pub const ROBERTA: (&'static str, &'static str) = (
"roberta/vocab.txt",
"https://cdn.huggingface.co/roberta-base-vocab.json",
);
}
impl RobertaMergesResources {
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
pub const ROBERTA: (&'static str, &'static str) = ("roberta/merges.txt", "https://cdn.huggingface.co/roberta-base-merges.txt");
pub const ROBERTA: (&'static str, &'static str) = (
"roberta/merges.txt",
"https://cdn.huggingface.co/roberta-base-merges.txt",
);
}
pub struct RobertaLMHead {
@ -60,17 +72,42 @@ pub struct RobertaLMHead {
impl RobertaLMHead {
pub fn new(p: &nn::Path, config: &BertConfig) -> RobertaLMHead {
let dense = nn::linear(p / "dense", config.hidden_size, config.hidden_size, Default::default());
let layer_norm_config = nn::LayerNormConfig { eps: 1e-12, ..Default::default() };
let layer_norm = nn::layer_norm(p / "layer_norm", vec![config.hidden_size], layer_norm_config);
let decoder = linear_no_bias(&(p / "decoder"), config.hidden_size, config.vocab_size, Default::default());
let dense = nn::linear(
p / "dense",
config.hidden_size,
config.hidden_size,
Default::default(),
);
let layer_norm_config = nn::LayerNormConfig {
eps: 1e-12,
..Default::default()
};
let layer_norm = nn::layer_norm(
p / "layer_norm",
vec![config.hidden_size],
layer_norm_config,
);
let decoder = linear_no_bias(
&(p / "decoder"),
config.hidden_size,
config.vocab_size,
Default::default(),
);
let bias = p.var("bias", &[config.vocab_size], Init::KaimingUniform);
RobertaLMHead { dense, decoder, layer_norm, bias }
RobertaLMHead {
dense,
decoder,
layer_norm,
bias,
}
}
pub fn forward(&self, hidden_states: &Tensor) -> Tensor {
(_gelu(&hidden_states.apply(&self.dense))).apply(&self.layer_norm).apply(&self.decoder) + &self.bias
(_gelu(&hidden_states.apply(&self.dense)))
.apply(&self.layer_norm)
.apply(&self.decoder)
+ &self.bias
}
}
@ -96,10 +133,10 @@ impl RobertaForMaskedLM {
///
/// ```no_run
/// use rust_bert::bert::BertConfig;
/// use tch::{nn, Device};
/// use rust_bert::roberta::RobertaForMaskedLM;
/// use rust_bert::Config;
/// use std::path::Path;
/// use rust_bert::roberta::RobertaForMaskedLM;
/// use tch::{nn, Device};
///
/// let config_path = Path::new("path/to/config.json");
/// let device = Device::Cpu;
@ -107,7 +144,6 @@ impl RobertaForMaskedLM {
/// let config = BertConfig::from_file(config_path);
/// let roberta = RobertaForMaskedLM::new(&(&p.root() / "roberta"), &config);
/// ```
///
pub fn new(p: &nn::Path, config: &BertConfig) -> RobertaForMaskedLM {
let roberta = BertModel::<RobertaEmbeddings>::new(&(p / "roberta"), config);
let lm_head = RobertaLMHead::new(&(p / "lm_head"), config);
@ -153,23 +189,24 @@ impl RobertaForMaskedLM {
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
/// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
/// let position_ids = Tensor::arange(sequence_length, (Int64, device)).expand(&[batch_size, sequence_length], true);
/// let position_ids = Tensor::arange(sequence_length, (Int64, device))
/// .expand(&[batch_size, sequence_length], true);
///
/// let (output, all_hidden_states, all_attentions) = no_grad(|| {
/// roberta_model
/// .forward_t(Some(input_tensor),
/// roberta_model.forward_t(
/// Some(input_tensor),
/// Some(mask),
/// Some(token_type_ids),
/// Some(position_ids),
/// None,
/// &None,
/// &None,
/// false)
/// false,
/// )
/// });
///
/// ```
///
pub fn forward_t(&self,
pub fn forward_t(
&self,
input_ids: Option<Tensor>,
mask: Option<Tensor>,
token_type_ids: Option<Tensor>,
@ -177,9 +214,21 @@ impl RobertaForMaskedLM {
input_embeds: Option<Tensor>,
encoder_hidden_states: &Option<Tensor>,
encoder_mask: &Option<Tensor>,
train: bool) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
let (hidden_state, _, all_hidden_states, all_attentions) = self.roberta.forward_t(input_ids, mask, token_type_ids, position_ids,
input_embeds, encoder_hidden_states, encoder_mask, train).unwrap();
train: bool,
) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
let (hidden_state, _, all_hidden_states, all_attentions) = self
.roberta
.forward_t(
input_ids,
mask,
token_type_ids,
position_ids,
input_embeds,
encoder_hidden_states,
encoder_mask,
train,
)
.unwrap();
let prediction_scores = self.lm_head.forward(&hidden_state);
(prediction_scores, all_hidden_states, all_attentions)
@ -194,12 +243,30 @@ pub struct RobertaClassificationHead {
impl RobertaClassificationHead {
pub fn new(p: &nn::Path, config: &BertConfig) -> RobertaClassificationHead {
let dense = nn::linear(p / "dense", config.hidden_size, config.hidden_size, Default::default());
let num_labels = config.id2label.as_ref().expect("num_labels not provided in configuration").len() as i64;
let out_proj = nn::linear(p / "out_proj", config.hidden_size, num_labels, Default::default());
let dense = nn::linear(
p / "dense",
config.hidden_size,
config.hidden_size,
Default::default(),
);
let num_labels = config
.id2label
.as_ref()
.expect("num_labels not provided in configuration")
.len() as i64;
let out_proj = nn::linear(
p / "out_proj",
config.hidden_size,
num_labels,
Default::default(),
);
let dropout = Dropout::new(config.hidden_dropout_prob);
RobertaClassificationHead { dense, dropout, out_proj }
RobertaClassificationHead {
dense,
dropout,
out_proj,
}
}
pub fn forward_t(&self, hidden_states: &Tensor, train: bool) -> Tensor {
@ -235,10 +302,10 @@ impl RobertaForSequenceClassification {
///
/// ```no_run
/// use rust_bert::bert::BertConfig;
/// use tch::{nn, Device};
/// use rust_bert::roberta::RobertaForSequenceClassification;
/// use rust_bert::Config;
/// use std::path::Path;
/// use rust_bert::roberta::RobertaForSequenceClassification;
/// use tch::{nn, Device};
///
/// let config_path = Path::new("path/to/config.json");
/// let device = Device::Cpu;
@ -246,12 +313,14 @@ impl RobertaForSequenceClassification {
/// let config = BertConfig::from_file(config_path);
/// let roberta = RobertaForSequenceClassification::new(&(&p.root() / "roberta"), &config);
/// ```
///
pub fn new(p: &nn::Path, config: &BertConfig) -> RobertaForSequenceClassification {
let roberta = BertModel::<RobertaEmbeddings>::new(&(p / "roberta"), config);
let classifier = RobertaClassificationHead::new(&(p / "classifier"), config);
RobertaForSequenceClassification { roberta, classifier }
RobertaForSequenceClassification {
roberta,
classifier,
}
}
/// Forward pass through the model
@ -290,29 +359,42 @@ impl RobertaForSequenceClassification {
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
/// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
/// let position_ids = Tensor::arange(sequence_length, (Int64, device)).expand(&[batch_size, sequence_length], true);
/// let position_ids = Tensor::arange(sequence_length, (Int64, device))
/// .expand(&[batch_size, sequence_length], true);
///
/// let (labels, all_hidden_states, all_attentions) = no_grad(|| {
/// roberta_model
/// .forward_t(Some(input_tensor),
/// roberta_model.forward_t(
/// Some(input_tensor),
/// Some(mask),
/// Some(token_type_ids),
/// Some(position_ids),
/// None,
/// false)
/// false,
/// )
/// });
///
/// ```
///
pub fn forward_t(&self,
pub fn forward_t(
&self,
input_ids: Option<Tensor>,
mask: Option<Tensor>,
token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
let (hidden_state, _, all_hidden_states, all_attentions) = self.roberta.forward_t(input_ids, mask, token_type_ids, position_ids,
input_embeds, &None, &None, train).unwrap();
train: bool,
) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
let (hidden_state, _, all_hidden_states, all_attentions) = self
.roberta
.forward_t(
input_ids,
mask,
token_type_ids,
position_ids,
input_embeds,
&None,
&None,
train,
)
.unwrap();
let output = self.classifier.forward_t(&hidden_state, train);
(output, all_hidden_states, all_attentions)
@ -344,10 +426,10 @@ impl RobertaForMultipleChoice {
///
/// ```no_run
/// use rust_bert::bert::BertConfig;
/// use tch::{nn, Device};
/// use rust_bert::roberta::RobertaForMultipleChoice;
/// use rust_bert::Config;
/// use std::path::Path;
/// use rust_bert::roberta::RobertaForMultipleChoice;
/// use tch::{nn, Device};
///
/// let config_path = Path::new("path/to/config.json");
/// let device = Device::Cpu;
@ -355,13 +437,16 @@ impl RobertaForMultipleChoice {
/// let config = BertConfig::from_file(config_path);
/// let roberta = RobertaForMultipleChoice::new(&(&p.root() / "roberta"), &config);
/// ```
///
pub fn new(p: &nn::Path, config: &BertConfig) -> RobertaForMultipleChoice {
let roberta = BertModel::<RobertaEmbeddings>::new(&(p / "roberta"), config);
let dropout = Dropout::new(config.hidden_dropout_prob);
let classifier = nn::linear(p / "classifier", config.hidden_size, 1, Default::default());
RobertaForMultipleChoice { roberta, dropout, classifier }
RobertaForMultipleChoice {
roberta,
dropout,
classifier,
}
}
/// Forward pass through the model
@ -399,45 +484,61 @@ impl RobertaForMultipleChoice {
/// let input_tensor = Tensor::rand(&[num_choices, sequence_length], (Int64, device));
/// let mask = Tensor::zeros(&[num_choices, sequence_length], (Int64, device));
/// let token_type_ids = Tensor::zeros(&[num_choices, sequence_length], (Int64, device));
/// let position_ids = Tensor::arange(sequence_length, (Int64, device)).expand(&[num_choices, sequence_length], true);
/// let position_ids = Tensor::arange(sequence_length, (Int64, device))
/// .expand(&[num_choices, sequence_length], true);
///
/// let (choices, all_hidden_states, all_attentions) = no_grad(|| {
/// roberta_model
/// .forward_t(input_tensor,
/// roberta_model.forward_t(
/// input_tensor,
/// Some(mask),
/// Some(token_type_ids),
/// Some(position_ids),
/// false)
/// false,
/// )
/// });
///
/// ```
///
pub fn forward_t(&self,
pub fn forward_t(
&self,
input_ids: Tensor,
mask: Option<Tensor>,
token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>,
train: bool) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
train: bool,
) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
let num_choices = input_ids.size()[1];
let flat_input_ids = Some(input_ids.view((-1i64, *input_ids.size().last().unwrap())));
let flat_position_ids = match position_ids {
Some(value) => Some(value.view((-1i64, *value.size().last().unwrap()))),
None => None
None => None,
};
let flat_token_type_ids = match token_type_ids {
Some(value) => Some(value.view((-1i64, *value.size().last().unwrap()))),
None => None
None => None,
};
let flat_mask = match mask {
Some(value) => Some(value.view((-1i64, *value.size().last().unwrap()))),
None => None
None => None,
};
let (_, pooled_output, all_hidden_states, all_attentions) = self.roberta.forward_t(flat_input_ids, flat_mask, flat_token_type_ids, flat_position_ids,
None, &None, &None, train).unwrap();
let (_, pooled_output, all_hidden_states, all_attentions) = self
.roberta
.forward_t(
flat_input_ids,
flat_mask,
flat_token_type_ids,
flat_position_ids,
None,
&None,
&None,
train,
)
.unwrap();
let output = pooled_output.apply_t(&self.dropout, train).apply(&self.classifier).view((-1, num_choices));
let output = pooled_output
.apply_t(&self.dropout, train)
.apply(&self.classifier)
.view((-1, num_choices));
(output, all_hidden_states, all_attentions)
}
}
@ -466,10 +567,10 @@ impl RobertaForTokenClassification {
///
/// ```no_run
/// use rust_bert::bert::BertConfig;
/// use tch::{nn, Device};
/// use rust_bert::roberta::RobertaForTokenClassification;
/// use rust_bert::Config;
/// use std::path::Path;
/// use rust_bert::roberta::RobertaForTokenClassification;
/// use tch::{nn, Device};
///
/// let config_path = Path::new("path/to/config.json");
/// let device = Device::Cpu;
@ -477,14 +578,26 @@ impl RobertaForTokenClassification {
/// let config = BertConfig::from_file(config_path);
/// let roberta = RobertaForTokenClassification::new(&(&p.root() / "roberta"), &config);
/// ```
///
pub fn new(p: &nn::Path, config: &BertConfig) -> RobertaForTokenClassification {
let roberta = BertModel::<RobertaEmbeddings>::new(&(p / "roberta"), config);
let dropout = Dropout::new(config.hidden_dropout_prob);
let num_labels = config.id2label.as_ref().expect("num_labels not provided in configuration").len() as i64;
let classifier = nn::linear(p / "classifier", config.hidden_size, num_labels, Default::default());
let num_labels = config
.id2label
.as_ref()
.expect("num_labels not provided in configuration")
.len() as i64;
let classifier = nn::linear(
p / "classifier",
config.hidden_size,
num_labels,
Default::default(),
);
RobertaForTokenClassification { roberta, dropout, classifier }
RobertaForTokenClassification {
roberta,
dropout,
classifier,
}
}
/// Forward pass through the model
@ -523,31 +636,46 @@ impl RobertaForTokenClassification {
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
/// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
/// let position_ids = Tensor::arange(sequence_length, (Int64, device)).expand(&[batch_size, sequence_length], true);
/// let position_ids = Tensor::arange(sequence_length, (Int64, device))
/// .expand(&[batch_size, sequence_length], true);
///
/// let (token_labels, all_hidden_states, all_attentions) = no_grad(|| {
/// roberta_model
/// .forward_t(Some(input_tensor),
/// roberta_model.forward_t(
/// Some(input_tensor),
/// Some(mask),
/// Some(token_type_ids),
/// Some(position_ids),
/// None,
/// false)
/// false,
/// )
/// });
///
/// ```
///
pub fn forward_t(&self,
pub fn forward_t(
&self,
input_ids: Option<Tensor>,
mask: Option<Tensor>,
token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
let (hidden_state, _, all_hidden_states, all_attentions) = self.roberta.forward_t(input_ids, mask, token_type_ids, position_ids,
input_embeds, &None, &None, train).unwrap();
train: bool,
) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
let (hidden_state, _, all_hidden_states, all_attentions) = self
.roberta
.forward_t(
input_ids,
mask,
token_type_ids,
position_ids,
input_embeds,
&None,
&None,
train,
)
.unwrap();
let sequence_output = hidden_state.apply_t(&self.dropout, train).apply(&self.classifier);
let sequence_output = hidden_state
.apply_t(&self.dropout, train)
.apply(&self.classifier);
(sequence_output, all_hidden_states, all_attentions)
}
}
@ -576,10 +704,10 @@ impl RobertaForQuestionAnswering {
///
/// ```no_run
/// use rust_bert::bert::BertConfig;
/// use tch::{nn, Device};
/// use rust_bert::roberta::RobertaForQuestionAnswering;
/// use rust_bert::Config;
/// use std::path::Path;
/// use rust_bert::roberta::RobertaForQuestionAnswering;
/// use tch::{nn, Device};
///
/// let config_path = Path::new("path/to/config.json");
/// let device = Device::Cpu;
@ -587,13 +715,20 @@ impl RobertaForQuestionAnswering {
/// let config = BertConfig::from_file(config_path);
/// let roberta = RobertaForQuestionAnswering::new(&(&p.root() / "roberta"), &config);
/// ```
///
pub fn new(p: &nn::Path, config: &BertConfig) -> RobertaForQuestionAnswering {
let roberta = BertModel::<RobertaEmbeddings>::new(&(p / "roberta"), config);
let num_labels = 2;
let qa_outputs = nn::linear(p / "qa_outputs", config.hidden_size, num_labels, Default::default());
let qa_outputs = nn::linear(
p / "qa_outputs",
config.hidden_size,
num_labels,
Default::default(),
);
RobertaForQuestionAnswering { roberta, qa_outputs }
RobertaForQuestionAnswering {
roberta,
qa_outputs,
}
}
/// Forward pass through the model
@ -633,29 +768,42 @@ impl RobertaForQuestionAnswering {
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
/// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
/// let position_ids = Tensor::arange(sequence_length, (Int64, device)).expand(&[batch_size, sequence_length], true);
/// let position_ids = Tensor::arange(sequence_length, (Int64, device))
/// .expand(&[batch_size, sequence_length], true);
///
/// let (start_scores, end_scores, all_hidden_states, all_attentions) = no_grad(|| {
/// roberta_model
/// .forward_t(Some(input_tensor),
/// roberta_model.forward_t(
/// Some(input_tensor),
/// Some(mask),
/// Some(token_type_ids),
/// Some(position_ids),
/// None,
/// false)
/// false,
/// )
/// });
///
/// ```
///
pub fn forward_t(&self,
pub fn forward_t(
&self,
input_ids: Option<Tensor>,
mask: Option<Tensor>,
token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool) -> (Tensor, Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
let (hidden_state, _, all_hidden_states, all_attentions) = self.roberta.forward_t(input_ids, mask, token_type_ids, position_ids,
input_embeds, &None, &None, train).unwrap();
train: bool,
) -> (Tensor, Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
let (hidden_state, _, all_hidden_states, all_attentions) = self
.roberta
.forward_t(
input_ids,
mask,
token_type_ids,
position_ids,
input_embeds,
&None,
&None,
train,
)
.unwrap();
let sequence_output = hidden_state.apply(&self.qa_outputs);
let logits = sequence_output.split(1, -1);

View File

@ -1,20 +1,29 @@
extern crate failure;
extern crate dirs;
extern crate failure;
use tch::{Device, nn, Tensor, no_grad};
use rust_tokenizers::{TruncationStrategy, Tokenizer, Vocab, AlbertTokenizer};
use rust_bert::albert::{
AlbertConfig, AlbertConfigResources, AlbertForMaskedLM, AlbertForMultipleChoice,
AlbertForQuestionAnswering, AlbertForSequenceClassification, AlbertForTokenClassification,
AlbertModelResources, AlbertVocabResources,
};
use rust_bert::resources::{download_resource, RemoteResource, Resource};
use rust_bert::Config;
use rust_bert::resources::{Resource, RemoteResource, download_resource};
use rust_bert::albert::{AlbertConfigResources, AlbertVocabResources, AlbertModelResources, AlbertConfig, AlbertForMaskedLM, AlbertForSequenceClassification, AlbertForMultipleChoice, AlbertForTokenClassification, AlbertForQuestionAnswering};
use rust_tokenizers::{AlbertTokenizer, Tokenizer, TruncationStrategy, Vocab};
use std::collections::HashMap;
use tch::{nn, no_grad, Device, Tensor};
#[test]
fn albert_masked_lm() -> failure::Fallible<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(AlbertConfigResources::ALBERT_BASE_V2));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(AlbertVocabResources::ALBERT_BASE_V2));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(AlbertModelResources::ALBERT_BASE_V2));
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
AlbertConfigResources::ALBERT_BASE_V2,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
AlbertVocabResources::ALBERT_BASE_V2,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
AlbertModelResources::ALBERT_BASE_V2,
));
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
let weights_path = download_resource(&weights_resource)?;
@ -22,37 +31,38 @@ fn albert_masked_lm() -> failure::Fallible<()> {
// Set-up masked LM model
let device = Device::Cpu;
let mut vs = nn::VarStore::new(device);
let tokenizer: AlbertTokenizer = AlbertTokenizer::from_file(vocab_path.to_str().unwrap(), true, false);
let tokenizer: AlbertTokenizer =
AlbertTokenizer::from_file(vocab_path.to_str().unwrap(), true, false);
let config = AlbertConfig::from_file(config_path);
let albert_model = AlbertForMaskedLM::new(&vs.root(), &config);
vs.load(weights_path)?;
// Define input
let input = ["Looks like one [MASK] is missing", "It\'s like comparing [MASK] to apples"];
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
let tokenized_input = tokenized_input.
iter().
map(|input| input.token_ids.clone()).
map(|mut input| {
let input = [
"Looks like one [MASK] is missing",
"It\'s like comparing [MASK] to apples",
];
let tokenized_input =
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let tokenized_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]);
input
}).
map(|input|
Tensor::of_slice(&(input))).
collect::<Vec<_>>();
})
.map(|input| Tensor::of_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
let (output, _, _) = no_grad(|| {
albert_model
.forward_t(Some(input_tensor),
None,
None,
None,
None,
false)
});
let (output, _, _) =
no_grad(|| albert_model.forward_t(Some(input_tensor), None, None, None, None, false));
// Print masked tokens
let index_1 = output.get(0).get(4).argmax(0, false);
@ -69,15 +79,20 @@ fn albert_masked_lm() -> failure::Fallible<()> {
#[test]
fn albert_for_sequence_classification() -> failure::Fallible<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(AlbertConfigResources::ALBERT_BASE_V2));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(AlbertVocabResources::ALBERT_BASE_V2));
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
AlbertConfigResources::ALBERT_BASE_V2,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
AlbertVocabResources::ALBERT_BASE_V2,
));
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
// Set-up model
let device = Device::Cpu;
let vs = nn::VarStore::new(device);
let tokenizer: AlbertTokenizer = AlbertTokenizer::from_file(vocab_path.to_str().unwrap(), true, false);
let tokenizer: AlbertTokenizer =
AlbertTokenizer::from_file(vocab_path.to_str().unwrap(), true, false);
let mut config = AlbertConfig::from_file(config_path);
let mut dummy_label_mapping = HashMap::new();
dummy_label_mapping.insert(0, String::from("Positive"));
@ -88,37 +103,42 @@ fn albert_for_sequence_classification() -> failure::Fallible<()> {
config.output_hidden_states = Some(true);
let albert_model = AlbertForSequenceClassification::new(&vs.root(), &config);
// Define input
let input = ["Looks like one thing is missing", "It\'s like comparing oranges to apples"];
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
let tokenized_input = tokenized_input.
iter().
map(|input| input.token_ids.clone()).
map(|mut input| {
let input = [
"Looks like one thing is missing",
"It\'s like comparing oranges to apples",
];
let tokenized_input =
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let tokenized_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]);
input
}).
map(|input|
Tensor::of_slice(&(input))).
collect::<Vec<_>>();
})
.map(|input| Tensor::of_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
let (output, all_hidden_states, all_attentions) = no_grad(|| {
albert_model
.forward_t(Some(input_tensor),
None,
None,
None,
None,
false)
});
let (output, all_hidden_states, all_attentions) =
no_grad(|| albert_model.forward_t(Some(input_tensor), None, None, None, None, false));
assert_eq!(output.size(), &[2, 3]);
assert_eq!(config.num_hidden_layers as usize, all_hidden_states.unwrap().len());
assert_eq!(config.num_hidden_layers as usize, all_attentions.unwrap().len());
assert_eq!(
config.num_hidden_layers as usize,
all_hidden_states.unwrap().len()
);
assert_eq!(
config.num_hidden_layers as usize,
all_attentions.unwrap().len()
);
Ok(())
}
@ -126,50 +146,66 @@ fn albert_for_sequence_classification() -> failure::Fallible<()> {
#[test]
fn albert_for_multiple_choice() -> failure::Fallible<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(AlbertConfigResources::ALBERT_BASE_V2));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(AlbertVocabResources::ALBERT_BASE_V2));
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
AlbertConfigResources::ALBERT_BASE_V2,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
AlbertVocabResources::ALBERT_BASE_V2,
));
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
// Set-up model
let device = Device::Cpu;
let vs = nn::VarStore::new(device);
let tokenizer: AlbertTokenizer = AlbertTokenizer::from_file(vocab_path.to_str().unwrap(), true, false);
let tokenizer: AlbertTokenizer =
AlbertTokenizer::from_file(vocab_path.to_str().unwrap(), true, false);
let mut config = AlbertConfig::from_file(config_path);
config.output_attentions = Some(true);
config.output_hidden_states = Some(true);
let albert_model = AlbertForMultipleChoice::new(&vs.root(), &config);
// Define input
let input = ["Looks like one thing is missing", "It\'s like comparing oranges to apples"];
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
let tokenized_input = tokenized_input.
iter().
map(|input| input.token_ids.clone()).
map(|mut input| {
let input = [
"Looks like one thing is missing",
"It\'s like comparing oranges to apples",
];
let tokenized_input =
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let tokenized_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]);
input
}).
map(|input|
Tensor::of_slice(&(input))).
collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device).unsqueeze(0);
})
.map(|input| Tensor::of_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0)
.to(device)
.unsqueeze(0);
// Forward pass
let (output, all_hidden_states, all_attentions) = no_grad(|| {
albert_model
.forward_t(Some(input_tensor),
None,
None,
None,
None,
false).unwrap()
.forward_t(Some(input_tensor), None, None, None, None, false)
.unwrap()
});
assert_eq!(output.size(), &[1, 2]);
assert_eq!(config.num_hidden_layers as usize, all_hidden_states.unwrap().len());
assert_eq!(config.num_hidden_layers as usize, all_attentions.unwrap().len());
assert_eq!(
config.num_hidden_layers as usize,
all_hidden_states.unwrap().len()
);
assert_eq!(
config.num_hidden_layers as usize,
all_attentions.unwrap().len()
);
Ok(())
}
@ -177,15 +213,20 @@ fn albert_for_multiple_choice() -> failure::Fallible<()> {
#[test]
fn albert_for_token_classification() -> failure::Fallible<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(AlbertConfigResources::ALBERT_BASE_V2));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(AlbertVocabResources::ALBERT_BASE_V2));
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
AlbertConfigResources::ALBERT_BASE_V2,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
AlbertVocabResources::ALBERT_BASE_V2,
));
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
// Set-up model
let device = Device::Cpu;
let vs = nn::VarStore::new(device);
let tokenizer: AlbertTokenizer = AlbertTokenizer::from_file(vocab_path.to_str().unwrap(), true, false);
let tokenizer: AlbertTokenizer =
AlbertTokenizer::from_file(vocab_path.to_str().unwrap(), true, false);
let mut config = AlbertConfig::from_file(config_path);
let mut dummy_label_mapping = HashMap::new();
dummy_label_mapping.insert(0, String::from("O"));
@ -197,37 +238,42 @@ fn albert_for_token_classification() -> failure::Fallible<()> {
config.output_hidden_states = Some(true);
let bert_model = AlbertForTokenClassification::new(&vs.root(), &config);
// Define input
let input = ["Looks like one thing is missing", "It\'s like comparing oranges to apples"];
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
let tokenized_input = tokenized_input.
iter().
map(|input| input.token_ids.clone()).
map(|mut input| {
let input = [
"Looks like one thing is missing",
"It\'s like comparing oranges to apples",
];
let tokenized_input =
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let tokenized_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]);
input
}).
map(|input|
Tensor::of_slice(&(input))).
collect::<Vec<_>>();
})
.map(|input| Tensor::of_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
let (output, all_hidden_states, all_attentions) = no_grad(|| {
bert_model
.forward_t(Some(input_tensor),
None,
None,
None,
None,
false)
});
let (output, all_hidden_states, all_attentions) =
no_grad(|| bert_model.forward_t(Some(input_tensor), None, None, None, None, false));
assert_eq!(output.size(), &[2, 12, 4]);
assert_eq!(config.num_hidden_layers as usize, all_hidden_states.unwrap().len());
assert_eq!(config.num_hidden_layers as usize, all_attentions.unwrap().len());
assert_eq!(
config.num_hidden_layers as usize,
all_hidden_states.unwrap().len()
);
assert_eq!(
config.num_hidden_layers as usize,
all_attentions.unwrap().len()
);
Ok(())
}
@ -235,51 +281,62 @@ fn albert_for_token_classification() -> failure::Fallible<()> {
#[test]
fn albert_for_question_answering() -> failure::Fallible<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(AlbertConfigResources::ALBERT_BASE_V2));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(AlbertVocabResources::ALBERT_BASE_V2));
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
AlbertConfigResources::ALBERT_BASE_V2,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
AlbertVocabResources::ALBERT_BASE_V2,
));
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
// Set-up model
let device = Device::Cpu;
let vs = nn::VarStore::new(device);
let tokenizer: AlbertTokenizer = AlbertTokenizer::from_file(vocab_path.to_str().unwrap(), true, false);
let tokenizer: AlbertTokenizer =
AlbertTokenizer::from_file(vocab_path.to_str().unwrap(), true, false);
let mut config = AlbertConfig::from_file(config_path);
config.output_attentions = Some(true);
config.output_hidden_states = Some(true);
let albert_model = AlbertForQuestionAnswering::new(&vs.root(), &config);
// Define input
let input = ["Looks like one thing is missing", "It\'s like comparing oranges to apples"];
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
let tokenized_input = tokenized_input.
iter().
map(|input| input.token_ids.clone()).
map(|mut input| {
let input = [
"Looks like one thing is missing",
"It\'s like comparing oranges to apples",
];
let tokenized_input =
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let tokenized_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]);
input
}).
map(|input|
Tensor::of_slice(&(input))).
collect::<Vec<_>>();
})
.map(|input| Tensor::of_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
let (start_scores, end_scores, all_hidden_states, all_attentions) = no_grad(|| {
albert_model
.forward_t(Some(input_tensor),
None,
None,
None,
None,
false)
});
let (start_scores, end_scores, all_hidden_states, all_attentions) =
no_grad(|| albert_model.forward_t(Some(input_tensor), None, None, None, None, false));
assert_eq!(start_scores.size(), &[2, 12]);
assert_eq!(end_scores.size(), &[2, 12]);
assert_eq!(config.num_hidden_layers as usize, all_hidden_states.unwrap().len());
assert_eq!(config.num_hidden_layers as usize, all_attentions.unwrap().len());
assert_eq!(
config.num_hidden_layers as usize,
all_hidden_states.unwrap().len()
);
assert_eq!(
config.num_hidden_layers as usize,
all_attentions.unwrap().len()
);
Ok(())
}

View File

@ -1,18 +1,25 @@
use tch::{Device, nn, Tensor};
use rust_tokenizers::{TruncationStrategy, Tokenizer, RobertaTokenizer};
use rust_bert::Config;
use rust_bert::bart::{BartConfig, BartConfigResources, BartVocabResources, BartModelResources, BartMergesResources, BartModel};
use rust_bert::bart::{
BartConfig, BartConfigResources, BartMergesResources, BartModel, BartModelResources,
BartVocabResources,
};
use rust_bert::pipelines::summarization::{SummarizationConfig, SummarizationModel};
use rust_bert::resources::{Resource, RemoteResource, download_resource};
use rust_bert::resources::{download_resource, RemoteResource, Resource};
use rust_bert::Config;
use rust_tokenizers::{RobertaTokenizer, Tokenizer, TruncationStrategy};
use tch::{nn, Device, Tensor};
#[test]
#[cfg_attr(not(feature = "all-tests"), ignore)]
fn bart_lm_model() -> failure::Fallible<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(BartConfigResources::BART));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(BartVocabResources::BART));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(BartMergesResources::BART));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(BartModelResources::BART));
let config_resource =
Resource::Remote(RemoteResource::from_pretrained(BartConfigResources::BART));
let vocab_resource =
Resource::Remote(RemoteResource::from_pretrained(BartVocabResources::BART));
let merges_resource =
Resource::Remote(RemoteResource::from_pretrained(BartMergesResources::BART));
let weights_resource =
Resource::Remote(RemoteResource::from_pretrained(BartModelResources::BART));
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
let merges_path = download_resource(&merges_resource)?;
@ -21,36 +28,38 @@ fn bart_lm_model() -> failure::Fallible<()> {
// Set-up masked LM model
let device = Device::Cpu;
let mut vs = nn::VarStore::new(device);
let tokenizer: RobertaTokenizer = RobertaTokenizer::from_file(vocab_path.to_str().unwrap(), merges_path.to_str().unwrap(), false);
let tokenizer: RobertaTokenizer = RobertaTokenizer::from_file(
vocab_path.to_str().unwrap(),
merges_path.to_str().unwrap(),
false,
);
let config = BartConfig::from_file(config_path);
let bart_model = BartModel::new(&vs.root(), &config, false);
vs.load(weights_path)?;
// Define input
let input = ["One two three four"];
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
let tokenized_input = tokenized_input.
iter().
map(|input| input.token_ids.clone()).
map(|mut input| {
let tokenized_input =
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let tokenized_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]);
input
}).
map(|input|
Tensor::of_slice(&(input))).
collect::<Vec<_>>();
})
.map(|input| Tensor::of_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
let (output, encoder_outputs, _, _, _, _, _) = bart_model.forward_t(
Some(&input_tensor),
None,
None,
None,
None,
None,
false);
let (output, encoder_outputs, _, _, _, _, _) =
bart_model.forward_t(Some(&input_tensor), None, None, None, None, None, false);
assert_eq!(output.size(), vec!(1, 6, 1024));
assert_eq!(encoder_outputs.size(), vec!(1, 6, 1024));
@ -58,11 +67,9 @@ fn bart_lm_model() -> failure::Fallible<()> {
Ok(())
}
#[test]
#[cfg_attr(not(feature = "all-tests"), ignore)]
fn bart_summarization_greedy() -> failure::Fallible<()> {
// Set-up masked LM model
let summarization_config = SummarizationConfig {
num_beams: 1,
@ -107,7 +114,6 @@ about exoplanets like K2-18b."];
#[test]
#[cfg_attr(not(feature = "all-tests"), ignore)]
fn bart_summarization_beam_search() -> failure::Fallible<()> {
// Set-up masked LM model
let summarization_config = SummarizationConfig {
num_beams: 3,

View File

@ -1,22 +1,27 @@
extern crate failure;
extern crate dirs;
extern crate failure;
use tch::{Device, nn, Tensor, no_grad};
use rust_tokenizers::{BertTokenizer, TruncationStrategy, Tokenizer, Vocab};
use rust_bert::Config;
use rust_bert::bert::{BertConfig, BertForMaskedLM, BertForSequenceClassification, BertForMultipleChoice, BertForTokenClassification, BertForQuestionAnswering,
BertConfigResources, BertVocabResources, BertModelResources};
use rust_bert::bert::{
BertConfig, BertConfigResources, BertForMaskedLM, BertForMultipleChoice,
BertForQuestionAnswering, BertForSequenceClassification, BertForTokenClassification,
BertModelResources, BertVocabResources,
};
use rust_bert::pipelines::ner::NERModel;
use rust_bert::resources::{Resource, RemoteResource, download_resource};
use rust_bert::resources::{download_resource, RemoteResource, Resource};
use rust_bert::Config;
use rust_tokenizers::{BertTokenizer, Tokenizer, TruncationStrategy, Vocab};
use std::collections::HashMap;
use tch::{nn, no_grad, Device, Tensor};
#[test]
fn bert_masked_lm() -> failure::Fallible<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(BertModelResources::BERT));
let config_resource =
Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT));
let vocab_resource =
Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT));
let weights_resource =
Resource::Remote(RemoteResource::from_pretrained(BertModelResources::BERT));
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
let weights_path = download_resource(&weights_resource)?;
@ -30,39 +35,47 @@ fn bert_masked_lm() -> failure::Fallible<()> {
vs.load(weights_path)?;
// Define input
let input = ["Looks like one thing is missing", "It\'s like comparing oranges to apples"];
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
let mut tokenized_input = tokenized_input.
iter().
map(|input| input.token_ids.clone()).
map(|mut input| {
let input = [
"Looks like one thing is missing",
"It\'s like comparing oranges to apples",
];
let tokenized_input =
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let mut tokenized_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]);
input
}).
collect::<Vec<_>>();
})
.collect::<Vec<_>>();
// Masking the token [thing] of sentence 1 and [oranges] of sentence 2
tokenized_input[0][4] = 103;
tokenized_input[1][6] = 103;
let tokenized_input = tokenized_input.
iter().
map(|input|
Tensor::of_slice(&(input))).
collect::<Vec<_>>();
let tokenized_input = tokenized_input
.iter()
.map(|input| Tensor::of_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
let (output, _, _) = no_grad(|| {
bert_model
.forward_t(Some(input_tensor),
bert_model.forward_t(
Some(input_tensor),
None,
None,
None,
None,
&None,
&None,
false)
false,
)
});
// Print masked tokens
@ -80,8 +93,10 @@ fn bert_masked_lm() -> failure::Fallible<()> {
#[test]
fn bert_for_sequence_classification() -> failure::Fallible<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT));
let config_resource =
Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT));
let vocab_resource =
Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT));
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
@ -99,37 +114,42 @@ fn bert_for_sequence_classification() -> failure::Fallible<()> {
config.output_hidden_states = Some(true);
let bert_model = BertForSequenceClassification::new(&vs.root(), &config);
// Define input
let input = ["Looks like one thing is missing", "It\'s like comparing oranges to apples"];
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
let tokenized_input = tokenized_input.
iter().
map(|input| input.token_ids.clone()).
map(|mut input| {
let input = [
"Looks like one thing is missing",
"It\'s like comparing oranges to apples",
];
let tokenized_input =
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let tokenized_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]);
input
}).
map(|input|
Tensor::of_slice(&(input))).
collect::<Vec<_>>();
})
.map(|input| Tensor::of_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
let (output, all_hidden_states, all_attentions) = no_grad(|| {
bert_model
.forward_t(Some(input_tensor),
None,
None,
None,
None,
false)
});
let (output, all_hidden_states, all_attentions) =
no_grad(|| bert_model.forward_t(Some(input_tensor), None, None, None, None, false));
assert_eq!(output.size(), &[2, 3]);
assert_eq!(config.num_hidden_layers as usize, all_hidden_states.unwrap().len());
assert_eq!(config.num_hidden_layers as usize, all_attentions.unwrap().len());
assert_eq!(
config.num_hidden_layers as usize,
all_hidden_states.unwrap().len()
);
assert_eq!(
config.num_hidden_layers as usize,
all_attentions.unwrap().len()
);
Ok(())
}
@ -137,8 +157,10 @@ fn bert_for_sequence_classification() -> failure::Fallible<()> {
#[test]
fn bert_for_multiple_choice() -> failure::Fallible<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT));
let config_resource =
Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT));
let vocab_resource =
Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT));
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
@ -152,34 +174,43 @@ fn bert_for_multiple_choice() -> failure::Fallible<()> {
let bert_model = BertForMultipleChoice::new(&vs.root(), &config);
// Define input
let input = ["Looks like one thing is missing", "It\'s like comparing oranges to apples"];
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
let tokenized_input = tokenized_input.
iter().
map(|input| input.token_ids.clone()).
map(|mut input| {
let input = [
"Looks like one thing is missing",
"It\'s like comparing oranges to apples",
];
let tokenized_input =
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let tokenized_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]);
input
}).
map(|input|
Tensor::of_slice(&(input))).
collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device).unsqueeze(0);
})
.map(|input| Tensor::of_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0)
.to(device)
.unsqueeze(0);
// Forward pass
let (output, all_hidden_states, all_attentions) = no_grad(|| {
bert_model
.forward_t(input_tensor,
None,
None,
None,
false)
});
let (output, all_hidden_states, all_attentions) =
no_grad(|| bert_model.forward_t(input_tensor, None, None, None, false));
assert_eq!(output.size(), &[1, 2]);
assert_eq!(config.num_hidden_layers as usize, all_hidden_states.unwrap().len());
assert_eq!(config.num_hidden_layers as usize, all_attentions.unwrap().len());
assert_eq!(
config.num_hidden_layers as usize,
all_hidden_states.unwrap().len()
);
assert_eq!(
config.num_hidden_layers as usize,
all_attentions.unwrap().len()
);
Ok(())
}
@ -187,8 +218,10 @@ fn bert_for_multiple_choice() -> failure::Fallible<()> {
#[test]
fn bert_for_token_classification() -> failure::Fallible<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT));
let config_resource =
Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT));
let vocab_resource =
Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT));
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
@ -207,37 +240,42 @@ fn bert_for_token_classification() -> failure::Fallible<()> {
config.output_hidden_states = Some(true);
let bert_model = BertForTokenClassification::new(&vs.root(), &config);
// Define input
let input = ["Looks like one thing is missing", "It\'s like comparing oranges to apples"];
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
let tokenized_input = tokenized_input.
iter().
map(|input| input.token_ids.clone()).
map(|mut input| {
let input = [
"Looks like one thing is missing",
"It\'s like comparing oranges to apples",
];
let tokenized_input =
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let tokenized_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]);
input
}).
map(|input|
Tensor::of_slice(&(input))).
collect::<Vec<_>>();
})
.map(|input| Tensor::of_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
let (output, all_hidden_states, all_attentions) = no_grad(|| {
bert_model
.forward_t(Some(input_tensor),
None,
None,
None,
None,
false)
});
let (output, all_hidden_states, all_attentions) =
no_grad(|| bert_model.forward_t(Some(input_tensor), None, None, None, None, false));
assert_eq!(output.size(), &[2, 11, 4]);
assert_eq!(config.num_hidden_layers as usize, all_hidden_states.unwrap().len());
assert_eq!(config.num_hidden_layers as usize, all_attentions.unwrap().len());
assert_eq!(
config.num_hidden_layers as usize,
all_hidden_states.unwrap().len()
);
assert_eq!(
config.num_hidden_layers as usize,
all_attentions.unwrap().len()
);
Ok(())
}
@ -245,8 +283,10 @@ fn bert_for_token_classification() -> failure::Fallible<()> {
#[test]
fn bert_for_question_answering() -> failure::Fallible<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT));
let config_resource =
Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT));
let vocab_resource =
Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT));
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
@ -260,36 +300,42 @@ fn bert_for_question_answering() -> failure::Fallible<()> {
let bert_model = BertForQuestionAnswering::new(&vs.root(), &config);
// Define input
let input = ["Looks like one thing is missing", "It\'s like comparing oranges to apples"];
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
let tokenized_input = tokenized_input.
iter().
map(|input| input.token_ids.clone()).
map(|mut input| {
let input = [
"Looks like one thing is missing",
"It\'s like comparing oranges to apples",
];
let tokenized_input =
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let tokenized_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]);
input
}).
map(|input|
Tensor::of_slice(&(input))).
collect::<Vec<_>>();
})
.map(|input| Tensor::of_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
let (start_scores, end_scores, all_hidden_states, all_attentions) = no_grad(|| {
bert_model
.forward_t(Some(input_tensor),
None,
None,
None,
None,
false)
});
let (start_scores, end_scores, all_hidden_states, all_attentions) =
no_grad(|| bert_model.forward_t(Some(input_tensor), None, None, None, None, false));
assert_eq!(start_scores.size(), &[2, 11]);
assert_eq!(end_scores.size(), &[2, 11]);
assert_eq!(config.num_hidden_layers as usize, all_hidden_states.unwrap().len());
assert_eq!(config.num_hidden_layers as usize, all_attentions.unwrap().len());
assert_eq!(
config.num_hidden_layers as usize,
all_hidden_states.unwrap().len()
);
assert_eq!(
config.num_hidden_layers as usize,
all_attentions.unwrap().len()
);
Ok(())
}
@ -302,7 +348,7 @@ fn bert_pre_trained_ner() -> failure::Fallible<()> {
// Define input
let input = [
"My name is Amy. I live in Paris.",
"Paris is a city in France."
"Paris is a city in France.",
];
// Run model

View File

@ -1,13 +1,17 @@
use tch::{Device, Tensor, nn, no_grad};
use rust_tokenizers::preprocessing::tokenizer::base_tokenizer::{Tokenizer, TruncationStrategy};
use rust_tokenizers::bert_tokenizer::BertTokenizer;
use rust_tokenizers::preprocessing::vocab::base_vocab::Vocab;
use rust_bert::Config;
use rust_bert::distilbert::{DistilBertConfig, DistilBertModelMaskedLM, DistilBertForQuestionAnswering, DistilBertForTokenClassification, DistilBertModelResources, DistilBertConfigResources, DistilBertVocabResources};
use rust_bert::distilbert::{
DistilBertConfig, DistilBertConfigResources, DistilBertForQuestionAnswering,
DistilBertForTokenClassification, DistilBertModelMaskedLM, DistilBertModelResources,
DistilBertVocabResources,
};
use rust_bert::pipelines::question_answering::{QaInput, QuestionAnsweringModel};
use rust_bert::pipelines::sentiment::{SentimentModel, SentimentPolarity};
use rust_bert::pipelines::question_answering::{QuestionAnsweringModel, QaInput};
use rust_bert::resources::{Resource, RemoteResource, download_resource};
use rust_bert::resources::{download_resource, RemoteResource, Resource};
use rust_bert::Config;
use rust_tokenizers::bert_tokenizer::BertTokenizer;
use rust_tokenizers::preprocessing::tokenizer::base_tokenizer::{Tokenizer, TruncationStrategy};
use rust_tokenizers::preprocessing::vocab::base_vocab::Vocab;
use std::collections::HashMap;
use tch::{nn, no_grad, Device, Tensor};
extern crate failure;
@ -36,13 +40,18 @@ fn distilbert_sentiment_classifier() -> failure::Fallible<()> {
Ok(())
}
#[test]
fn distilbert_masked_lm() -> failure::Fallible<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(DistilBertConfigResources::DISTIL_BERT));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(DistilBertVocabResources::DISTIL_BERT));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(DistilBertModelResources::DISTIL_BERT));
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
DistilBertConfigResources::DISTIL_BERT,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
DistilBertVocabResources::DISTIL_BERT,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
DistilBertModelResources::DISTIL_BERT,
));
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
let weights_path = download_resource(&weights_resource)?;
@ -56,29 +65,35 @@ fn distilbert_masked_lm() -> failure::Fallible<()> {
vs.load(weights_path)?;
// Define input
let input = ["Looks like one thing is missing", "It\'s like comparing oranges to apples"];
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
let mut tokenized_input = tokenized_input.
iter().
map(|input| input.token_ids.clone()).
map(|mut input| {
let input = [
"Looks like one thing is missing",
"It\'s like comparing oranges to apples",
];
let tokenized_input =
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let mut tokenized_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]);
input
}).
collect::<Vec<_>>();
})
.collect::<Vec<_>>();
// Masking the token [thing] of sentence 1 and [oranges] of sentence 2
tokenized_input[0][4] = 103;
tokenized_input[1][6] = 103;
let tokenized_input = tokenized_input.
iter().
map(|input|
Tensor::of_slice(&(input))).
collect::<Vec<_>>();
let tokenized_input = tokenized_input
.iter()
.map(|input| Tensor::of_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
let (output, _, _) = no_grad(|| {
distil_bert_model
@ -100,10 +115,13 @@ fn distilbert_masked_lm() -> failure::Fallible<()> {
#[test]
fn distilbert_for_question_answering() -> failure::Fallible<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(DistilBertConfigResources::DISTIL_BERT_SQUAD));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(DistilBertVocabResources::DISTIL_BERT_SQUAD));
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
DistilBertConfigResources::DISTIL_BERT_SQUAD,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
DistilBertVocabResources::DISTIL_BERT_SQUAD,
));
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
@ -117,19 +135,26 @@ fn distilbert_for_question_answering() -> failure::Fallible<()> {
let distil_bert_model = DistilBertForQuestionAnswering::new(&vs.root(), &config);
// Define input
let input = ["Looks like one thing is missing", "It\'s like comparing oranges to apples"];
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
let tokenized_input = tokenized_input.
iter().
map(|input| input.token_ids.clone()).
map(|mut input| {
let input = [
"Looks like one thing is missing",
"It\'s like comparing oranges to apples",
];
let tokenized_input =
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let tokenized_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]);
input
}).
map(|input|
Tensor::of_slice(&(input))).
collect::<Vec<_>>();
})
.map(|input| Tensor::of_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
@ -149,10 +174,13 @@ fn distilbert_for_question_answering() -> failure::Fallible<()> {
#[test]
fn distilbert_for_token_classification() -> failure::Fallible<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(DistilBertConfigResources::DISTIL_BERT));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(DistilBertVocabResources::DISTIL_BERT));
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
DistilBertConfigResources::DISTIL_BERT,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
DistilBertVocabResources::DISTIL_BERT,
));
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
@ -172,19 +200,26 @@ fn distilbert_for_token_classification() -> failure::Fallible<()> {
let distil_bert_model = DistilBertForTokenClassification::new(&vs.root(), &config);
// Define input
let input = ["Looks like one thing is missing", "It\'s like comparing oranges to apples"];
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
let tokenized_input = tokenized_input.
iter().
map(|input| input.token_ids.clone()).
map(|mut input| {
let input = [
"Looks like one thing is missing",
"It\'s like comparing oranges to apples",
];
let tokenized_input =
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let tokenized_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]);
input
}).
map(|input|
Tensor::of_slice(&(input))).
collect::<Vec<_>>();
})
.map(|input| Tensor::of_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
@ -211,7 +246,7 @@ fn distilbert_question_answering() -> failure::Fallible<()> {
let context = String::from("Amy lives in Amsterdam");
let qa_input = QaInput { question, context };
let answers = qa_model.predict(&vec!(qa_input), 1, 32);
let answers = qa_model.predict(&vec![qa_input], 1, 32);
assert_eq!(answers.len(), 1 as usize);
assert_eq!(answers[0].len(), 1 as usize);

View File

@ -1,17 +1,28 @@
use tch::{Device, nn, Tensor};
use rust_tokenizers::{Gpt2Tokenizer, TruncationStrategy, Tokenizer};
use rust_bert::gpt2::{
GPT2LMHeadModel, Gpt2Config, Gpt2ConfigResources, Gpt2MergesResources, Gpt2ModelResources,
Gpt2VocabResources,
};
use rust_bert::pipelines::generation::{Cache, LMHeadModel};
use rust_bert::resources::{download_resource, RemoteResource, Resource};
use rust_bert::Config;
use rust_bert::gpt2::{Gpt2Config, GPT2LMHeadModel, Gpt2ConfigResources, Gpt2VocabResources, Gpt2MergesResources, Gpt2ModelResources};
use rust_bert::pipelines::generation::{LMHeadModel, Cache};
use rust_bert::resources::{Resource, download_resource, RemoteResource};
use rust_tokenizers::{Gpt2Tokenizer, Tokenizer, TruncationStrategy};
use tch::{nn, Device, Tensor};
#[test]
fn distilgpt2_lm_model() -> failure::Fallible<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::DISTIL_GPT2));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::DISTIL_GPT2));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::DISTIL_GPT2));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::DISTIL_GPT2));
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
Gpt2ConfigResources::DISTIL_GPT2,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
Gpt2VocabResources::DISTIL_GPT2,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
Gpt2MergesResources::DISTIL_GPT2,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
Gpt2ModelResources::DISTIL_GPT2,
));
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
let merges_path = download_resource(&merges_resource)?;
@ -20,29 +31,38 @@ fn distilgpt2_lm_model() -> failure::Fallible<()> {
// Set-up masked LM model
let device = Device::Cpu;
let mut vs = nn::VarStore::new(device);
let tokenizer: Gpt2Tokenizer = Gpt2Tokenizer::from_file(vocab_path.to_str().unwrap(), merges_path.to_str().unwrap(), false);
let tokenizer: Gpt2Tokenizer = Gpt2Tokenizer::from_file(
vocab_path.to_str().unwrap(),
merges_path.to_str().unwrap(),
false,
);
let config = Gpt2Config::from_file(config_path);
let gpt2_model = GPT2LMHeadModel::new(&vs.root(), &config);
vs.load(weights_path)?;
// Define input
let input = ["One two three four five six seven eight nine ten eleven"];
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
let tokenized_input = tokenized_input.
iter().
map(|input| input.token_ids.clone()).
map(|mut input| {
let tokenized_input =
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let tokenized_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]);
input
}).
map(|input|
Tensor::of_slice(&(input))).
collect::<Vec<_>>();
})
.map(|input| Tensor::of_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
let (output, _, past, _, _) = gpt2_model.forward_t(
let (output, _, past, _, _) = gpt2_model
.forward_t(
&Some(input_tensor),
Cache::None,
&None,
@ -51,21 +71,28 @@ fn distilgpt2_lm_model() -> failure::Fallible<()> {
&None,
None,
&None,
false).unwrap();
false,
)
.unwrap();
let next_word_id = output.get(0).get(-1).argmax(-1, true).int64_value(&[0]);
let next_word = tokenizer.decode(vec!(next_word_id), true, true);
let next_word = tokenizer.decode(vec![next_word_id], true, true);
assert_eq!(output.size(), vec!(1, 11, 50257));
match past {
Cache::GPT2Cache(past) => {
assert!(past.is_some());
assert_eq!(past.as_ref().unwrap().len(), config.n_layer as usize);
assert_eq!(past.as_ref().unwrap()[0].size(), vec!(2, 1, config.n_head, 11, 64));
assert_eq!(
past.as_ref().unwrap()[0].size(),
vec!(2, 1, config.n_head, 11, 64)
);
}
_ => panic!("Wrong cache returned for GPT2")
_ => panic!("Wrong cache returned for GPT2"),
}
assert!((output.double_value(&[0, output.size()[1] - 1, next_word_id]) - (-48.7065)).abs() < 1e-4);
assert!(
(output.double_value(&[0, output.size()[1] - 1, next_word_id]) - (-48.7065)).abs() < 1e-4
);
assert_eq!(next_word_id, 14104i64);
assert_eq!(next_word, String::from(" twelve"));

View File

@ -1,15 +1,24 @@
use rust_bert::resources::{Resource, download_resource, RemoteResource};
use rust_bert::electra::{ElectraConfigResources, ElectraVocabResources, ElectraModelResources, ElectraConfig, ElectraForMaskedLM, ElectraDiscriminator};
use tch::{Device, nn, Tensor, no_grad};
use rust_tokenizers::{BertTokenizer, TruncationStrategy, Tokenizer, Vocab};
use rust_bert::electra::{
ElectraConfig, ElectraConfigResources, ElectraDiscriminator, ElectraForMaskedLM,
ElectraModelResources, ElectraVocabResources,
};
use rust_bert::resources::{download_resource, RemoteResource, Resource};
use rust_bert::Config;
use rust_tokenizers::{BertTokenizer, Tokenizer, TruncationStrategy, Vocab};
use tch::{nn, no_grad, Device, Tensor};
#[test]
fn electra_masked_lm() -> failure::Fallible<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(ElectraConfigResources::BASE_GENERATOR));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(ElectraVocabResources::BASE_GENERATOR));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(ElectraModelResources::BASE_GENERATOR));
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
ElectraConfigResources::BASE_GENERATOR,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
ElectraVocabResources::BASE_GENERATOR,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
ElectraModelResources::BASE_GENERATOR,
));
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
let weights_path = download_resource(&weights_resource)?;
@ -25,33 +34,31 @@ fn electra_masked_lm() -> failure::Fallible<()> {
vs.load(weights_path)?;
// Define input
let input = ["Looks like one [MASK] is missing", "It was a very nice and [MASK] day"];
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
let tokenized_input = tokenized_input.
iter().
map(|input| input.token_ids.clone()).
map(|mut input| {
let input = [
"Looks like one [MASK] is missing",
"It was a very nice and [MASK] day",
];
let tokenized_input =
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let tokenized_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]);
input
}).
map(|input|
Tensor::of_slice(&(input))).
collect::<Vec<_>>();
})
.map(|input| Tensor::of_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
let (output,
all_hidden_states,
all_attentions) = no_grad(|| {
electra_model
.forward_t(Some(input_tensor),
None,
None,
None,
None,
false)
});
let (output, all_hidden_states, all_attentions) =
no_grad(|| electra_model.forward_t(Some(input_tensor), None, None, None, None, false));
// Decode output
let index_1 = output.get(0).get(4).argmax(0, false);
@ -60,8 +67,14 @@ fn electra_masked_lm() -> failure::Fallible<()> {
let word_2 = tokenizer.vocab().id_to_token(&index_2.int64_value(&[]));
assert_eq!(output.size(), &[2, 10, config.vocab_size]);
assert_eq!(config.num_hidden_layers as usize, all_hidden_states.unwrap().len());
assert_eq!(config.num_hidden_layers as usize, all_attentions.unwrap().len());
assert_eq!(
config.num_hidden_layers as usize,
all_hidden_states.unwrap().len()
);
assert_eq!(
config.num_hidden_layers as usize,
all_attentions.unwrap().len()
);
assert_eq!("thing", word_1); // Outputs "person" : "Looks like one [person] is missing"
assert_eq!("sunny", word_2); // Outputs "pear" : "It was a very nice and [sunny] day"
Ok(())
@ -70,9 +83,15 @@ fn electra_masked_lm() -> failure::Fallible<()> {
#[test]
fn electra_discriminator() -> failure::Fallible<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(ElectraConfigResources::BASE_DISCRIMINATOR));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(ElectraVocabResources::BASE_DISCRIMINATOR));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(ElectraModelResources::BASE_DISCRIMINATOR));
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
ElectraConfigResources::BASE_DISCRIMINATOR,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
ElectraVocabResources::BASE_DISCRIMINATOR,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
ElectraModelResources::BASE_DISCRIMINATOR,
));
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
let weights_path = download_resource(&weights_resource)?;
@ -87,33 +106,32 @@ fn electra_discriminator() -> failure::Fallible<()> {
// Define input
let input = ["One Two Three Ten Five Six Seven Eight"];
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
let encoded_input = tokenized_input.
iter().
map(|input| input.token_ids.clone()).
map(|mut input| {
let tokenized_input =
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let encoded_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]);
input
}).
map(|input|
Tensor::of_slice(&(input))).
collect::<Vec<_>>();
})
.map(|input| Tensor::of_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(encoded_input.as_slice(), 0).to(device);
// Forward pass
let (output, _, _) = no_grad(|| {
electra_model
.forward_t(Some(input_tensor),
None,
None,
None,
None,
false)
});
let (output, _, _) =
no_grad(|| electra_model.forward_t(Some(input_tensor), None, None, None, None, false));
// Validate model predictions
let expected_probabilities = vec!(0.0101, 0.0030, 0.0010, 0.0018, 0.9489, 0.0067, 0.0026, 0.0017, 0.0311, 0.0101);
let expected_probabilities = vec![
0.0101, 0.0030, 0.0010, 0.0018, 0.9489, 0.0067, 0.0026, 0.0017, 0.0311, 0.0101,
];
let probabilities = output.iter::<f64>().unwrap().collect::<Vec<f64>>();
assert_eq!(output.size(), &[10]);

View File

@ -1,17 +1,26 @@
use tch::{Device, nn, Tensor};
use rust_tokenizers::{Gpt2Tokenizer, TruncationStrategy, Tokenizer};
use rust_bert::gpt2::{
GPT2LMHeadModel, Gpt2Config, Gpt2ConfigResources, Gpt2MergesResources, Gpt2ModelResources,
Gpt2VocabResources,
};
use rust_bert::pipelines::generation::{
Cache, GPT2Generator, GenerateConfig, LMHeadModel, LanguageGenerator,
};
use rust_bert::resources::{download_resource, RemoteResource, Resource};
use rust_bert::Config;
use rust_bert::pipelines::generation::{GPT2Generator, LanguageGenerator, GenerateConfig, LMHeadModel, Cache};
use rust_bert::gpt2::{Gpt2Config, GPT2LMHeadModel, Gpt2ConfigResources, Gpt2MergesResources, Gpt2VocabResources, Gpt2ModelResources};
use rust_bert::resources::{RemoteResource, Resource, download_resource};
use rust_tokenizers::{Gpt2Tokenizer, Tokenizer, TruncationStrategy};
use tch::{nn, Device, Tensor};
#[test]
fn gpt2_lm_model() -> failure::Fallible<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
let config_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2));
let vocab_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2));
let merges_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
let weights_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
let merges_path = download_resource(&merges_resource)?;
@ -20,29 +29,38 @@ fn gpt2_lm_model() -> failure::Fallible<()> {
// Set-up masked LM model
let device = Device::Cpu;
let mut vs = nn::VarStore::new(device);
let tokenizer: Gpt2Tokenizer = Gpt2Tokenizer::from_file(vocab_path.to_str().unwrap(), merges_path.to_str().unwrap(), false);
let tokenizer: Gpt2Tokenizer = Gpt2Tokenizer::from_file(
vocab_path.to_str().unwrap(),
merges_path.to_str().unwrap(),
false,
);
let config = Gpt2Config::from_file(config_path);
let gpt2_model = GPT2LMHeadModel::new(&vs.root(), &config);
vs.load(weights_path)?;
// Define input
let input = ["One two three four"];
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
let tokenized_input = tokenized_input.
iter().
map(|input| input.token_ids.clone()).
map(|mut input| {
let tokenized_input =
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let tokenized_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]);
input
}).
map(|input|
Tensor::of_slice(&(input))).
collect::<Vec<_>>();
})
.map(|input| Tensor::of_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
let (output, _, past, _, _) = gpt2_model.forward_t(
let (output, _, past, _, _) = gpt2_model
.forward_t(
&Some(input_tensor),
Cache::None,
&None,
@ -51,35 +69,45 @@ fn gpt2_lm_model() -> failure::Fallible<()> {
&None,
None,
&None,
false).unwrap();
false,
)
.unwrap();
let next_word_id = output.get(0).get(-1).argmax(-1, true).int64_value(&[0]);
let next_word = tokenizer.decode(vec!(next_word_id), true, true);
let next_word = tokenizer.decode(vec![next_word_id], true, true);
assert_eq!(output.size(), vec!(1, 4, 50257));
match past {
Cache::GPT2Cache(past) => {
assert!(past.is_some());
assert_eq!(past.as_ref().unwrap().len(), config.n_layer as usize);
assert_eq!(past.as_ref().unwrap()[0].size(), vec!(2, 1, config.n_head, 4, 64));
assert_eq!(
past.as_ref().unwrap()[0].size(),
vec!(2, 1, config.n_head, 4, 64)
);
}
_ => panic!("Wrong cache returned for GPT2")
_ => panic!("Wrong cache returned for GPT2"),
}
assert!((output.double_value(&[0, output.size()[1] - 1, next_word_id]) - (-69.4948)).abs() < 1e-4);
assert!(
(output.double_value(&[0, output.size()[1] - 1, next_word_id]) - (-69.4948)).abs() < 1e-4
);
assert_eq!(next_word_id, 1936i64);
assert_eq!(next_word, String::from(" five"));
Ok(())
}
#[test]
fn gpt2_generation_greedy() -> failure::Fallible<()> {
// Resources definition
let config_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
let model_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
let config_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2));
let vocab_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2));
let merges_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
let model_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
// Set-up masked LM model
let generate_config = GenerateConfig {
@ -97,7 +125,7 @@ fn gpt2_generation_greedy() -> failure::Fallible<()> {
let model = GPT2Generator::new(generate_config)?;
let input_context = "The cat";
let output = model.generate(Some(vec!(input_context)), None);
let output = model.generate(Some(vec![input_context]), None);
assert_eq!(output.len(), 1);
assert_eq!(output[0], "The cat was found in a field near the town of Keflavik, about 30 miles (48 kilometers) south-east of Moscow.\n\n\n");
@ -108,10 +136,14 @@ fn gpt2_generation_greedy() -> failure::Fallible<()> {
#[test]
fn gpt2_generation_beam_search() -> failure::Fallible<()> {
// Resources definition
let config_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
let model_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
let config_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2));
let vocab_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2));
let merges_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
let model_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
// Set-up masked LM model
let generate_config = GenerateConfig {
@ -129,12 +161,21 @@ fn gpt2_generation_beam_search() -> failure::Fallible<()> {
let model = GPT2Generator::new(generate_config)?;
let input_context = "The dog";
let output = model.generate(Some(vec!(input_context)), None);
let output = model.generate(Some(vec![input_context]), None);
assert_eq!(output.len(), 3);
assert_eq!(output[0], "The dog was found in the backyard of a home in the 6200 block of South Main Street.");
assert_eq!(output[1], "The dog was found in the backyard of a home in the 6500 block of South Main Street.");
assert_eq!(output[2], "The dog was found in the backyard of a home in the 6200 block of South Main Street,");
assert_eq!(
output[0],
"The dog was found in the backyard of a home in the 6200 block of South Main Street."
);
assert_eq!(
output[1],
"The dog was found in the backyard of a home in the 6500 block of South Main Street."
);
assert_eq!(
output[2],
"The dog was found in the backyard of a home in the 6200 block of South Main Street,"
);
Ok(())
}
@ -142,10 +183,14 @@ fn gpt2_generation_beam_search() -> failure::Fallible<()> {
#[test]
fn gpt2_generation_beam_search_multiple_prompts_without_padding() -> failure::Fallible<()> {
// Resources definition
let config_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
let model_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
let config_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2));
let vocab_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2));
let merges_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
let model_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
// Set-up masked LM model
let generate_config = GenerateConfig {
@ -164,15 +209,33 @@ fn gpt2_generation_beam_search_multiple_prompts_without_padding() -> failure::Fa
let input_context_1 = "The dog";
let input_context_2 = "The cat";
let output = model.generate(Some(vec!(input_context_1, input_context_2)), None);
let output = model.generate(Some(vec![input_context_1, input_context_2]), None);
assert_eq!(output.len(), 6);
assert_eq!(output[0], "The dog was found in the backyard of a home in the 6200 block of South Main Street.");
assert_eq!(output[1], "The dog was found in the backyard of a home in the 6500 block of South Main Street.");
assert_eq!(output[2], "The dog was found in the backyard of a home in the 6200 block of South Main Street,");
assert_eq!(output[3], "The cat-and-mouse game.\n\n\"I think it\'s going to be interesting to");
assert_eq!(output[4], "The cat-and-mouse game.\n\n\"I think it\'s going to be a very");
assert_eq!(output[5], "The cat-and-mouse game.\n\n\"I think it\'s going to be very interesting");
assert_eq!(
output[0],
"The dog was found in the backyard of a home in the 6200 block of South Main Street."
);
assert_eq!(
output[1],
"The dog was found in the backyard of a home in the 6500 block of South Main Street."
);
assert_eq!(
output[2],
"The dog was found in the backyard of a home in the 6200 block of South Main Street,"
);
assert_eq!(
output[3],
"The cat-and-mouse game.\n\n\"I think it\'s going to be interesting to"
);
assert_eq!(
output[4],
"The cat-and-mouse game.\n\n\"I think it\'s going to be a very"
);
assert_eq!(
output[5],
"The cat-and-mouse game.\n\n\"I think it\'s going to be very interesting"
);
Ok(())
}
@ -180,10 +243,14 @@ fn gpt2_generation_beam_search_multiple_prompts_without_padding() -> failure::Fa
#[test]
fn gpt2_generation_beam_search_multiple_prompts_with_padding() -> failure::Fallible<()> {
// Resources definition
let config_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
let model_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
let config_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2));
let vocab_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2));
let merges_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
let model_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
// Set-up masked LM model
let generate_config = GenerateConfig {
@ -202,15 +269,33 @@ fn gpt2_generation_beam_search_multiple_prompts_with_padding() -> failure::Falli
let input_context_1 = "The dog";
let input_context_2 = "The cat was";
let output = model.generate(Some(vec!(input_context_1, input_context_2)), None);
let output = model.generate(Some(vec![input_context_1, input_context_2]), None);
assert_eq!(output.len(), 6);
assert_eq!(output[0], "The dog was found dead on the side of the road in the middle of the night.\n");
assert_eq!(output[1], "The dog was found dead on the side of the road in the middle of the night on Sunday");
assert_eq!(output[2], "The dog was found dead on the side of the road in the middle of the night on Saturday");
assert_eq!(output[3], "The cat was taken to a local hospital, where it was treated and released.\n\nPolice said");
assert_eq!(output[4], "The cat was taken to a local hospital, where it was treated and released.\n\n\"It");
assert_eq!(output[5], "The cat was taken to a local hospital, where it was treated and released.\n\n\"We");
assert_eq!(
output[0],
"The dog was found dead on the side of the road in the middle of the night.\n"
);
assert_eq!(
output[1],
"The dog was found dead on the side of the road in the middle of the night on Sunday"
);
assert_eq!(
output[2],
"The dog was found dead on the side of the road in the middle of the night on Saturday"
);
assert_eq!(
output[3],
"The cat was taken to a local hospital, where it was treated and released.\n\nPolice said"
);
assert_eq!(
output[4],
"The cat was taken to a local hospital, where it was treated and released.\n\n\"It"
);
assert_eq!(
output[5],
"The cat was taken to a local hospital, where it was treated and released.\n\n\"We"
);
Ok(())
}

View File

@ -1,10 +1,9 @@
use rust_bert::pipelines::translation::{TranslationConfig, Language, TranslationModel};
use rust_bert::pipelines::translation::{Language, TranslationConfig, TranslationModel};
use tch::Device;
#[test]
#[cfg_attr(not(feature = "all-tests"), ignore)]
fn test_translation() -> failure::Fallible<()> {
// Set-up translation model
let translation_config = TranslationConfig::new(Language::EnglishToFrench, Device::Cpu);
let model = TranslationModel::new(translation_config)?;
@ -15,7 +14,10 @@ fn test_translation() -> failure::Fallible<()> {
let output = model.translate(&[input_context_1, input_context_2]);
assert_eq!(output.len(), 2);
assert_eq!(output[0], " Le rapide renard brun saute sur le chien paresseux");
assert_eq!(
output[0],
" Le rapide renard brun saute sur le chien paresseux"
);
assert_eq!(output[1], " Le chien ne s'est pas réveillé.");
Ok(())

View File

@ -1,18 +1,31 @@
use tch::{Device, nn, Tensor};
use rust_tokenizers::{TruncationStrategy, Tokenizer, OpenAiGptTokenizer};
use rust_bert::Config;
use rust_bert::pipelines::generation::{OpenAIGenerator, LanguageGenerator, GenerateConfig, LMHeadModel, Cache};
use rust_bert::gpt2::Gpt2Config;
use rust_bert::openai_gpt::{OpenAIGPTLMHeadModel, OpenAiGptConfigResources, OpenAiGptVocabResources, OpenAiGptMergesResources, OpenAiGptModelResources};
use rust_bert::resources::{RemoteResource, Resource, download_resource};
use rust_bert::openai_gpt::{
OpenAIGPTLMHeadModel, OpenAiGptConfigResources, OpenAiGptMergesResources,
OpenAiGptModelResources, OpenAiGptVocabResources,
};
use rust_bert::pipelines::generation::{
Cache, GenerateConfig, LMHeadModel, LanguageGenerator, OpenAIGenerator,
};
use rust_bert::resources::{download_resource, RemoteResource, Resource};
use rust_bert::Config;
use rust_tokenizers::{OpenAiGptTokenizer, Tokenizer, TruncationStrategy};
use tch::{nn, Device, Tensor};
#[test]
fn openai_gpt_lm_model() -> failure::Fallible<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptConfigResources::GPT));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptVocabResources::GPT));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptMergesResources::GPT));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptModelResources::GPT));
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
OpenAiGptConfigResources::GPT,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
OpenAiGptVocabResources::GPT,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
OpenAiGptMergesResources::GPT,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
OpenAiGptModelResources::GPT,
));
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
let merges_path = download_resource(&merges_resource)?;
@ -21,29 +34,38 @@ fn openai_gpt_lm_model() -> failure::Fallible<()> {
// Set-up masked LM model
let device = Device::Cpu;
let mut vs = nn::VarStore::new(device);
let tokenizer = OpenAiGptTokenizer::from_file(vocab_path.to_str().unwrap(), merges_path.to_str().unwrap(), true);
let tokenizer = OpenAiGptTokenizer::from_file(
vocab_path.to_str().unwrap(),
merges_path.to_str().unwrap(),
true,
);
let config = Gpt2Config::from_file(config_path);
let openai_gpt = OpenAIGPTLMHeadModel::new(&vs.root(), &config);
vs.load(weights_path)?;
// Define input
let input = ["Wondering what the next word will"];
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
let tokenized_input = tokenized_input.
iter().
map(|input| input.token_ids.clone()).
map(|mut input| {
let tokenized_input =
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let tokenized_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]);
input
}).
map(|input|
Tensor::of_slice(&(input))).
collect::<Vec<_>>();
})
.map(|input| Tensor::of_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
let (output, _, _, _, _) = openai_gpt.forward_t(
let (output, _, _, _, _) = openai_gpt
.forward_t(
&Some(input_tensor),
Cache::None,
&None,
@ -52,13 +74,17 @@ fn openai_gpt_lm_model() -> failure::Fallible<()> {
&None,
None,
&None,
false).unwrap();
false,
)
.unwrap();
let next_word_id = output.get(0).get(-1).argmax(-1, true).int64_value(&[0]);
let next_word = tokenizer.decode(vec!(next_word_id), true, true);
let next_word = tokenizer.decode(vec![next_word_id], true, true);
assert_eq!(output.size(), vec!(1, 6, 40478));
assert!((output.double_value(&[0, output.size()[1] - 1, next_word_id]) - (9.1056)).abs() < 1e-4);
assert!(
(output.double_value(&[0, output.size()[1] - 1, next_word_id]) - (9.1056)).abs() < 1e-4
);
assert_eq!(next_word_id, 580i64);
assert_eq!(next_word, String::from("be"));
@ -68,10 +94,18 @@ fn openai_gpt_lm_model() -> failure::Fallible<()> {
#[test]
fn openai_gpt_generation_greedy() -> failure::Fallible<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptConfigResources::GPT));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptVocabResources::GPT));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptMergesResources::GPT));
let model_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptModelResources::GPT));
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
OpenAiGptConfigResources::GPT,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
OpenAiGptVocabResources::GPT,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
OpenAiGptMergesResources::GPT,
));
let model_resource = Resource::Remote(RemoteResource::from_pretrained(
OpenAiGptModelResources::GPT,
));
// Set-up masked LM model
let generate_config = GenerateConfig {
@ -90,7 +124,7 @@ fn openai_gpt_generation_greedy() -> failure::Fallible<()> {
let model = OpenAIGenerator::new(generate_config)?;
let input_context = "It was an intense machine dialogue. ";
let output = model.generate(Some(vec!(input_context)), None);
let output = model.generate(Some(vec![input_context]), None);
assert_eq!(output.len(), 1);
assert_eq!(output[0], "it was an intense machine dialogue. \n \" i\'m sorry, but we have to go now! the police are on their way and they\'re going after you - or at least that\'s what my");
@ -101,10 +135,18 @@ fn openai_gpt_generation_greedy() -> failure::Fallible<()> {
#[test]
fn openai_gpt_generation_beam_search() -> failure::Fallible<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptConfigResources::GPT));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptVocabResources::GPT));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptMergesResources::GPT));
let model_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptModelResources::GPT));
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
OpenAiGptConfigResources::GPT,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
OpenAiGptVocabResources::GPT,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
OpenAiGptMergesResources::GPT,
));
let model_resource = Resource::Remote(RemoteResource::from_pretrained(
OpenAiGptModelResources::GPT,
));
// Set-up masked LM model
let generate_config = GenerateConfig {
@ -122,12 +164,21 @@ fn openai_gpt_generation_beam_search() -> failure::Fallible<()> {
let model = OpenAIGenerator::new(generate_config)?;
let input_context = "The dog is";
let output = model.generate(Some(vec!(input_context)), None);
let output = model.generate(Some(vec![input_context]), None);
assert_eq!(output.len(), 3);
assert_eq!(output[0], "the dog isn\'t going anywhere. i\'m going to take care of him. i \'ll be right");
assert_eq!(output[1], "the dog isn\'t going anywhere. i\'m going to take care of him. i \'ll be back");
assert_eq!(output[2], "the dog isn\'t going anywhere. i\'m going to take care of him. \" \n \" i");
assert_eq!(
output[0],
"the dog isn\'t going anywhere. i\'m going to take care of him. i \'ll be right"
);
assert_eq!(
output[1],
"the dog isn\'t going anywhere. i\'m going to take care of him. i \'ll be back"
);
assert_eq!(
output[2],
"the dog isn\'t going anywhere. i\'m going to take care of him. \" \n \" i"
);
Ok(())
}
@ -135,10 +186,18 @@ fn openai_gpt_generation_beam_search() -> failure::Fallible<()> {
#[test]
fn openai_gpt_generation_beam_search_multiple_prompts_without_padding() -> failure::Fallible<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptConfigResources::GPT));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptVocabResources::GPT));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptMergesResources::GPT));
let model_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptModelResources::GPT));
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
OpenAiGptConfigResources::GPT,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
OpenAiGptVocabResources::GPT,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
OpenAiGptMergesResources::GPT,
));
let model_resource = Resource::Remote(RemoteResource::from_pretrained(
OpenAiGptModelResources::GPT,
));
// Set-up masked LM model
let generate_config = GenerateConfig {
@ -157,18 +216,36 @@ fn openai_gpt_generation_beam_search_multiple_prompts_without_padding() -> failu
let input_context_1 = "The dog is";
let input_context_2 = "The cat";
let output = model.generate(Some(vec!(input_context_1, input_context_2)), None);
let output = model.generate(Some(vec![input_context_1, input_context_2]), None);
assert_eq!(output.len(), 6);
// Unpadded sequence (generation for `The dog is`) is identical to the
assert_eq!(output[0], "the dog isn\'t going anywhere. i\'m going to take care of him. i \'ll be right");
assert_eq!(output[1], "the dog isn\'t going anywhere. i\'m going to take care of him. i \'ll be back");
assert_eq!(output[2], "the dog isn\'t going anywhere. i\'m going to take care of him. \" \n \" i");
assert_eq!(
output[0],
"the dog isn\'t going anywhere. i\'m going to take care of him. i \'ll be right"
);
assert_eq!(
output[1],
"the dog isn\'t going anywhere. i\'m going to take care of him. i \'ll be back"
);
assert_eq!(
output[2],
"the dog isn\'t going anywhere. i\'m going to take care of him. \" \n \" i"
);
assert_eq!(output[3], "the cat. \" \n \" i don\'t know what you\'re talking about. i don\'t");
assert_eq!(output[4], "the cat. \" \n \" i don\'t know what you\'re talking about. i\'m not");
assert_eq!(output[5], "the cat. \" \n \" i don\'t know what you\'re talking about. i do know");
assert_eq!(
output[3],
"the cat. \" \n \" i don\'t know what you\'re talking about. i don\'t"
);
assert_eq!(
output[4],
"the cat. \" \n \" i don\'t know what you\'re talking about. i\'m not"
);
assert_eq!(
output[5],
"the cat. \" \n \" i don\'t know what you\'re talking about. i do know"
);
Ok(())
}
@ -176,10 +253,18 @@ fn openai_gpt_generation_beam_search_multiple_prompts_without_padding() -> failu
#[test]
fn openai_gpt_generation_beam_search_multiple_prompts_with_padding() -> failure::Fallible<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptConfigResources::GPT));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptVocabResources::GPT));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptMergesResources::GPT));
let model_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptModelResources::GPT));
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
OpenAiGptConfigResources::GPT,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
OpenAiGptVocabResources::GPT,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
OpenAiGptMergesResources::GPT,
));
let model_resource = Resource::Remote(RemoteResource::from_pretrained(
OpenAiGptModelResources::GPT,
));
// Set-up masked LM model
let generate_config = GenerateConfig {
@ -198,16 +283,34 @@ fn openai_gpt_generation_beam_search_multiple_prompts_with_padding() -> failure:
let input_context_1 = "The dog is";
let input_context_2 = "The cat was in";
let output = model.generate(Some(vec!(input_context_1, input_context_2)), None);
let output = model.generate(Some(vec![input_context_1, input_context_2]), None);
assert_eq!(output.len(), 6);
// Left padding impacts the generated sentences output
assert_eq!(output[0], "the dog is a dog. \" \n \" i don\'t know what you\'re talking about.");
assert_eq!(output[1], "the dog is a dog. \" \n \" i don\'t know what you\'re talking about,");
assert_eq!(output[2], "the dog is a dog. \" \n \" i don\'t know what you\'re talking about!");
assert_eq!(output[3], "the cat was in the room with them. \n \" what\'s going on? \" i asked.");
assert_eq!(output[4], "the cat was in the room with them. \n \" what\'s going on? \" she asked.");
assert_eq!(output[5], "the cat was in the room with them. \n \" what\'s going on? why are you all");
assert_eq!(
output[0],
"the dog is a dog. \" \n \" i don\'t know what you\'re talking about."
);
assert_eq!(
output[1],
"the dog is a dog. \" \n \" i don\'t know what you\'re talking about,"
);
assert_eq!(
output[2],
"the dog is a dog. \" \n \" i don\'t know what you\'re talking about!"
);
assert_eq!(
output[3],
"the cat was in the room with them. \n \" what\'s going on? \" i asked."
);
assert_eq!(
output[4],
"the cat was in the room with them. \n \" what\'s going on? \" she asked."
);
assert_eq!(
output[5],
"the cat was in the room with them. \n \" what\'s going on? why are you all"
);
Ok(())
}

View File

@ -1,18 +1,30 @@
use tch::{Device, nn, Tensor, no_grad};
use rust_tokenizers::{RobertaTokenizer, TruncationStrategy, Tokenizer, Vocab};
use rust_bert::Config;
use rust_bert::bert::BertConfig;
use rust_bert::roberta::{RobertaForMaskedLM, RobertaForSequenceClassification, RobertaForMultipleChoice, RobertaForTokenClassification, RobertaForQuestionAnswering, RobertaConfigResources, RobertaVocabResources, RobertaMergesResources, RobertaModelResources};
use rust_bert::resources::{download_resource, RemoteResource, Resource};
use rust_bert::roberta::{
RobertaConfigResources, RobertaForMaskedLM, RobertaForMultipleChoice,
RobertaForQuestionAnswering, RobertaForSequenceClassification, RobertaForTokenClassification,
RobertaMergesResources, RobertaModelResources, RobertaVocabResources,
};
use rust_bert::Config;
use rust_tokenizers::{RobertaTokenizer, Tokenizer, TruncationStrategy, Vocab};
use std::collections::HashMap;
use rust_bert::resources::{RemoteResource, Resource, download_resource};
use tch::{nn, no_grad, Device, Tensor};
#[test]
fn roberta_masked_lm() -> failure::Fallible<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaConfigResources::ROBERTA));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaVocabResources::ROBERTA));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaMergesResources::ROBERTA));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaModelResources::ROBERTA));
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
RobertaConfigResources::ROBERTA,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
RobertaVocabResources::ROBERTA,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
RobertaMergesResources::ROBERTA,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
RobertaModelResources::ROBERTA,
));
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
let merges_path = download_resource(&merges_resource)?;
@ -21,45 +33,57 @@ fn roberta_masked_lm() -> failure::Fallible<()> {
// Set-up masked LM model
let device = Device::Cpu;
let mut vs = nn::VarStore::new(device);
let tokenizer: RobertaTokenizer = RobertaTokenizer::from_file(vocab_path.to_str().unwrap(), merges_path.to_str().unwrap(), true);
let tokenizer: RobertaTokenizer = RobertaTokenizer::from_file(
vocab_path.to_str().unwrap(),
merges_path.to_str().unwrap(),
true,
);
let config = BertConfig::from_file(config_path);
let roberta_model = RobertaForMaskedLM::new(&vs.root(), &config);
vs.load(weights_path)?;
// Define input
let input = ["<pad> Looks like one thing is missing", "It\'s like comparing oranges to apples"];
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
let mut tokenized_input = tokenized_input.
iter().
map(|input| input.token_ids.clone()).
map(|mut input| {
let input = [
"<pad> Looks like one thing is missing",
"It\'s like comparing oranges to apples",
];
let tokenized_input =
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let mut tokenized_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]);
input
}).
collect::<Vec<_>>();
})
.collect::<Vec<_>>();
// Masking the token [thing] of sentence 1 and [oranges] of sentence 2
tokenized_input[0][4] = 103;
tokenized_input[1][5] = 103;
let tokenized_input = tokenized_input.
iter().
map(|input|
Tensor::of_slice(&(input))).
collect::<Vec<_>>();
let tokenized_input = tokenized_input
.iter()
.map(|input| Tensor::of_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
let (output, _, _) = no_grad(|| {
roberta_model
.forward_t(Some(input_tensor),
roberta_model.forward_t(
Some(input_tensor),
None,
None,
None,
None,
&None,
&None,
false)
false,
)
});
// Print masked tokens
@ -77,9 +101,15 @@ fn roberta_masked_lm() -> failure::Fallible<()> {
#[test]
fn roberta_for_sequence_classification() -> failure::Fallible<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaConfigResources::ROBERTA));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaVocabResources::ROBERTA));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaMergesResources::ROBERTA));
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
RobertaConfigResources::ROBERTA,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
RobertaVocabResources::ROBERTA,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
RobertaMergesResources::ROBERTA,
));
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
let merges_path = download_resource(&merges_resource)?;
@ -87,7 +117,11 @@ fn roberta_for_sequence_classification() -> failure::Fallible<()> {
// Set-up model
let device = Device::Cpu;
let vs = nn::VarStore::new(device);
let tokenizer: RobertaTokenizer = RobertaTokenizer::from_file(vocab_path.to_str().unwrap(), merges_path.to_str().unwrap(), true);
let tokenizer: RobertaTokenizer = RobertaTokenizer::from_file(
vocab_path.to_str().unwrap(),
merges_path.to_str().unwrap(),
true,
);
let mut config = BertConfig::from_file(config_path);
let mut dummy_label_mapping = HashMap::new();
dummy_label_mapping.insert(0, String::from("Positive"));
@ -98,37 +132,42 @@ fn roberta_for_sequence_classification() -> failure::Fallible<()> {
config.output_hidden_states = Some(true);
let roberta_model = RobertaForSequenceClassification::new(&vs.root(), &config);
// Define input
let input = ["Looks like one thing is missing", "It\'s like comparing oranges to apples"];
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
let tokenized_input = tokenized_input.
iter().
map(|input| input.token_ids.clone()).
map(|mut input| {
let input = [
"Looks like one thing is missing",
"It\'s like comparing oranges to apples",
];
let tokenized_input =
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let tokenized_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]);
input
}).
map(|input|
Tensor::of_slice(&(input))).
collect::<Vec<_>>();
})
.map(|input| Tensor::of_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
let (output, all_hidden_states, all_attentions) = no_grad(|| {
roberta_model
.forward_t(Some(input_tensor),
None,
None,
None,
None,
false)
});
let (output, all_hidden_states, all_attentions) =
no_grad(|| roberta_model.forward_t(Some(input_tensor), None, None, None, None, false));
assert_eq!(output.size(), &[2, 3]);
assert_eq!(config.num_hidden_layers as usize, all_hidden_states.unwrap().len());
assert_eq!(config.num_hidden_layers as usize, all_attentions.unwrap().len());
assert_eq!(
config.num_hidden_layers as usize,
all_hidden_states.unwrap().len()
);
assert_eq!(
config.num_hidden_layers as usize,
all_attentions.unwrap().len()
);
Ok(())
}
@ -136,9 +175,15 @@ fn roberta_for_sequence_classification() -> failure::Fallible<()> {
#[test]
fn roberta_for_multiple_choice() -> failure::Fallible<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaConfigResources::ROBERTA));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaVocabResources::ROBERTA));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaMergesResources::ROBERTA));
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
RobertaConfigResources::ROBERTA,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
RobertaVocabResources::ROBERTA,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
RobertaMergesResources::ROBERTA,
));
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
let merges_path = download_resource(&merges_resource)?;
@ -146,42 +191,54 @@ fn roberta_for_multiple_choice() -> failure::Fallible<()> {
// Set-up model
let device = Device::Cpu;
let vs = nn::VarStore::new(device);
let tokenizer: RobertaTokenizer = RobertaTokenizer::from_file(vocab_path.to_str().unwrap(), merges_path.to_str().unwrap(), true);
let tokenizer: RobertaTokenizer = RobertaTokenizer::from_file(
vocab_path.to_str().unwrap(),
merges_path.to_str().unwrap(),
true,
);
let mut config = BertConfig::from_file(config_path);
config.output_attentions = Some(true);
config.output_hidden_states = Some(true);
let roberta_model = RobertaForMultipleChoice::new(&vs.root(), &config);
// Define input
let input = ["Looks like one thing is missing", "It\'s like comparing oranges to apples"];
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
let tokenized_input = tokenized_input.
iter().
map(|input| input.token_ids.clone()).
map(|mut input| {
let input = [
"Looks like one thing is missing",
"It\'s like comparing oranges to apples",
];
let tokenized_input =
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let tokenized_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]);
input
}).
map(|input|
Tensor::of_slice(&(input))).
collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device).unsqueeze(0);
})
.map(|input| Tensor::of_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0)
.to(device)
.unsqueeze(0);
// Forward pass
let (output, all_hidden_states, all_attentions) = no_grad(|| {
roberta_model
.forward_t(input_tensor,
None,
None,
None,
false)
});
let (output, all_hidden_states, all_attentions) =
no_grad(|| roberta_model.forward_t(input_tensor, None, None, None, false));
assert_eq!(output.size(), &[1, 2]);
assert_eq!(config.num_hidden_layers as usize, all_hidden_states.unwrap().len());
assert_eq!(config.num_hidden_layers as usize, all_attentions.unwrap().len());
assert_eq!(
config.num_hidden_layers as usize,
all_hidden_states.unwrap().len()
);
assert_eq!(
config.num_hidden_layers as usize,
all_attentions.unwrap().len()
);
Ok(())
}
@ -189,9 +246,15 @@ fn roberta_for_multiple_choice() -> failure::Fallible<()> {
#[test]
fn roberta_for_token_classification() -> failure::Fallible<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaConfigResources::ROBERTA));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaVocabResources::ROBERTA));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaMergesResources::ROBERTA));
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
RobertaConfigResources::ROBERTA,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
RobertaVocabResources::ROBERTA,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
RobertaMergesResources::ROBERTA,
));
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
let merges_path = download_resource(&merges_resource)?;
@ -199,7 +262,11 @@ fn roberta_for_token_classification() -> failure::Fallible<()> {
// Set-up model
let device = Device::Cpu;
let vs = nn::VarStore::new(device);
let tokenizer: RobertaTokenizer = RobertaTokenizer::from_file(vocab_path.to_str().unwrap(), merges_path.to_str().unwrap(), true);
let tokenizer: RobertaTokenizer = RobertaTokenizer::from_file(
vocab_path.to_str().unwrap(),
merges_path.to_str().unwrap(),
true,
);
let mut config = BertConfig::from_file(config_path);
let mut dummy_label_mapping = HashMap::new();
dummy_label_mapping.insert(0, String::from("O"));
@ -212,46 +279,57 @@ fn roberta_for_token_classification() -> failure::Fallible<()> {
let roberta_model = RobertaForTokenClassification::new(&vs.root(), &config);
// Define input
let input = ["Looks like one thing is missing", "It\'s like comparing oranges to apples"];
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
let tokenized_input = tokenized_input.
iter().
map(|input| input.token_ids.clone()).
map(|mut input| {
let input = [
"Looks like one thing is missing",
"It\'s like comparing oranges to apples",
];
let tokenized_input =
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let tokenized_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]);
input
}).
map(|input|
Tensor::of_slice(&(input))).
collect::<Vec<_>>();
})
.map(|input| Tensor::of_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
let (output, all_hidden_states, all_attentions) = no_grad(|| {
roberta_model
.forward_t(Some(input_tensor),
None,
None,
None,
None,
false)
});
let (output, all_hidden_states, all_attentions) =
no_grad(|| roberta_model.forward_t(Some(input_tensor), None, None, None, None, false));
assert_eq!(output.size(), &[2, 9, 4]);
assert_eq!(config.num_hidden_layers as usize, all_hidden_states.unwrap().len());
assert_eq!(config.num_hidden_layers as usize, all_attentions.unwrap().len());
assert_eq!(
config.num_hidden_layers as usize,
all_hidden_states.unwrap().len()
);
assert_eq!(
config.num_hidden_layers as usize,
all_attentions.unwrap().len()
);
Ok(())
}
#[test]
fn roberta_for_question_answering() -> failure::Fallible<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaConfigResources::ROBERTA));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaVocabResources::ROBERTA));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaMergesResources::ROBERTA));
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
RobertaConfigResources::ROBERTA,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
RobertaVocabResources::ROBERTA,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
RobertaMergesResources::ROBERTA,
));
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
let merges_path = download_resource(&merges_resource)?;
@ -259,7 +337,11 @@ fn roberta_for_question_answering() -> failure::Fallible<()> {
// Set-up model
let device = Device::Cpu;
let vs = nn::VarStore::new(device);
let tokenizer: RobertaTokenizer = RobertaTokenizer::from_file(vocab_path.to_str().unwrap(), merges_path.to_str().unwrap(), true);
let tokenizer: RobertaTokenizer = RobertaTokenizer::from_file(
vocab_path.to_str().unwrap(),
merges_path.to_str().unwrap(),
true,
);
let mut config = BertConfig::from_file(config_path);
let mut dummy_label_mapping = HashMap::new();
dummy_label_mapping.insert(0, String::from("Positive"));
@ -270,38 +352,43 @@ fn roberta_for_question_answering() -> failure::Fallible<()> {
config.output_hidden_states = Some(true);
let roberta_model = RobertaForQuestionAnswering::new(&vs.root(), &config);
// Define input
let input = ["Looks like one thing is missing", "It\'s like comparing oranges to apples"];
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
let tokenized_input = tokenized_input.
iter().
map(|input| input.token_ids.clone()).
map(|mut input| {
let input = [
"Looks like one thing is missing",
"It\'s like comparing oranges to apples",
];
let tokenized_input =
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let tokenized_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]);
input
}).
map(|input|
Tensor::of_slice(&(input))).
collect::<Vec<_>>();
})
.map(|input| Tensor::of_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
let (start_scores, end_scores, all_hidden_states, all_attentions) = no_grad(|| {
roberta_model
.forward_t(Some(input_tensor),
None,
None,
None,
None,
false)
});
let (start_scores, end_scores, all_hidden_states, all_attentions) =
no_grad(|| roberta_model.forward_t(Some(input_tensor), None, None, None, None, false));
assert_eq!(start_scores.size(), &[2, 9]);
assert_eq!(end_scores.size(), &[2, 9]);
assert_eq!(config.num_hidden_layers as usize, all_hidden_states.unwrap().len());
assert_eq!(config.num_hidden_layers as usize, all_attentions.unwrap().len());
assert_eq!(
config.num_hidden_layers as usize,
all_hidden_states.unwrap().len()
);
assert_eq!(
config.num_hidden_layers as usize,
all_attentions.unwrap().len()
);
Ok(())
}