Code formatted using rustfmt

This commit is contained in:
Guillaume B 2020-06-23 16:54:46 +02:00
parent 0624a5368c
commit 47e36c4e8c
86 changed files with 9498 additions and 5324 deletions

View File

@ -13,58 +13,67 @@
extern crate failure;
use tch::{Device, nn, Tensor, no_grad};
use rust_tokenizers::{AlbertTokenizer, TruncationStrategy, Tokenizer, Vocab};
use rust_bert::albert::{
AlbertConfig, AlbertConfigResources, AlbertForMaskedLM, AlbertModelResources,
AlbertVocabResources,
};
use rust_bert::resources::{download_resource, RemoteResource, Resource};
use rust_bert::Config;
use rust_bert::resources::{Resource, download_resource, RemoteResource};
use rust_bert::albert::{AlbertConfig, AlbertForMaskedLM, AlbertConfigResources, AlbertVocabResources, AlbertModelResources};
use rust_tokenizers::{AlbertTokenizer, Tokenizer, TruncationStrategy, Vocab};
use tch::{nn, no_grad, Device, Tensor};
fn main() -> failure::Fallible<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(AlbertConfigResources::ALBERT_BASE_V2));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(AlbertVocabResources::ALBERT_BASE_V2));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(AlbertModelResources::ALBERT_BASE_V2));
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
AlbertConfigResources::ALBERT_BASE_V2,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
AlbertVocabResources::ALBERT_BASE_V2,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
AlbertModelResources::ALBERT_BASE_V2,
));
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
let weights_path = download_resource(&weights_resource)?;
// Set-up masked LM model
// Set-up masked LM model
let device = Device::Cpu;
let mut vs = nn::VarStore::new(device);
let tokenizer: AlbertTokenizer = AlbertTokenizer::from_file(vocab_path.to_str().unwrap(), true, false);
let tokenizer: AlbertTokenizer =
AlbertTokenizer::from_file(vocab_path.to_str().unwrap(), true, false);
let config = AlbertConfig::from_file(config_path);
let albert_model = AlbertForMaskedLM::new(&vs.root(), &config);
vs.load(weights_path)?;
// Define input
let input = ["Looks like one [MASK] is missing", "It was a very nice and [MASK] day"];
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
let tokenized_input = tokenized_input.
iter().
map(|input| input.token_ids.clone()).
map(|mut input| {
// Define input
let input = [
"Looks like one [MASK] is missing",
"It was a very nice and [MASK] day",
];
let tokenized_input =
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let tokenized_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]);
input
}).
map(|input|
Tensor::of_slice(&(input))).
collect::<Vec<_>>();
})
.map(|input| Tensor::of_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
let (output, _, _) = no_grad(|| {
albert_model
.forward_t(Some(input_tensor),
None,
None,
None,
None,
false)
});
// Forward pass
let (output, _, _) =
no_grad(|| albert_model.forward_t(Some(input_tensor), None, None, None, None, false));
println!("{:?}", output.double_value(&[0, 0, 0]));
// Print masked tokens
// Print masked tokens
let index_1 = output.get(0).get(4).argmax(0, false);
let index_2 = output.get(1).get(7).argmax(0, false);
let word_1 = tokenizer.vocab().id_to_token(&index_1.int64_value(&[]));
@ -74,4 +83,4 @@ fn main() -> failure::Fallible<()> {
println!("{} - {}", &index_2.int64_value(&[]), word_2); // Outputs "_enjoyable" : "It was a very nice and [enjoyable] day"
Ok(())
}
}

View File

@ -12,33 +12,43 @@
extern crate failure;
use tch::{Device, nn, Tensor, no_grad};
use rust_tokenizers::{RobertaTokenizer, TruncationStrategy, Tokenizer};
use rust_bert::bart::{BartConfig, BartConfigResources, BartVocabResources, BartMergesResources, BartModelResources, BartModel};
use rust_bert::bart::{
BartConfig, BartConfigResources, BartMergesResources, BartModel, BartModelResources,
BartVocabResources,
};
use rust_bert::resources::{download_resource, RemoteResource, Resource};
use rust_bert::Config;
use rust_bert::resources::{Resource, download_resource, RemoteResource};
use rust_tokenizers::{RobertaTokenizer, Tokenizer, TruncationStrategy};
use tch::{nn, no_grad, Device, Tensor};
fn main() -> failure::Fallible<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(BartConfigResources::BART));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(BartVocabResources::BART));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(BartMergesResources::BART));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(BartModelResources::BART));
let config_resource =
Resource::Remote(RemoteResource::from_pretrained(BartConfigResources::BART));
let vocab_resource =
Resource::Remote(RemoteResource::from_pretrained(BartVocabResources::BART));
let merges_resource =
Resource::Remote(RemoteResource::from_pretrained(BartMergesResources::BART));
let weights_resource =
Resource::Remote(RemoteResource::from_pretrained(BartModelResources::BART));
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
let merges_path = download_resource(&merges_resource)?;
let weights_path = download_resource(&weights_resource)?;
// Set-up masked LM model
// Set-up masked LM model
let device = Device::cuda_if_available();
let mut vs = nn::VarStore::new(device);
let tokenizer = RobertaTokenizer::from_file(vocab_path.to_str().unwrap(), merges_path.to_str().unwrap(), false);
let tokenizer = RobertaTokenizer::from_file(
vocab_path.to_str().unwrap(),
merges_path.to_str().unwrap(),
false,
);
let config = BartConfig::from_file(config_path);
let bart_model = BartModel::new(&vs.root(), &config, false);
vs.load(weights_path)?;
// Define input
// Define input
let input = ["In findings published Tuesday in Cornell University's arXiv by a team of scientists \
from the University of Montreal and a separate report published Wednesday in Nature Astronomy by a team \
from University College London (UCL), the presence of water vapour was confirmed in the atmosphere of K2-18b, \
@ -61,36 +71,32 @@ on K2-18b lasts 33 Earth days. According to The Guardian, astronomers were optim
telescope scheduled for launch in 2021 and the European Space Agency's 2028 ARIEL program, could reveal more \
about exoplanets like K2-18b."];
// Credits: WikiNews, CC BY 2.5 license (https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b)
// Credits: WikiNews, CC BY 2.5 license (https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b)
let tokenized_input = tokenizer.encode_list(input.to_vec(), 1024, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
let tokenized_input = tokenized_input.
iter().
map(|input| input.token_ids.clone()).
map(|mut input| {
let tokenized_input =
tokenizer.encode_list(input.to_vec(), 1024, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let tokenized_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]);
input
}).
map(|input|
Tensor::of_slice(&(input))).
collect::<Vec<_>>();
})
.map(|input| Tensor::of_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
let (decoder_output, encoder_output, _, _, _, _, _) = no_grad(|| {
bart_model
.forward_t(Some(&input_tensor),
None,
None,
None,
None,
None,
false)
});
// Forward pass
let (decoder_output, encoder_output, _, _, _, _, _) =
no_grad(|| bart_model.forward_t(Some(&input_tensor), None, None, None, None, None, false));
// Print masked tokens
// Print masked tokens
println!("{:?}", encoder_output);
println!("{:?}", decoder_output);
Ok(())
}
}

View File

@ -12,23 +12,27 @@
extern crate failure;
use tch::{Device, nn, Tensor, no_grad};
use rust_tokenizers::{BertTokenizer, TruncationStrategy, Tokenizer, Vocab};
use rust_bert::bert::{
BertConfig, BertConfigResources, BertForMaskedLM, BertModelResources, BertVocabResources,
};
use rust_bert::resources::{download_resource, RemoteResource, Resource};
use rust_bert::Config;
use rust_bert::bert::{BertConfig, BertForMaskedLM, BertConfigResources, BertVocabResources, BertModelResources};
use rust_bert::resources::{Resource, download_resource, RemoteResource};
use rust_tokenizers::{BertTokenizer, Tokenizer, TruncationStrategy, Vocab};
use tch::{nn, no_grad, Device, Tensor};
fn main() -> failure::Fallible<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(BertModelResources::BERT));
let config_resource =
Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT));
let vocab_resource =
Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT));
let weights_resource =
Resource::Remote(RemoteResource::from_pretrained(BertModelResources::BERT));
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
let weights_path = download_resource(&weights_resource)?;
// Set-up masked LM model
// Set-up masked LM model
let device = Device::Cpu;
let mut vs = nn::VarStore::new(device);
let tokenizer: BertTokenizer = BertTokenizer::from_file(vocab_path.to_str().unwrap(), true);
@ -36,43 +40,51 @@ fn main() -> failure::Fallible<()> {
let bert_model = BertForMaskedLM::new(&vs.root(), &config);
vs.load(weights_path)?;
// Define input
let input = ["Looks like one [MASK] is missing", "It was a very nice and [MASK] day"];
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
let tokenized_input = tokenized_input.
iter().
map(|input| input.token_ids.clone()).
map(|mut input| {
// Define input
let input = [
"Looks like one [MASK] is missing",
"It was a very nice and [MASK] day",
];
let tokenized_input =
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let tokenized_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]);
input
}).
map(|input|
Tensor::of_slice(&(input))).
collect::<Vec<_>>();
})
.map(|input| Tensor::of_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
// Forward pass
let (output, _, _) = no_grad(|| {
bert_model
.forward_t(Some(input_tensor),
None,
None,
None,
None,
&None,
&None,
false)
bert_model.forward_t(
Some(input_tensor),
None,
None,
None,
None,
&None,
&None,
false,
)
});
// Print masked tokens
// Print masked tokens
let index_1 = output.get(0).get(4).argmax(0, false);
let index_2 = output.get(1).get(7).argmax(0, false);
let word_1 = tokenizer.vocab().id_to_token(&index_1.int64_value(&[]));
let word_2 = tokenizer.vocab().id_to_token(&index_2.int64_value(&[]));
println!("{}", word_1); // Outputs "person" : "Looks like one [person] is missing"
println!("{}", word_2);// Outputs "pear" : "It was a very nice and [pleasant] day"
println!("{}", word_2); // Outputs "pear" : "It was a very nice and [pleasant] day"
Ok(())
}
}

View File

@ -11,25 +11,33 @@
// limitations under the License.
extern crate failure;
use tch::{Device, Tensor, nn, no_grad};
use rust_tokenizers::preprocessing::tokenizer::base_tokenizer::{Tokenizer, TruncationStrategy};
use rust_tokenizers::bert_tokenizer::BertTokenizer;
use rust_tokenizers::preprocessing::vocab::base_vocab::Vocab;
use rust_bert::distilbert::{
DistilBertConfig, DistilBertConfigResources, DistilBertModelMaskedLM, DistilBertModelResources,
DistilBertVocabResources,
};
use rust_bert::resources::{download_resource, RemoteResource, Resource};
use rust_bert::Config;
use rust_bert::distilbert::{DistilBertConfig, DistilBertModelMaskedLM, DistilBertConfigResources, DistilBertVocabResources, DistilBertModelResources};
use rust_bert::resources::{Resource, download_resource, RemoteResource};
use rust_tokenizers::bert_tokenizer::BertTokenizer;
use rust_tokenizers::preprocessing::tokenizer::base_tokenizer::{Tokenizer, TruncationStrategy};
use rust_tokenizers::preprocessing::vocab::base_vocab::Vocab;
use tch::{nn, no_grad, Device, Tensor};
fn main() -> failure::Fallible<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(DistilBertConfigResources::DISTIL_BERT));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(DistilBertVocabResources::DISTIL_BERT));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(DistilBertModelResources::DISTIL_BERT));
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
DistilBertConfigResources::DISTIL_BERT,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
DistilBertVocabResources::DISTIL_BERT,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
DistilBertModelResources::DISTIL_BERT,
));
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
let weights_path = download_resource(&weights_resource)?;
// Set-up masked LM model
// Set-up masked LM model
let device = Device::Cpu;
let mut vs = nn::VarStore::new(device);
let tokenizer: BertTokenizer = BertTokenizer::from_file(vocab_path.to_str().unwrap(), true);
@ -37,45 +45,51 @@ fn main() -> failure::Fallible<()> {
let distil_bert_model = DistilBertModelMaskedLM::new(&vs.root(), &config);
vs.load(weights_path)?;
// Define input
let input = ["Looks like one thing is missing", "It\'s like comparing oranges to apples"];
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
let mut tokenized_input = tokenized_input.
iter().
map(|input| input.token_ids.clone()).
map(|mut input| {
// Define input
let input = [
"Looks like one thing is missing",
"It\'s like comparing oranges to apples",
];
let tokenized_input =
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let mut tokenized_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]);
input
}).
collect::<Vec<_>>();
})
.collect::<Vec<_>>();
// Masking the token [thing] of sentence 1 and [oranges] of sentence 2
// Masking the token [thing] of sentence 1 and [oranges] of sentence 2
tokenized_input[0][4] = 103;
tokenized_input[1][6] = 103;
let tokenized_input = tokenized_input.
iter().
map(|input|
Tensor::of_slice(&(input))).
collect::<Vec<_>>();
let tokenized_input = tokenized_input
.iter()
.map(|input| Tensor::of_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
// Forward pass
let (output, _, _) = no_grad(|| {
distil_bert_model
.forward_t(Some(input_tensor), None, None, false)
.unwrap()
});
// Print masked tokens
// Print masked tokens
let index_1 = output.get(0).get(4).argmax(0, false);
let index_2 = output.get(1).get(6).argmax(0, false);
let word_1 = tokenizer.vocab().id_to_token(&index_1.int64_value(&[]));
let word_2 = tokenizer.vocab().id_to_token(&index_2.int64_value(&[]));
println!("{}", word_1); // Outputs "person" : "Looks like one [person] is missing"
println!("{}", word_2);// Outputs "pear" : "It\'s like comparing [pear] to apples"
println!("{}", word_2); // Outputs "pear" : "It\'s like comparing [pear] to apples"
Ok(())
}

View File

@ -1,26 +1,44 @@
extern crate failure;
use rust_bert::gpt2::{Gpt2ConfigResources, Gpt2VocabResources, Gpt2MergesResources, Gpt2ModelResources};
use rust_bert::distilbert::{DistilBertModelResources, DistilBertConfigResources, DistilBertVocabResources};
use rust_bert::openai_gpt::{OpenAiGptConfigResources, OpenAiGptVocabResources, OpenAiGptMergesResources, OpenAiGptModelResources};
use rust_bert::roberta::{RobertaConfigResources, RobertaVocabResources, RobertaMergesResources, RobertaModelResources};
use rust_bert::bert::{BertConfigResources, BertVocabResources, BertModelResources};
use rust_bert::bart::{BartConfigResources, BartVocabResources, BartMergesResources, BartModelResources};
use rust_bert::resources::{Resource, download_resource, RemoteResource};
use rust_bert::electra::{ElectraConfigResources, ElectraVocabResources, ElectraModelResources};
use rust_bert::albert::{AlbertConfigResources, AlbertVocabResources, AlbertModelResources};
use rust_bert::albert::{AlbertConfigResources, AlbertModelResources, AlbertVocabResources};
use rust_bert::bart::{
BartConfigResources, BartMergesResources, BartModelResources, BartVocabResources,
};
use rust_bert::bert::{BertConfigResources, BertModelResources, BertVocabResources};
use rust_bert::distilbert::{
DistilBertConfigResources, DistilBertModelResources, DistilBertVocabResources,
};
use rust_bert::electra::{ElectraConfigResources, ElectraModelResources, ElectraVocabResources};
use rust_bert::gpt2::{
Gpt2ConfigResources, Gpt2MergesResources, Gpt2ModelResources, Gpt2VocabResources,
};
use rust_bert::openai_gpt::{
OpenAiGptConfigResources, OpenAiGptMergesResources, OpenAiGptModelResources,
OpenAiGptVocabResources,
};
use rust_bert::resources::{download_resource, RemoteResource, Resource};
use rust_bert::roberta::{
RobertaConfigResources, RobertaMergesResources, RobertaModelResources, RobertaVocabResources,
};
/// This example downloads and caches all dependencies used in model tests. This allows for safe
/// multi threaded testing (two test using the same resource would otherwise download the file to
/// the same location).
fn download_distil_gpt2() -> failure::Fallible<()> {
// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
let config_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::DISTIL_GPT2));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::DISTIL_GPT2));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::DISTIL_GPT2));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::DISTIL_GPT2));
// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
Gpt2ConfigResources::DISTIL_GPT2,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
Gpt2VocabResources::DISTIL_GPT2,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
Gpt2MergesResources::DISTIL_GPT2,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
Gpt2ModelResources::DISTIL_GPT2,
));
let _ = download_resource(&config_resource)?;
let _ = download_resource(&vocab_resource)?;
let _ = download_resource(&merges_resource)?;
@ -29,10 +47,16 @@ fn download_distil_gpt2() -> failure::Fallible<()> {
}
fn download_distilbert_sst2() -> failure::Fallible<()> {
// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(DistilBertModelResources::DISTIL_BERT_SST2));
let config_resource = Resource::Remote(RemoteResource::from_pretrained(DistilBertConfigResources::DISTIL_BERT_SST2));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(DistilBertVocabResources::DISTIL_BERT_SST2));
// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
DistilBertModelResources::DISTIL_BERT_SST2,
));
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
DistilBertConfigResources::DISTIL_BERT_SST2,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
DistilBertVocabResources::DISTIL_BERT_SST2,
));
let _ = download_resource(&config_resource)?;
let _ = download_resource(&vocab_resource)?;
let _ = download_resource(&weights_resource)?;
@ -40,10 +64,16 @@ fn download_distilbert_sst2() -> failure::Fallible<()> {
}
fn download_distilbert_qa() -> failure::Fallible<()> {
// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(DistilBertModelResources::DISTIL_BERT_SQUAD));
let config_resource = Resource::Remote(RemoteResource::from_pretrained(DistilBertConfigResources::DISTIL_BERT_SQUAD));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(DistilBertVocabResources::DISTIL_BERT_SQUAD));
// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
DistilBertModelResources::DISTIL_BERT_SQUAD,
));
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
DistilBertConfigResources::DISTIL_BERT_SQUAD,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
DistilBertVocabResources::DISTIL_BERT_SQUAD,
));
let _ = download_resource(&config_resource)?;
let _ = download_resource(&vocab_resource)?;
let _ = download_resource(&weights_resource)?;
@ -51,10 +81,16 @@ fn download_distilbert_qa() -> failure::Fallible<()> {
}
fn download_distilbert() -> failure::Fallible<()> {
// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(DistilBertModelResources::DISTIL_BERT));
let config_resource = Resource::Remote(RemoteResource::from_pretrained(DistilBertConfigResources::DISTIL_BERT));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(DistilBertVocabResources::DISTIL_BERT));
// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
DistilBertModelResources::DISTIL_BERT,
));
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
DistilBertConfigResources::DISTIL_BERT,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
DistilBertVocabResources::DISTIL_BERT,
));
let _ = download_resource(&config_resource)?;
let _ = download_resource(&vocab_resource)?;
let _ = download_resource(&weights_resource)?;
@ -62,11 +98,15 @@ fn download_distilbert() -> failure::Fallible<()> {
}
fn download_gpt2() -> failure::Fallible<()> {
// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2. Modified with conversion to C-array format.
let config_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2. Modified with conversion to C-array format.
let config_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2));
let vocab_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2));
let merges_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
let weights_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
let _ = download_resource(&config_resource)?;
let _ = download_resource(&vocab_resource)?;
let _ = download_resource(&merges_resource)?;
@ -75,11 +115,19 @@ fn download_gpt2() -> failure::Fallible<()> {
}
fn download_gpt() -> failure::Fallible<()> {
// Shared under MIT license by the OpenAI team at https://github.com/openai/finetune-transformer-lm. Modified with conversion to C-array format.
let config_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptConfigResources::GPT));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptVocabResources::GPT));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptMergesResources::GPT));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptModelResources::GPT));
// Shared under MIT license by the OpenAI team at https://github.com/openai/finetune-transformer-lm. Modified with conversion to C-array format.
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
OpenAiGptConfigResources::GPT,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
OpenAiGptVocabResources::GPT,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
OpenAiGptMergesResources::GPT,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
OpenAiGptModelResources::GPT,
));
let _ = download_resource(&config_resource)?;
let _ = download_resource(&vocab_resource)?;
let _ = download_resource(&merges_resource)?;
@ -88,11 +136,19 @@ fn download_gpt() -> failure::Fallible<()> {
}
fn download_roberta() -> failure::Fallible<()> {
// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
let config_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaConfigResources::ROBERTA));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaVocabResources::ROBERTA));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaMergesResources::ROBERTA));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaModelResources::ROBERTA));
// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
RobertaConfigResources::ROBERTA,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
RobertaVocabResources::ROBERTA,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
RobertaMergesResources::ROBERTA,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
RobertaModelResources::ROBERTA,
));
let _ = download_resource(&config_resource)?;
let _ = download_resource(&vocab_resource)?;
let _ = download_resource(&merges_resource)?;
@ -101,10 +157,13 @@ fn download_roberta() -> failure::Fallible<()> {
}
fn download_bert() -> failure::Fallible<()> {
// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/bert. Modified with conversion to C-array format.
let config_resource = Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(BertModelResources::BERT));
// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/bert. Modified with conversion to C-array format.
let config_resource =
Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT));
let vocab_resource =
Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT));
let weights_resource =
Resource::Remote(RemoteResource::from_pretrained(BertModelResources::BERT));
let _ = download_resource(&config_resource)?;
let _ = download_resource(&vocab_resource)?;
let _ = download_resource(&weights_resource)?;
@ -112,10 +171,16 @@ fn download_bert() -> failure::Fallible<()> {
}
fn download_bert_ner() -> failure::Fallible<()> {
// Shared under MIT license by the MDZ Digital Library team at the Bavarian State Library at https://github.com/dbmdz/berts. Modified with conversion to C-array format.
let config_resource = Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT_NER));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT_NER));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(BertModelResources::BERT_NER));
// Shared under MIT license by the MDZ Digital Library team at the Bavarian State Library at https://github.com/dbmdz/berts. Modified with conversion to C-array format.
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
BertConfigResources::BERT_NER,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
BertVocabResources::BERT_NER,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
BertModelResources::BERT_NER,
));
let _ = download_resource(&config_resource)?;
let _ = download_resource(&vocab_resource)?;
let _ = download_resource(&weights_resource)?;
@ -123,11 +188,15 @@ fn download_bert_ner() -> failure::Fallible<()> {
}
fn download_bart() -> failure::Fallible<()> {
// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
let config_resource = Resource::Remote(RemoteResource::from_pretrained(BartConfigResources::BART));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(BartVocabResources::BART));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(BartMergesResources::BART));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(BartModelResources::BART));
// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
let config_resource =
Resource::Remote(RemoteResource::from_pretrained(BartConfigResources::BART));
let vocab_resource =
Resource::Remote(RemoteResource::from_pretrained(BartVocabResources::BART));
let merges_resource =
Resource::Remote(RemoteResource::from_pretrained(BartMergesResources::BART));
let weights_resource =
Resource::Remote(RemoteResource::from_pretrained(BartModelResources::BART));
let _ = download_resource(&config_resource)?;
let _ = download_resource(&vocab_resource)?;
let _ = download_resource(&merges_resource)?;
@ -136,11 +205,19 @@ fn download_bart() -> failure::Fallible<()> {
}
fn download_bart_cnn() -> failure::Fallible<()> {
// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
let config_resource = Resource::Remote(RemoteResource::from_pretrained(BartConfigResources::BART_CNN));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(BartVocabResources::BART_CNN));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(BartMergesResources::BART_CNN));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(BartModelResources::BART_CNN));
// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
BartConfigResources::BART_CNN,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
BartVocabResources::BART_CNN,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
BartMergesResources::BART_CNN,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
BartModelResources::BART_CNN,
));
let _ = download_resource(&config_resource)?;
let _ = download_resource(&vocab_resource)?;
let _ = download_resource(&merges_resource)?;
@ -149,10 +226,16 @@ fn download_bart_cnn() -> failure::Fallible<()> {
}
fn download_electra_generator() -> failure::Fallible<()> {
// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/electra. Modified with conversion to C-array format.
let config_resource = Resource::Remote(RemoteResource::from_pretrained(ElectraConfigResources::BASE_GENERATOR));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(ElectraVocabResources::BASE_GENERATOR));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(ElectraModelResources::BASE_GENERATOR));
// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/electra. Modified with conversion to C-array format.
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
ElectraConfigResources::BASE_GENERATOR,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
ElectraVocabResources::BASE_GENERATOR,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
ElectraModelResources::BASE_GENERATOR,
));
let _ = download_resource(&config_resource)?;
let _ = download_resource(&vocab_resource)?;
let _ = download_resource(&weights_resource)?;
@ -160,10 +243,16 @@ fn download_electra_generator() -> failure::Fallible<()> {
}
fn download_electra_discriminator() -> failure::Fallible<()> {
// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/electra. Modified with conversion to C-array format.
let config_resource = Resource::Remote(RemoteResource::from_pretrained(ElectraConfigResources::BASE_DISCRIMINATOR));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(ElectraVocabResources::BASE_DISCRIMINATOR));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(ElectraModelResources::BASE_DISCRIMINATOR));
// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/electra. Modified with conversion to C-array format.
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
ElectraConfigResources::BASE_DISCRIMINATOR,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
ElectraVocabResources::BASE_DISCRIMINATOR,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
ElectraModelResources::BASE_DISCRIMINATOR,
));
let _ = download_resource(&config_resource)?;
let _ = download_resource(&vocab_resource)?;
let _ = download_resource(&weights_resource)?;
@ -171,10 +260,16 @@ fn download_electra_discriminator() -> failure::Fallible<()> {
}
fn download_albert_base_v2() -> failure::Fallible<()> {
// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/ALBERT. Modified with conversion to C-array format.
let config_resource = Resource::Remote(RemoteResource::from_pretrained(AlbertConfigResources::ALBERT_BASE_V2));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(AlbertVocabResources::ALBERT_BASE_V2));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(AlbertModelResources::ALBERT_BASE_V2));
// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/ALBERT. Modified with conversion to C-array format.
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
AlbertConfigResources::ALBERT_BASE_V2,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
AlbertVocabResources::ALBERT_BASE_V2,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
AlbertModelResources::ALBERT_BASE_V2,
));
let _ = download_resource(&config_resource)?;
let _ = download_resource(&vocab_resource)?;
let _ = download_resource(&weights_resource)?;
@ -198,4 +293,4 @@ fn main() -> failure::Fallible<()> {
let _ = download_albert_base_v2();
Ok(())
}
}

View File

@ -12,23 +12,31 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use rust_bert::resources::{Resource, download_resource, RemoteResource};
use rust_bert::electra::{ElectraConfig, ElectraDiscriminator, ElectraConfigResources, ElectraVocabResources, ElectraModelResources};
use rust_bert::electra::{
ElectraConfig, ElectraConfigResources, ElectraDiscriminator, ElectraModelResources,
ElectraVocabResources,
};
use rust_bert::resources::{download_resource, RemoteResource, Resource};
use rust_bert::Config;
use rust_tokenizers::{BertTokenizer, Tokenizer, TruncationStrategy};
use tch::{Tensor, Device, nn, no_grad};
use tch::{nn, no_grad, Device, Tensor};
fn main() -> failure::Fallible<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(ElectraConfigResources::BASE_DISCRIMINATOR));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(ElectraVocabResources::BASE_DISCRIMINATOR));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(ElectraModelResources::BASE_DISCRIMINATOR));
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
ElectraConfigResources::BASE_DISCRIMINATOR,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
ElectraVocabResources::BASE_DISCRIMINATOR,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
ElectraModelResources::BASE_DISCRIMINATOR,
));
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
let weights_path = download_resource(&weights_resource)?;
// Set-up masked LM model
// Set-up masked LM model
let device = Device::Cpu;
let mut vs = nn::VarStore::new(device);
let tokenizer: BertTokenizer = BertTokenizer::from_file(vocab_path.to_str().unwrap(), true);
@ -36,45 +44,45 @@ fn main() -> failure::Fallible<()> {
let electra_model = ElectraDiscriminator::new(&vs.root(), &config);
vs.load(weights_path)?;
// Define input
// Define input
let input = ["One Two Three Ten Five Six Seven Eight"];
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
let encoded_input = tokenized_input.
iter().
map(|input| input.token_ids.clone()).
map(|mut input| {
let tokenized_input =
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let encoded_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]);
input
}).
map(|input|
Tensor::of_slice(&(input))).
collect::<Vec<_>>();
})
.map(|input| Tensor::of_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(encoded_input.as_slice(), 0).to(device);
// Forward pass
let (output, _, _) = no_grad(|| {
electra_model
.forward_t(Some(input_tensor),
None,
None,
None,
None,
false)
});
let (output, _, _) =
no_grad(|| electra_model.forward_t(Some(input_tensor), None, None, None, None, false));
// Print model predictions
// Print model predictions
for (position, token) in tokenized_input[0].token_ids.iter().enumerate() {
let probability = output.double_value(&[position as i64]);
let generated = if probability > 0.5 { "generated" } else { "original" };
println!("{:?}: {} ({:.1}%)",
tokenizer.decode([*token].to_vec(),
false,
false),
generated,
100f64 * probability)
let generated = if probability > 0.5 {
"generated"
} else {
"original"
};
println!(
"{:?}: {} ({:.1}%)",
tokenizer.decode([*token].to_vec(), false, false),
generated,
100f64 * probability
)
}
Ok(())
}

View File

@ -12,23 +12,31 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use rust_bert::resources::{Resource, download_resource, RemoteResource};
use rust_bert::electra::{ElectraConfig, ElectraForMaskedLM, ElectraModelResources, ElectraConfigResources, ElectraVocabResources};
use rust_bert::electra::{
ElectraConfig, ElectraConfigResources, ElectraForMaskedLM, ElectraModelResources,
ElectraVocabResources,
};
use rust_bert::resources::{download_resource, RemoteResource, Resource};
use rust_bert::Config;
use rust_tokenizers::{BertTokenizer, Tokenizer, TruncationStrategy, Vocab};
use tch::{Tensor, Device, nn, no_grad};
use tch::{nn, no_grad, Device, Tensor};
fn main() -> failure::Fallible<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(ElectraConfigResources::BASE_GENERATOR));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(ElectraVocabResources::BASE_GENERATOR));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(ElectraModelResources::BASE_GENERATOR));
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
ElectraConfigResources::BASE_GENERATOR,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
ElectraVocabResources::BASE_GENERATOR,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
ElectraModelResources::BASE_GENERATOR,
));
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
let weights_path = download_resource(&weights_resource)?;
// Set-up masked LM model
// Set-up masked LM model
let device = Device::Cpu;
let mut vs = nn::VarStore::new(device);
let tokenizer: BertTokenizer = BertTokenizer::from_file(vocab_path.to_str().unwrap(), true);
@ -36,41 +44,41 @@ fn main() -> failure::Fallible<()> {
let electra_model = ElectraForMaskedLM::new(&vs.root(), &config);
vs.load(weights_path)?;
// Define input
let input = ["Looks like one [MASK] is missing", "It was a very nice and [MASK] day"];
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
let tokenized_input = tokenized_input.
iter().
map(|input| input.token_ids.clone()).
map(|mut input| {
// Define input
let input = [
"Looks like one [MASK] is missing",
"It was a very nice and [MASK] day",
];
let tokenized_input =
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let tokenized_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]);
input
}).
map(|input|
Tensor::of_slice(&(input))).
collect::<Vec<_>>();
})
.map(|input| Tensor::of_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
let (output, _, _) = no_grad(|| {
electra_model
.forward_t(Some(input_tensor),
None,
None,
None,
None,
false)
});
let (output, _, _) =
no_grad(|| electra_model.forward_t(Some(input_tensor), None, None, None, None, false));
// Print masked tokens
// Print masked tokens
let index_1 = output.get(0).get(4).argmax(0, false);
let index_2 = output.get(1).get(7).argmax(0, false);
let word_1 = tokenizer.vocab().id_to_token(&index_1.int64_value(&[]));
let word_2 = tokenizer.vocab().id_to_token(&index_2.int64_value(&[]));
println!("{}", word_1); // Outputs "thing" : "Looks like one [thing] is missing"
println!("{}", word_2);// Outputs "sunny" : "It was a very nice and [sunny] day"
println!("{}", word_2); // Outputs "sunny" : "It was a very nice and [sunny] day"
Ok(())
}

View File

@ -12,12 +12,10 @@
extern crate failure;
use rust_bert::pipelines::generation::{GPT2Generator, LanguageGenerator, GenerateConfig};
use rust_bert::pipelines::generation::{GPT2Generator, GenerateConfig, LanguageGenerator};
fn main() -> failure::Fallible<()> {
// Set-up masked LM model
// Set-up masked LM model
let generate_config = GenerateConfig {
max_length: 30,
do_sample: true,
@ -30,10 +28,10 @@ fn main() -> failure::Fallible<()> {
let input_context = "The dog";
let second_input_context = "The cat was";
let output = model.generate(Some(vec!(input_context, second_input_context)), None);
let output = model.generate(Some(vec![input_context, second_input_context]), None);
for sentence in output {
println!("{:?}", sentence);
}
Ok(())
}
}

View File

@ -12,65 +12,82 @@
extern crate failure;
use tch::{Device, nn, Tensor};
use rust_tokenizers::{TruncationStrategy, Tokenizer, Gpt2Tokenizer};
use rust_bert::gpt2::{Gpt2Config, GPT2LMHeadModel, Gpt2ConfigResources, Gpt2VocabResources, Gpt2MergesResources, Gpt2ModelResources};
use rust_bert::pipelines::generation::{LMHeadModel, Cache};
use rust_bert::resources::{Resource, download_resource, RemoteResource};
use rust_bert::gpt2::{
GPT2LMHeadModel, Gpt2Config, Gpt2ConfigResources, Gpt2MergesResources, Gpt2ModelResources,
Gpt2VocabResources,
};
use rust_bert::pipelines::generation::{Cache, LMHeadModel};
use rust_bert::resources::{download_resource, RemoteResource, Resource};
use rust_bert::Config;
use rust_tokenizers::{Gpt2Tokenizer, Tokenizer, TruncationStrategy};
use tch::{nn, Device, Tensor};
fn main() -> failure::Fallible<()> {
// Resources set-up
let config_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
let config_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2));
let vocab_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2));
let merges_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
let weights_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
let merges_path = download_resource(&merges_resource)?;
let weights_path = download_resource(&weights_resource)?;
// Set-up masked LM model
// Set-up masked LM model
let device = Device::Cpu;
let mut vs = nn::VarStore::new(device);
let tokenizer: Gpt2Tokenizer = Gpt2Tokenizer::from_file(vocab_path.to_str().unwrap(), merges_path.to_str().unwrap(), false);
let tokenizer: Gpt2Tokenizer = Gpt2Tokenizer::from_file(
vocab_path.to_str().unwrap(),
merges_path.to_str().unwrap(),
false,
);
let config = Gpt2Config::from_file(config_path);
let gpt2_model = GPT2LMHeadModel::new(&vs.root(), &config);
vs.load(weights_path)?;
// Define input
// Define input
let input = ["One two three four five six seven eight nine ten eleven"];
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
let tokenized_input = tokenized_input.
iter().
map(|input| input.token_ids.clone()).
map(|mut input| {
let tokenized_input =
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let tokenized_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]);
input
}).
map(|input|
Tensor::of_slice(&(input))).
collect::<Vec<_>>();
})
.map(|input| Tensor::of_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
let (output, _, _, _, _) = gpt2_model.forward_t(
&Some(input_tensor),
Cache::None,
&None,
&None,
&None,
&None,
None,
&None,
false).unwrap();
// Forward pass
let (output, _, _, _, _) = gpt2_model
.forward_t(
&Some(input_tensor),
Cache::None,
&None,
&None,
&None,
&None,
None,
&None,
false,
)
.unwrap();
let next_word_id = output.get(0).get(-1).argmax(-1, true).int64_value(&[0]);
let next_word = tokenizer.decode(vec!(next_word_id), true, true);
let next_word = tokenizer.decode(vec![next_word_id], true, true);
println!("Provided input: {}", input[0]);
println!("Next word: {}", next_word);
Ok(())
}
}

View File

@ -15,20 +15,20 @@ extern crate failure;
use rust_bert::pipelines::ner::NERModel;
fn main() -> failure::Fallible<()> {
// Set-up model
// Set-up model
let ner_model = NERModel::new(Default::default())?;
// Define input
// Define input
let input = [
"My name is Amélie. I live in Москва.",
"Chongqing is a city in China."
"Chongqing is a city in China.",
];
// Run model
// Run model
let output = ner_model.predict(&input);
for entity in output {
println!("{:?}", entity);
}
Ok(())
}
}

View File

@ -12,66 +12,87 @@
extern crate failure;
use tch::{Device, nn, Tensor};
use rust_tokenizers::{TruncationStrategy, Tokenizer, OpenAiGptTokenizer};
use rust_bert::gpt2::Gpt2Config;
use rust_bert::openai_gpt::{OpenAIGPTLMHeadModel, OpenAiGptConfigResources, OpenAiGptVocabResources, OpenAiGptMergesResources, OpenAiGptModelResources};
use rust_bert::pipelines::generation::{LMHeadModel, Cache};
use rust_bert::resources::{Resource, download_resource, RemoteResource};
use rust_bert::openai_gpt::{
OpenAIGPTLMHeadModel, OpenAiGptConfigResources, OpenAiGptMergesResources,
OpenAiGptModelResources, OpenAiGptVocabResources,
};
use rust_bert::pipelines::generation::{Cache, LMHeadModel};
use rust_bert::resources::{download_resource, RemoteResource, Resource};
use rust_bert::Config;
use rust_tokenizers::{OpenAiGptTokenizer, Tokenizer, TruncationStrategy};
use tch::{nn, Device, Tensor};
fn main() -> failure::Fallible<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptConfigResources::GPT));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptVocabResources::GPT));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptMergesResources::GPT));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptModelResources::GPT));
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
OpenAiGptConfigResources::GPT,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
OpenAiGptVocabResources::GPT,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
OpenAiGptMergesResources::GPT,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
OpenAiGptModelResources::GPT,
));
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
let merges_path = download_resource(&merges_resource)?;
let weights_path = download_resource(&weights_resource)?;
// Set-up masked LM model
// Set-up masked LM model
let device = Device::Cpu;
let mut vs = nn::VarStore::new(device);
let tokenizer = OpenAiGptTokenizer::from_file(vocab_path.to_str().unwrap(), merges_path.to_str().unwrap(), true);
let tokenizer = OpenAiGptTokenizer::from_file(
vocab_path.to_str().unwrap(),
merges_path.to_str().unwrap(),
true,
);
let config = Gpt2Config::from_file(config_path);
let openai_gpt = OpenAIGPTLMHeadModel::new(&vs.root(), &config);
vs.load(weights_path)?;
// Define input
// Define input
let input = ["Wondering what the next word will"];
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
let tokenized_input = tokenized_input.
iter().
map(|input| input.token_ids.clone()).
map(|mut input| {
let tokenized_input =
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let tokenized_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]);
input
}).
map(|input|
Tensor::of_slice(&(input))).
collect::<Vec<_>>();
})
.map(|input| Tensor::of_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
let (output, _, _, _, _) = openai_gpt.forward_t(
&Some(input_tensor),
Cache::None,
&None,
&None,
&None,
&None,
None,
&None,
false).unwrap();
// Forward pass
let (output, _, _, _, _) = openai_gpt
.forward_t(
&Some(input_tensor),
Cache::None,
&None,
&None,
&None,
&None,
None,
&None,
false,
)
.unwrap();
let next_word_id = output.get(0).get(-1).argmax(-1, true).int64_value(&[0]);
let next_word = tokenizer.decode(vec!(next_word_id), true, true);
let next_word = tokenizer.decode(vec![next_word_id], true, true);
println!("Provided input: {}", input[0]);
println!("Next word: {}", next_word);
Ok(())
}
}

View File

@ -12,23 +12,28 @@
extern crate failure;
use rust_bert::pipelines::question_answering::{QuestionAnsweringModel, QaInput};
use rust_bert::pipelines::question_answering::{QaInput, QuestionAnsweringModel};
fn main() -> failure::Fallible<()> {
// Set-up Question Answering model
// Set-up Question Answering model
let qa_model = QuestionAnsweringModel::new(Default::default())?;
// Define input
// Define input
let question_1 = String::from("Where does Amy live ?");
let context_1 = String::from("Amy lives in Amsterdam");
let question_2 = String::from("Where does Eric live");
let context_2 = String::from("While Amy lives in Amsterdam, Eric is in The Hague.");
let qa_input_1 = QaInput { question: question_1, context: context_1 };
let qa_input_2 = QaInput { question: question_2, context: context_2 };
let qa_input_1 = QaInput {
question: question_1,
context: context_1,
};
let qa_input_2 = QaInput {
question: question_2,
context: context_2,
};
// Get answer
let answers = qa_model.predict(&vec!(qa_input_1, qa_input_2), 1, 32);
// Get answer
let answers = qa_model.predict(&vec![qa_input_1, qa_input_2], 1, 32);
println!("{:?}", answers);
Ok(())
}
}

View File

@ -12,77 +12,99 @@
extern crate failure;
use tch::{Device, nn, Tensor, no_grad};
use rust_tokenizers::{TruncationStrategy, Tokenizer, Vocab, RobertaTokenizer};
use rust_bert::Config;
use rust_bert::bert::BertConfig;
use rust_bert::roberta::{RobertaForMaskedLM, RobertaVocabResources, RobertaConfigResources, RobertaMergesResources, RobertaModelResources};
use rust_bert::resources::{Resource, download_resource, RemoteResource};
use rust_bert::resources::{download_resource, RemoteResource, Resource};
use rust_bert::roberta::{
RobertaConfigResources, RobertaForMaskedLM, RobertaMergesResources, RobertaModelResources,
RobertaVocabResources,
};
use rust_bert::Config;
use rust_tokenizers::{RobertaTokenizer, Tokenizer, TruncationStrategy, Vocab};
use tch::{nn, no_grad, Device, Tensor};
fn main() -> failure::Fallible<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaConfigResources::ROBERTA));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaVocabResources::ROBERTA));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaMergesResources::ROBERTA));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaModelResources::ROBERTA));
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
RobertaConfigResources::ROBERTA,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
RobertaVocabResources::ROBERTA,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
RobertaMergesResources::ROBERTA,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
RobertaModelResources::ROBERTA,
));
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
let merges_path = download_resource(&merges_resource)?;
let weights_path = download_resource(&weights_resource)?;
// Set-up masked LM model
// Set-up masked LM model
let device = Device::Cpu;
let mut vs = nn::VarStore::new(device);
let tokenizer: RobertaTokenizer = RobertaTokenizer::from_file(vocab_path.to_str().unwrap(), merges_path.to_str().unwrap(), true);
let tokenizer: RobertaTokenizer = RobertaTokenizer::from_file(
vocab_path.to_str().unwrap(),
merges_path.to_str().unwrap(),
true,
);
let config = BertConfig::from_file(config_path);
let bert_model = RobertaForMaskedLM::new(&vs.root(), &config);
vs.load(weights_path)?;
// Define input
let input = ["<pad> Looks like one thing is missing", "It\'s like comparing oranges to apples"];
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
let mut tokenized_input = tokenized_input.
iter().
map(|input| input.token_ids.clone()).
map(|mut input| {
// Define input
let input = [
"<pad> Looks like one thing is missing",
"It\'s like comparing oranges to apples",
];
let tokenized_input =
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let mut tokenized_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]);
input
}).
collect::<Vec<_>>();
})
.collect::<Vec<_>>();
// Masking the token [thing] of sentence 1 and [oranges] of sentence 2
// Masking the token [thing] of sentence 1 and [oranges] of sentence 2
tokenized_input[0][4] = 103;
tokenized_input[1][5] = 103;
let tokenized_input = tokenized_input.
iter().
map(|input|
Tensor::of_slice(&(input))).
collect::<Vec<_>>();
let tokenized_input = tokenized_input
.iter()
.map(|input| Tensor::of_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
// Forward pass
let (output, _, _) = no_grad(|| {
bert_model
.forward_t(Some(input_tensor),
None,
None,
None,
None,
&None,
&None,
false)
bert_model.forward_t(
Some(input_tensor),
None,
None,
None,
None,
&None,
&None,
false,
)
});
// Print masked tokens
// Print masked tokens
let index_1 = output.get(0).get(4).argmax(0, false);
let index_2 = output.get(1).get(5).argmax(0, false);
let word_1 = tokenizer.vocab().id_to_token(&index_1.int64_value(&[]));
let word_2 = tokenizer.vocab().id_to_token(&index_2.int64_value(&[]));
println!("{}", word_1); // Outputs "some" : "Looks like [some] thing is missing"
println!("{}", word_2);// Outputs "apple" : "It\'s like comparing [apple] to apples"
println!("{}", word_2); // Outputs "apple" : "It\'s like comparing [apple] to apples"
Ok(())
}
}

View File

@ -14,23 +14,22 @@ extern crate failure;
use rust_bert::pipelines::sentiment::SentimentModel;
fn main() -> failure::Fallible<()> {
// Set-up classifier
// Set-up classifier
let sentiment_classifier = SentimentModel::new(Default::default())?;
// Define input
// Define input
let input = [
"Probably my all-time favorite movie, a story of selflessness, sacrifice and dedication to a noble cause, but it's not preachy or boring.",
"This film tried to be too many things all at once: stinging political satire, Hollywood blockbuster, sappy romantic comedy, family values promo...",
"If you like original gut wrenching laughter you will like this movie. If you are young or old then you will love this movie, hell even my mom liked it.",
];
// Run model
// Run model
let output = sentiment_classifier.predict(&input);
for sentiment in output {
println!("{:?}", sentiment);
}
Ok(())
}
}

View File

@ -15,21 +15,21 @@ extern crate failure;
use rust_bert::pipelines::sequence_classification::SequenceClassificationModel;
fn main() -> failure::Fallible<()> {
// Set-up model
// Set-up model
let sequence_classification_model = SequenceClassificationModel::new(Default::default())?;
// Define input
// Define input
let input = [
"Probably my all-time favorite movie, a story of selflessness, sacrifice and dedication to a noble cause, but it's not preachy or boring.",
"This film tried to be too many things all at once: stinging political satire, Hollywood blockbuster, sappy romantic comedy, family values promo...",
"If you like original gut wrenching laughter you will like this movie. If you are young or old then you will love this movie, hell even my mom liked it.",
];
// Run model
// Run model
let output = sequence_classification_model.predict(&input);
for label in output {
println!("{:?}", label);
}
Ok(())
}
}

View File

@ -15,21 +15,21 @@ extern crate failure;
use rust_bert::pipelines::sequence_classification::SequenceClassificationModel;
fn main() -> failure::Fallible<()> {
// Set-up model
// Set-up model
let sequence_classification_model = SequenceClassificationModel::new(Default::default())?;
// Define input
// Define input
let input = [
"Probably my all-time favorite movie, a story of selflessness, sacrifice and dedication to a noble cause, but it's not preachy or boring.",
"This is a neutral sentence.",
"If you like original gut wrenching laughter you will like this movie. If you are young or old then you will love this movie, hell even my mom liked it.",
];
// Run model
// Run model
let output = sequence_classification_model.predict_multilabel(&input, 0.05);
for label in output {
println!("{:?}", label);
}
Ok(())
}
}

View File

@ -12,24 +12,23 @@
extern crate failure;
use std::path::PathBuf;
use rust_bert::pipelines::question_answering::{squad_processor, QuestionAnsweringModel};
use std::env;
use rust_bert::pipelines::question_answering::{QuestionAnsweringModel, squad_processor};
use std::path::PathBuf;
fn main() -> failure::Fallible<()> {
// Set-up Question Answering model
// Set-up Question Answering model
let qa_model = QuestionAnsweringModel::new(Default::default())?;
// Define input
// Define input
let mut squad_path = PathBuf::from(env::var("squad_dataset")
.expect("Please set the \"squad_dataset\" environment variable pointing to the SQuAD dataset folder"));
squad_path.push("dev-v2.0.json");
let qa_inputs = squad_processor(squad_path);
// Get answer
// Get answer
let answers = qa_model.predict(&qa_inputs, 1, 64);
println!("Sample answer: {:?}", answers.first().unwrap());
println!("{}", answers.len());
Ok(())
}
}

View File

@ -10,35 +10,42 @@
// See the License for the specific language governing permissions and
// limitations under the License.
extern crate failure;
extern crate dirs;
extern crate failure;
use std::path::PathBuf;
use rust_bert::pipelines::sentiment::{SentimentModel, ss2_processor};
use rust_bert::pipelines::sentiment::{ss2_processor, SentimentModel};
use std::env;
use std::path::PathBuf;
fn main() -> failure::Fallible<()> {
// Set-up classifier
// Set-up classifier
let sentiment_classifier = SentimentModel::new(Default::default())?;
// Define input
// Define input
let mut sst2_path = PathBuf::from(env::var("SST2_PATH")
.expect("Please set the \"squad_dataset\" environment variable pointing to the SQuAD dataset folder"));
sst2_path.push("train.tsv");
let inputs = ss2_processor(sst2_path).unwrap();
// Run model
// Run model
let batch_size = 64;
let mut output = vec!();
let mut output = vec![];
for batch in inputs.chunks(batch_size) {
output.push(sentiment_classifier.predict(batch.iter().map(|v| v.as_str()).collect::<Vec<&str>>().as_slice()));
output.push(
sentiment_classifier.predict(
batch
.iter()
.map(|v| v.as_str())
.collect::<Vec<&str>>()
.as_slice(),
),
);
}
let mut flat_outputs = vec!();
let mut flat_outputs = vec![];
for batch_output in output.iter_mut() {
flat_outputs.append(batch_output);
}
println!("{:?}", flat_outputs.len());
Ok(())
}
}

View File

@ -14,7 +14,6 @@ extern crate failure;
use rust_bert::pipelines::summarization::SummarizationModel;
fn main() -> failure::Fallible<()> {
let summarization_model = SummarizationModel::new(Default::default())?;
@ -40,11 +39,11 @@ on K2-18b lasts 33 Earth days. According to The Guardian, astronomers were optim
telescope scheduled for launch in 2021 and the European Space Agency's 2028 ARIEL program, could reveal more \
about exoplanets like K2-18b."];
// Credits: WikiNews, CC BY 2.5 license (https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b)
// Credits: WikiNews, CC BY 2.5 license (https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b)
let _output = summarization_model.summarize(&input);
for sentence in _output {
println!("{:?}", sentence);
};
}
Ok(())
}
}

View File

@ -10,28 +10,36 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use rust_bert::pipelines::token_classification::{TokenClassificationModel, TokenClassificationConfig, LabelAggregationOption};
use rust_bert::resources::{Resource, RemoteResource};
use rust_bert::bert::{BertModelResources, BertVocabResources, BertConfigResources};
use rust_bert::bert::{BertConfigResources, BertModelResources, BertVocabResources};
use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::token_classification::{
LabelAggregationOption, TokenClassificationConfig, TokenClassificationModel,
};
use rust_bert::resources::{RemoteResource, Resource};
fn main() -> failure::Fallible<()> {
// Load a configuration
let config = TokenClassificationConfig::new(ModelType::Bert,
Resource::Remote(RemoteResource::from_pretrained(BertModelResources::BERT_NER)),
Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT_NER)),
Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT_NER)),
None, //merges resource only relevant with ModelType::Roberta
false, //lowercase
LabelAggregationOption::Mode,
// Load a configuration
let config = TokenClassificationConfig::new(
ModelType::Bert,
Resource::Remote(RemoteResource::from_pretrained(
BertModelResources::BERT_NER,
)),
Resource::Remote(RemoteResource::from_pretrained(
BertConfigResources::BERT_NER,
)),
Resource::Remote(RemoteResource::from_pretrained(
BertVocabResources::BERT_NER,
)),
None, //merges resource only relevant with ModelType::Roberta
false, //lowercase
LabelAggregationOption::Mode,
);
// Create the model
// Create the model
let token_classification_model = TokenClassificationModel::new(config)?;
let input = [
"My name is Amélie. I live in Москва.",
"Chongqing is a city in China."
"Chongqing is a city in China.",
];
let token_outputs = token_classification_model.predict(&input, true, false); //ignore_first_label = true (only returns the NER parts, ignoring first label O)
@ -40,4 +48,4 @@ fn main() -> failure::Fallible<()> {
}
Ok(())
}
}

View File

@ -13,12 +13,12 @@
extern crate failure;
use rust_bert::pipelines::translation::{TranslationConfig, TranslationModel, Language};
use rust_bert::pipelines::translation::{Language, TranslationConfig, TranslationModel};
use tch::Device;
fn main() -> failure::Fallible<()> {
let translation_config = TranslationConfig::new(Language::EnglishToGerman, Device::cuda_if_available());
let translation_config =
TranslationConfig::new(Language::EnglishToGerman, Device::cuda_if_available());
let model = TranslationModel::new(translation_config)?;
let input_context_1 = "The quick brown fox jumps over the lazy dog";
@ -30,4 +30,4 @@ fn main() -> failure::Fallible<()> {
println!("{}", sentence);
}
Ok(())
}
}

1
rustfmt.toml Normal file
View File

@ -0,0 +1 @@
format_code_in_doc_comments = true

View File

@ -11,16 +11,15 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::collections::HashMap;
use crate::Config;
use serde::{Deserialize, Serialize};
use crate::albert::embeddings::AlbertEmbeddings;
use crate::albert::encoder::AlbertTransformer;
use tch::{nn, Tensor, Kind};
use crate::common::activations::{_tanh, _gelu_new, _gelu, _relu, _mish};
use tch::nn::Module;
use crate::common::activations::{_gelu, _gelu_new, _mish, _relu, _tanh};
use crate::common::dropout::Dropout;
use crate::Config;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use tch::nn::Module;
use tch::{nn, Kind, Tensor};
/// # ALBERT Pretrained model weight files
pub struct AlbertModelResources;
@ -33,20 +32,28 @@ pub struct AlbertVocabResources;
impl AlbertModelResources {
/// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/ALBERT. Modified with conversion to C-array format.
pub const ALBERT_BASE_V2: (&'static str, &'static str) = ("albert-base-v2/model.ot", "https://cdn.huggingface.co/albert-base-v2/rust_model.ot");
pub const ALBERT_BASE_V2: (&'static str, &'static str) = (
"albert-base-v2/model.ot",
"https://cdn.huggingface.co/albert-base-v2/rust_model.ot",
);
}
impl AlbertConfigResources {
/// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/ALBERT. Modified with conversion to C-array format.
pub const ALBERT_BASE_V2: (&'static str, &'static str) = ("albert-base-v2/config.json", "https://cdn.huggingface.co/albert-base-v2-config.json");
pub const ALBERT_BASE_V2: (&'static str, &'static str) = (
"albert-base-v2/config.json",
"https://cdn.huggingface.co/albert-base-v2-config.json",
);
}
impl AlbertVocabResources {
/// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/ALBERT. Modified with conversion to C-array format.
pub const ALBERT_BASE_V2: (&'static str, &'static str) = ("albert-base-v2/spiece.model", "https://cdn.huggingface.co/albert-base-v2-spiece.model");
pub const ALBERT_BASE_V2: (&'static str, &'static str) = (
"albert-base-v2/spiece.model",
"https://cdn.huggingface.co/albert-base-v2-spiece.model",
);
}
#[allow(non_camel_case_types)]
#[derive(Clone, Debug, Serialize, Deserialize)]
/// # Activation function used in the attention layer and masked language model head
@ -61,7 +68,6 @@ pub enum Activation {
mish,
}
#[derive(Debug, Serialize, Deserialize)]
/// # ALBERT model configuration
/// Defines the ALBERT model architecture (e.g. number of layers, hidden layer size, label mapping...)
@ -123,10 +129,10 @@ impl AlbertModel {
/// # Example
///
/// ```no_run
/// use tch::{nn, Device};
/// use rust_bert::albert::{AlbertConfig, AlbertModel};
/// use rust_bert::Config;
/// use std::path::Path;
/// use rust_bert::albert::{AlbertConfig, AlbertModel};
/// use tch::{nn, Device};
///
/// let config_path = Path::new("path/to/config.json");
/// let device = Device::Cpu;
@ -134,14 +140,23 @@ impl AlbertModel {
/// let config = AlbertConfig::from_file(config_path);
/// let albert: AlbertModel = AlbertModel::new(&(&p.root() / "albert"), &config);
/// ```
///
pub fn new(p: &nn::Path, config: &AlbertConfig) -> AlbertModel {
let embeddings = AlbertEmbeddings::new(&(p / "embeddings"), config);
let encoder = AlbertTransformer::new(&(p / "encoder"), config);
let pooler = nn::linear(&(p / "pooler"), config.hidden_size, config.hidden_size, Default::default());
let pooler = nn::linear(
&(p / "pooler"),
config.hidden_size,
config.hidden_size,
Default::default(),
);
let pooler_activation = Box::new(_tanh);
AlbertModel { embeddings, encoder, pooler, pooler_activation }
AlbertModel {
embeddings,
encoder,
pooler,
pooler_activation,
}
}
/// Forward pass through the model
@ -165,75 +180,103 @@ impl AlbertModel {
/// # Example
///
/// ```no_run
///# use tch::{nn, Device, Tensor, no_grad};
///# use rust_bert::Config;
///# use std::path::Path;
///# use tch::kind::Kind::Int64;
/// # use tch::{nn, Device, Tensor, no_grad};
/// # use rust_bert::Config;
/// # use std::path::Path;
/// # use tch::kind::Kind::Int64;
/// use rust_bert::albert::{AlbertConfig, AlbertModel};
///# let config_path = Path::new("path/to/config.json");
///# let device = Device::Cpu;
///# let vs = nn::VarStore::new(device);
///# let config = AlbertConfig::from_file(config_path);
///# let albert_model: AlbertModel = AlbertModel::new(&vs.root(), &config);
/// let (batch_size, sequence_length) = (64, 128);
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
/// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
/// let position_ids = Tensor::arange(sequence_length, (Int64, device)).expand(&[batch_size, sequence_length], true);
///
/// let (output, pooled_output, all_hidden_states, all_attentions) = no_grad(|| {
/// albert_model
/// .forward_t(Some(input_tensor),
/// Some(mask),
/// Some(token_type_ids),
/// Some(position_ids),
/// None,
/// false).unwrap()
/// });
/// # let config_path = Path::new("path/to/config.json");
/// # let device = Device::Cpu;
/// # let vs = nn::VarStore::new(device);
/// # let config = AlbertConfig::from_file(config_path);
/// # let albert_model: AlbertModel = AlbertModel::new(&vs.root(), &config);
/// let (batch_size, sequence_length) = (64, 128);
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
/// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
/// let position_ids = Tensor::arange(sequence_length, (Int64, device))
/// .expand(&[batch_size, sequence_length], true);
///
/// let (output, pooled_output, all_hidden_states, all_attentions) = no_grad(|| {
/// albert_model
/// .forward_t(
/// Some(input_tensor),
/// Some(mask),
/// Some(token_type_ids),
/// Some(position_ids),
/// None,
/// false,
/// )
/// .unwrap()
/// });
/// ```
///
pub fn forward_t(&self,
input_ids: Option<Tensor>,
mask: Option<Tensor>,
token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool)
-> Result<(Tensor, Tensor, Option<Vec<Tensor>>, Option<Vec<Vec<Tensor>>>), &'static str> {
pub fn forward_t(
&self,
input_ids: Option<Tensor>,
mask: Option<Tensor>,
token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool,
) -> Result<
(
Tensor,
Tensor,
Option<Vec<Tensor>>,
Option<Vec<Vec<Tensor>>>,
),
&'static str,
> {
let (input_shape, device) = match &input_ids {
Some(input_value) => match &input_embeds {
Some(_) => { return Err("Only one of input ids or input embeddings may be set"); }
None => (input_value.size(), input_value.device())
}
Some(_) => {
return Err("Only one of input ids or input embeddings may be set");
}
None => (input_value.size(), input_value.device()),
},
None => match &input_embeds {
Some(embeds) => (vec!(embeds.size()[0], embeds.size()[1]), embeds.device()),
None => { return Err("At least one of input ids or input embeddings must be set"); }
}
Some(embeds) => (vec![embeds.size()[0], embeds.size()[1]], embeds.device()),
None => {
return Err("At least one of input ids or input embeddings must be set");
}
},
};
let mask = match mask {
Some(value) => value,
None => Tensor::ones(&input_shape, (Kind::Int64, device))
None => Tensor::ones(&input_shape, (Kind::Int64, device)),
};
let extended_attention_mask = mask.unsqueeze(1).unsqueeze(2);
let extended_attention_mask: Tensor = (extended_attention_mask.ones_like() - extended_attention_mask) * -10000.0;
let extended_attention_mask: Tensor =
(extended_attention_mask.ones_like() - extended_attention_mask) * -10000.0;
let embedding_output = match self.embeddings.forward_t(input_ids, token_type_ids, position_ids, input_embeds, train) {
let embedding_output = match self.embeddings.forward_t(
input_ids,
token_type_ids,
position_ids,
input_embeds,
train,
) {
Ok(value) => value,
Err(e) => { return Err(e); }
Err(e) => {
return Err(e);
}
};
let (hidden_state, all_hidden_states, all_attentions) =
self.encoder.forward_t(&embedding_output,
Some(extended_attention_mask),
train);
self.encoder
.forward_t(&embedding_output, Some(extended_attention_mask), train);
let pooled_output = self.pooler.forward(&hidden_state.select(1, 0));
let pooled_output = (self.pooler_activation)(&pooled_output);
Ok((hidden_state, pooled_output, all_hidden_states, all_attentions))
Ok((
hidden_state,
pooled_output,
all_hidden_states,
all_attentions,
))
}
}
@ -248,21 +291,43 @@ impl AlbertMLMHead {
pub fn new(p: &nn::Path, config: &AlbertConfig) -> AlbertMLMHead {
let layer_norm_eps = match config.layer_norm_eps {
Some(value) => value,
None => 1e-12
None => 1e-12,
};
let layer_norm_config = nn::LayerNormConfig { eps: layer_norm_eps, ..Default::default() };
let layer_norm = nn::layer_norm(&(p / "LayerNorm"), vec![config.embedding_size], layer_norm_config);
let dense = nn::linear(&(p / "dense"), config.hidden_size, config.embedding_size, Default::default());
let decoder = nn::linear(&(p / "decoder"), config.embedding_size, config.vocab_size, Default::default());
let layer_norm_config = nn::LayerNormConfig {
eps: layer_norm_eps,
..Default::default()
};
let layer_norm = nn::layer_norm(
&(p / "LayerNorm"),
vec![config.embedding_size],
layer_norm_config,
);
let dense = nn::linear(
&(p / "dense"),
config.hidden_size,
config.embedding_size,
Default::default(),
);
let decoder = nn::linear(
&(p / "decoder"),
config.embedding_size,
config.vocab_size,
Default::default(),
);
let activation = Box::new(match &config.hidden_act {
Activation::gelu_new => _gelu_new,
Activation::gelu => _gelu,
Activation::relu => _relu,
Activation::mish => _mish
Activation::mish => _mish,
});
AlbertMLMHead { layer_norm, dense, decoder, activation }
AlbertMLMHead {
layer_norm,
dense,
decoder,
activation,
}
}
pub fn forward(&self, hidden_states: &Tensor) -> Tensor {
@ -292,10 +357,10 @@ impl AlbertForMaskedLM {
/// # Example
///
/// ```no_run
/// use tch::{nn, Device};
/// use rust_bert::albert::{AlbertConfig, AlbertForMaskedLM};
/// use rust_bert::Config;
/// use std::path::Path;
/// use rust_bert::albert::{AlbertConfig, AlbertForMaskedLM};
/// use tch::{nn, Device};
///
/// let config_path = Path::new("path/to/config.json");
/// let device = Device::Cpu;
@ -303,12 +368,14 @@ impl AlbertForMaskedLM {
/// let config = AlbertConfig::from_file(config_path);
/// let albert: AlbertForMaskedLM = AlbertForMaskedLM::new(&p.root(), &config);
/// ```
///
pub fn new(p: &nn::Path, config: &AlbertConfig) -> AlbertForMaskedLM {
let albert = AlbertModel::new(&(p / "albert"), config);
let predictions = AlbertMLMHead::new(&(p / "predictions"), config);
AlbertForMaskedLM { albert, predictions }
AlbertForMaskedLM {
albert,
predictions,
}
}
/// Forward pass through the model
@ -331,42 +398,54 @@ impl AlbertForMaskedLM {
/// # Example
///
/// ```no_run
///# use tch::{nn, Device, Tensor, no_grad};
///# use rust_bert::Config;
///# use std::path::Path;
///# use tch::kind::Kind::Int64;
/// # use tch::{nn, Device, Tensor, no_grad};
/// # use rust_bert::Config;
/// # use std::path::Path;
/// # use tch::kind::Kind::Int64;
/// use rust_bert::albert::{AlbertConfig, AlbertForMaskedLM};
///# let config_path = Path::new("path/to/config.json");
///# let device = Device::Cpu;
///# let vs = nn::VarStore::new(device);
///# let config = AlbertConfig::from_file(config_path);
///# let albert_model: AlbertForMaskedLM = AlbertForMaskedLM::new(&vs.root(), &config);
/// let (batch_size, sequence_length) = (64, 128);
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
/// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
/// let position_ids = Tensor::arange(sequence_length, (Int64, device)).expand(&[batch_size, sequence_length], true);
///
/// let (output, all_hidden_states, all_attentions) = no_grad(|| {
/// albert_model
/// .forward_t(Some(input_tensor),
/// Some(mask),
/// Some(token_type_ids),
/// Some(position_ids),
/// None,
/// false)
/// });
/// # let config_path = Path::new("path/to/config.json");
/// # let device = Device::Cpu;
/// # let vs = nn::VarStore::new(device);
/// # let config = AlbertConfig::from_file(config_path);
/// # let albert_model: AlbertForMaskedLM = AlbertForMaskedLM::new(&vs.root(), &config);
/// let (batch_size, sequence_length) = (64, 128);
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
/// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
/// let position_ids = Tensor::arange(sequence_length, (Int64, device))
/// .expand(&[batch_size, sequence_length], true);
///
/// let (output, all_hidden_states, all_attentions) = no_grad(|| {
/// albert_model.forward_t(
/// Some(input_tensor),
/// Some(mask),
/// Some(token_type_ids),
/// Some(position_ids),
/// None,
/// false,
/// )
/// });
/// ```
///
pub fn forward_t(&self,
input_ids: Option<Tensor>,
mask: Option<Tensor>,
token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Vec<Tensor>>>) {
let (hidden_state, _, all_hidden_states, all_attentions) = self.albert.forward_t(input_ids, mask, token_type_ids, position_ids, input_embeds, train).unwrap();
pub fn forward_t(
&self,
input_ids: Option<Tensor>,
mask: Option<Tensor>,
token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool,
) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Vec<Tensor>>>) {
let (hidden_state, _, all_hidden_states, all_attentions) = self
.albert
.forward_t(
input_ids,
mask,
token_type_ids,
position_ids,
input_embeds,
train,
)
.unwrap();
let prediction_scores = self.predictions.forward(&hidden_state);
(prediction_scores, all_hidden_states, all_attentions)
}
@ -395,29 +474,42 @@ impl AlbertForSequenceClassification {
/// # Example
///
/// ```no_run
/// use tch::{nn, Device};
/// use rust_bert::albert::{AlbertConfig, AlbertForSequenceClassification};
/// use rust_bert::Config;
/// use std::path::Path;
/// use rust_bert::albert::{AlbertConfig, AlbertForSequenceClassification};
/// use tch::{nn, Device};
///
/// let config_path = Path::new("path/to/config.json");
/// let device = Device::Cpu;
/// let p = nn::VarStore::new(device);
/// let config = AlbertConfig::from_file(config_path);
/// let albert: AlbertForSequenceClassification = AlbertForSequenceClassification::new(&p.root(), &config);
/// let albert: AlbertForSequenceClassification =
/// AlbertForSequenceClassification::new(&p.root(), &config);
/// ```
///
pub fn new(p: &nn::Path, config: &AlbertConfig) -> AlbertForSequenceClassification {
let albert = AlbertModel::new(&(p / "albert"), config);
let classifier_dropout_prob = match config.classifier_dropout_prob {
Some(value) => value,
None => 0.1
None => 0.1,
};
let dropout = Dropout::new(classifier_dropout_prob);
let num_labels = config.id2label.as_ref().expect("num_labels not provided in configuration").len() as i64;
let classifier = nn::linear(&(p / "classifier"), config.hidden_size, num_labels, Default::default());
let num_labels = config
.id2label
.as_ref()
.expect("num_labels not provided in configuration")
.len() as i64;
let classifier = nn::linear(
&(p / "classifier"),
config.hidden_size,
num_labels,
Default::default(),
);
AlbertForSequenceClassification { albert, dropout, classifier }
AlbertForSequenceClassification {
albert,
dropout,
classifier,
}
}
/// Forward pass through the model
@ -440,16 +532,16 @@ impl AlbertForSequenceClassification {
/// # Example
///
/// ```no_run
///# use tch::{nn, Device, Tensor, no_grad};
///# use rust_bert::Config;
///# use std::path::Path;
///# use tch::kind::Kind::Int64;
/// # use tch::{nn, Device, Tensor, no_grad};
/// # use rust_bert::Config;
/// # use std::path::Path;
/// # use tch::kind::Kind::Int64;
/// use rust_bert::albert::{AlbertConfig, AlbertForSequenceClassification};
///# let config_path = Path::new("path/to/config.json");
///# let device = Device::Cpu;
///# let vs = nn::VarStore::new(device);
///# let config = AlbertConfig::from_file(config_path);
///# let albert_model: AlbertForSequenceClassification = AlbertForSequenceClassification::new(&vs.root(), &config);
/// # let config_path = Path::new("path/to/config.json");
/// # let device = Device::Cpu;
/// # let vs = nn::VarStore::new(device);
/// # let config = AlbertConfig::from_file(config_path);
/// # let albert_model: AlbertForSequenceClassification = AlbertForSequenceClassification::new(&vs.root(), &config);
/// let (batch_size, sequence_length) = (64, 128);
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
@ -465,18 +557,30 @@ impl AlbertForSequenceClassification {
/// None,
/// false)
/// });
///
/// ```
///
pub fn forward_t(&self,
input_ids: Option<Tensor>,
mask: Option<Tensor>,
token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Vec<Tensor>>>) {
let (_, pooled_output, all_hidden_states, all_attentions) = self.albert.forward_t(input_ids, mask, token_type_ids, position_ids, input_embeds, train).unwrap();
let logits = pooled_output.apply_t(&self.dropout, train).apply(&self.classifier);
pub fn forward_t(
&self,
input_ids: Option<Tensor>,
mask: Option<Tensor>,
token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool,
) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Vec<Tensor>>>) {
let (_, pooled_output, all_hidden_states, all_attentions) = self
.albert
.forward_t(
input_ids,
mask,
token_type_ids,
position_ids,
input_embeds,
train,
)
.unwrap();
let logits = pooled_output
.apply_t(&self.dropout, train)
.apply(&self.classifier);
(logits, all_hidden_states, all_attentions)
}
}
@ -505,25 +609,38 @@ impl AlbertForTokenClassification {
/// # Example
///
/// ```no_run
/// use tch::{nn, Device};
/// use rust_bert::albert::{AlbertConfig, AlbertForTokenClassification};
/// use rust_bert::Config;
/// use std::path::Path;
/// use rust_bert::albert::{AlbertConfig, AlbertForTokenClassification};
/// use tch::{nn, Device};
///
/// let config_path = Path::new("path/to/config.json");
/// let device = Device::Cpu;
/// let p = nn::VarStore::new(device);
/// let config = AlbertConfig::from_file(config_path);
/// let albert: AlbertForTokenClassification = AlbertForTokenClassification::new(&p.root(), &config);
/// let albert: AlbertForTokenClassification =
/// AlbertForTokenClassification::new(&p.root(), &config);
/// ```
///
pub fn new(p: &nn::Path, config: &AlbertConfig) -> AlbertForTokenClassification {
let albert = AlbertModel::new(&(p / "albert"), config);
let dropout = Dropout::new(config.hidden_dropout_prob);
let num_labels = config.id2label.as_ref().expect("num_labels not provided in configuration").len() as i64;
let classifier = nn::linear(&(p / "classifier"), config.hidden_size, num_labels, Default::default());
let num_labels = config
.id2label
.as_ref()
.expect("num_labels not provided in configuration")
.len() as i64;
let classifier = nn::linear(
&(p / "classifier"),
config.hidden_size,
num_labels,
Default::default(),
);
AlbertForTokenClassification { albert, dropout, classifier }
AlbertForTokenClassification {
albert,
dropout,
classifier,
}
}
/// Forward pass through the model
@ -546,16 +663,16 @@ impl AlbertForTokenClassification {
/// # Example
///
/// ```no_run
///# use tch::{nn, Device, Tensor, no_grad};
///# use rust_bert::Config;
///# use std::path::Path;
///# use tch::kind::Kind::Int64;
/// # use tch::{nn, Device, Tensor, no_grad};
/// # use rust_bert::Config;
/// # use std::path::Path;
/// # use tch::kind::Kind::Int64;
/// use rust_bert::albert::{AlbertConfig, AlbertForTokenClassification};
///# let config_path = Path::new("path/to/config.json");
///# let device = Device::Cpu;
///# let vs = nn::VarStore::new(device);
///# let config = AlbertConfig::from_file(config_path);
///# let albert_model: AlbertForTokenClassification = AlbertForTokenClassification::new(&vs.root(), &config);
/// # let config_path = Path::new("path/to/config.json");
/// # let device = Device::Cpu;
/// # let vs = nn::VarStore::new(device);
/// # let config = AlbertConfig::from_file(config_path);
/// # let albert_model: AlbertForTokenClassification = AlbertForTokenClassification::new(&vs.root(), &config);
/// let (batch_size, sequence_length) = (64, 128);
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
@ -571,18 +688,30 @@ impl AlbertForTokenClassification {
/// None,
/// false)
/// });
///
/// ```
///
pub fn forward_t(&self,
input_ids: Option<Tensor>,
mask: Option<Tensor>,
token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Vec<Tensor>>>) {
let (sequence_output, _, all_hidden_states, all_attentions) = self.albert.forward_t(input_ids, mask, token_type_ids, position_ids, input_embeds, train).unwrap();
let logits = sequence_output.apply_t(&self.dropout, train).apply(&self.classifier);
pub fn forward_t(
&self,
input_ids: Option<Tensor>,
mask: Option<Tensor>,
token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool,
) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Vec<Tensor>>>) {
let (sequence_output, _, all_hidden_states, all_attentions) = self
.albert
.forward_t(
input_ids,
mask,
token_type_ids,
position_ids,
input_embeds,
train,
)
.unwrap();
let logits = sequence_output
.apply_t(&self.dropout, train)
.apply(&self.classifier);
(logits, all_hidden_states, all_attentions)
}
}
@ -610,10 +739,10 @@ impl AlbertForQuestionAnswering {
/// # Example
///
/// ```no_run
/// use tch::{nn, Device};
/// use rust_bert::albert::{AlbertConfig, AlbertForQuestionAnswering};
/// use rust_bert::Config;
/// use std::path::Path;
/// use rust_bert::albert::{AlbertConfig, AlbertForQuestionAnswering};
/// use tch::{nn, Device};
///
/// let config_path = Path::new("path/to/config.json");
/// let device = Device::Cpu;
@ -621,11 +750,15 @@ impl AlbertForQuestionAnswering {
/// let config = AlbertConfig::from_file(config_path);
/// let albert: AlbertForQuestionAnswering = AlbertForQuestionAnswering::new(&p.root(), &config);
/// ```
///
pub fn new(p: &nn::Path, config: &AlbertConfig) -> AlbertForQuestionAnswering {
let albert = AlbertModel::new(&(p / "albert"), config);
let num_labels = 2;
let qa_outputs = nn::linear(&(p / "qa_outputs"), config.hidden_size, num_labels, Default::default());
let qa_outputs = nn::linear(
&(p / "qa_outputs"),
config.hidden_size,
num_labels,
Default::default(),
);
AlbertForQuestionAnswering { albert, qa_outputs }
}
@ -651,16 +784,16 @@ impl AlbertForQuestionAnswering {
/// # Example
///
/// ```no_run
///# use tch::{nn, Device, Tensor, no_grad};
///# use rust_bert::Config;
///# use std::path::Path;
///# use tch::kind::Kind::Int64;
/// # use tch::{nn, Device, Tensor, no_grad};
/// # use rust_bert::Config;
/// # use std::path::Path;
/// # use tch::kind::Kind::Int64;
/// use rust_bert::albert::{AlbertConfig, AlbertForQuestionAnswering};
///# let config_path = Path::new("path/to/config.json");
///# let device = Device::Cpu;
///# let vs = nn::VarStore::new(device);
///# let config = AlbertConfig::from_file(config_path);
///# let albert_model: AlbertForQuestionAnswering = AlbertForQuestionAnswering::new(&vs.root(), &config);
/// # let config_path = Path::new("path/to/config.json");
/// # let device = Device::Cpu;
/// # let vs = nn::VarStore::new(device);
/// # let config = AlbertConfig::from_file(config_path);
/// # let albert_model: AlbertForQuestionAnswering = AlbertForQuestionAnswering::new(&vs.root(), &config);
/// let (batch_size, sequence_length) = (64, 128);
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
@ -676,17 +809,32 @@ impl AlbertForQuestionAnswering {
/// None,
/// false)
/// });
///
/// ```
///
pub fn forward_t(&self,
input_ids: Option<Tensor>,
mask: Option<Tensor>,
token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool) -> (Tensor, Tensor, Option<Vec<Tensor>>, Option<Vec<Vec<Tensor>>>) {
let (sequence_output, _, all_hidden_states, all_attentions) = self.albert.forward_t(input_ids, mask, token_type_ids, position_ids, input_embeds, train).unwrap();
pub fn forward_t(
&self,
input_ids: Option<Tensor>,
mask: Option<Tensor>,
token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool,
) -> (
Tensor,
Tensor,
Option<Vec<Tensor>>,
Option<Vec<Vec<Tensor>>>,
) {
let (sequence_output, _, all_hidden_states, all_attentions) = self
.albert
.forward_t(
input_ids,
mask,
token_type_ids,
position_ids,
input_embeds,
train,
)
.unwrap();
let logits = sequence_output.apply(&self.qa_outputs).split(1, -1);
let (start_logits, end_logits) = (&logits[0], &logits[1]);
let start_logits = start_logits.squeeze1(-1);
@ -721,10 +869,10 @@ impl AlbertForMultipleChoice {
/// # Example
///
/// ```no_run
/// use tch::{nn, Device};
/// use rust_bert::albert::{AlbertConfig, AlbertForMultipleChoice};
/// use rust_bert::Config;
/// use std::path::Path;
/// use rust_bert::albert::{AlbertConfig, AlbertForMultipleChoice};
/// use tch::{nn, Device};
///
/// let config_path = Path::new("path/to/config.json");
/// let device = Device::Cpu;
@ -732,14 +880,22 @@ impl AlbertForMultipleChoice {
/// let config = AlbertConfig::from_file(config_path);
/// let albert: AlbertForMultipleChoice = AlbertForMultipleChoice::new(&p.root(), &config);
/// ```
///
pub fn new(p: &nn::Path, config: &AlbertConfig) -> AlbertForMultipleChoice {
let albert = AlbertModel::new(&(p / "albert"), config);
let dropout = Dropout::new(config.hidden_dropout_prob);
let num_labels = 1;
let classifier = nn::linear(&(p / "classifier"), config.hidden_size, num_labels, Default::default());
let classifier = nn::linear(
&(p / "classifier"),
config.hidden_size,
num_labels,
Default::default(),
);
AlbertForMultipleChoice { albert, dropout, classifier }
AlbertForMultipleChoice {
albert,
dropout,
classifier,
}
}
/// Forward pass through the model
@ -762,16 +918,16 @@ impl AlbertForMultipleChoice {
/// # Example
///
/// ```no_run
///# use tch::{nn, Device, Tensor, no_grad};
///# use rust_bert::Config;
///# use std::path::Path;
///# use tch::kind::Kind::Int64;
/// # use tch::{nn, Device, Tensor, no_grad};
/// # use rust_bert::Config;
/// # use std::path::Path;
/// # use tch::kind::Kind::Int64;
/// use rust_bert::albert::{AlbertConfig, AlbertForMultipleChoice};
///# let config_path = Path::new("path/to/config.json");
///# let device = Device::Cpu;
///# let vs = nn::VarStore::new(device);
///# let config = AlbertConfig::from_file(config_path);
///# let albert_model: AlbertForMultipleChoice = AlbertForMultipleChoice::new(&vs.root(), &config);
/// # let config_path = Path::new("path/to/config.json");
/// # let device = Device::Cpu;
/// # let vs = nn::VarStore::new(device);
/// # let config = AlbertConfig::from_file(config_path);
/// # let albert_model: AlbertForMultipleChoice = AlbertForMultipleChoice::new(&vs.root(), &config);
/// let (batch_size, sequence_length) = (64, 128);
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
@ -787,44 +943,68 @@ impl AlbertForMultipleChoice {
/// None,
/// false).unwrap()
/// });
///
/// ```
///
pub fn forward_t(&self,
input_ids: Option<Tensor>,
mask: Option<Tensor>,
token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool) -> Result<(Tensor, Option<Vec<Tensor>>, Option<Vec<Vec<Tensor>>>), &'static str> {
pub fn forward_t(
&self,
input_ids: Option<Tensor>,
mask: Option<Tensor>,
token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool,
) -> Result<(Tensor, Option<Vec<Tensor>>, Option<Vec<Vec<Tensor>>>), &'static str> {
let (input_ids, input_embeds, num_choices) = match &input_ids {
Some(input_value) => match &input_embeds {
Some(_) => { return Err("Only one of input ids or input embeddings may be set"); }
None => (Some(input_value.view((-1, *input_value.size().last().unwrap()))), None, input_value.size()[1])
}
Some(_) => {
return Err("Only one of input ids or input embeddings may be set");
}
None => (
Some(input_value.view((-1, *input_value.size().last().unwrap()))),
None,
input_value.size()[1],
),
},
None => match &input_embeds {
Some(embeds) => (None, Some(embeds.view((-1, embeds.size()[1], embeds.size()[2]))), embeds.size()[1]),
None => { return Err("At least one of input ids or input embeddings must be set"); }
}
Some(embeds) => (
None,
Some(embeds.view((-1, embeds.size()[1], embeds.size()[2]))),
embeds.size()[1],
),
None => {
return Err("At least one of input ids or input embeddings must be set");
}
},
};
let mask = match mask {
Some(value) => Some(value.view((-1, *value.size().last().unwrap()))),
None => None
None => None,
};
let token_type_ids = match token_type_ids {
Some(value) => Some(value.view((-1, *value.size().last().unwrap()))),
None => None
None => None,
};
let position_ids = match position_ids {
Some(value) => Some(value.view((-1, *value.size().last().unwrap()))),
None => None
None => None,
};
let (_, pooled_output, all_hidden_states, all_attentions) = self.albert.forward_t(input_ids, mask, token_type_ids, position_ids, input_embeds, train).unwrap();
let logits = pooled_output.apply_t(&self.dropout, train).apply(&self.classifier).view((-1, num_choices));
let (_, pooled_output, all_hidden_states, all_attentions) = self
.albert
.forward_t(
input_ids,
mask,
token_type_ids,
position_ids,
input_embeds,
train,
)
.unwrap();
let logits = pooled_output
.apply_t(&self.dropout, train)
.apply(&self.classifier)
.view((-1, num_choices));
Ok((logits, all_hidden_states, all_attentions))
}
}
}

View File

@ -11,10 +11,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use crate::common::dropout::Dropout;
use tch::{nn, Tensor};
use crate::albert::AlbertConfig;
use crate::common::dropout::Dropout;
use tch::kind::Kind::Float;
use tch::{nn, Tensor};
#[derive(Debug)]
pub struct AlbertSelfAttention {
@ -32,24 +32,55 @@ pub struct AlbertSelfAttention {
impl AlbertSelfAttention {
pub fn new(p: nn::Path, config: &AlbertConfig) -> AlbertSelfAttention {
assert_eq!(config.hidden_size % config.num_attention_heads, 0, "Hidden size not a multiple of the number of attention heads");
assert_eq!(
config.hidden_size % config.num_attention_heads,
0,
"Hidden size not a multiple of the number of attention heads"
);
let query = nn::linear(&p / "query", config.hidden_size, config.hidden_size, Default::default());
let key = nn::linear(&p / "key", config.hidden_size, config.hidden_size, Default::default());
let value = nn::linear(&p / "value", config.hidden_size, config.hidden_size, Default::default());
let dense = nn::linear(&p / "dense", config.hidden_size, config.hidden_size, Default::default());
let query = nn::linear(
&p / "query",
config.hidden_size,
config.hidden_size,
Default::default(),
);
let key = nn::linear(
&p / "key",
config.hidden_size,
config.hidden_size,
Default::default(),
);
let value = nn::linear(
&p / "value",
config.hidden_size,
config.hidden_size,
Default::default(),
);
let dense = nn::linear(
&p / "dense",
config.hidden_size,
config.hidden_size,
Default::default(),
);
let dropout = Dropout::new(config.attention_probs_dropout_prob);
let attention_head_size = config.hidden_size / config.num_attention_heads;
let output_attentions = match config.output_attentions {
Some(value) => value,
None => false
None => false,
};
let layer_norm_eps = match config.layer_norm_eps {
Some(value) => value,
None => 1e-12
None => 1e-12,
};
let layer_norm_config = nn::LayerNormConfig { eps: layer_norm_eps, ..Default::default() };
let layer_norm = nn::layer_norm(&p / "LayerNorm", vec![config.hidden_size], layer_norm_config);
let layer_norm_config = nn::LayerNormConfig {
eps: layer_norm_eps,
..Default::default()
};
let layer_norm = nn::layer_norm(
&p / "LayerNorm",
vec![config.hidden_size],
layer_norm_config,
);
AlbertSelfAttention {
num_attention_heads: config.num_attention_heads,
@ -66,19 +97,23 @@ impl AlbertSelfAttention {
}
fn split_heads(&self, x: Tensor, bs: i64, dim_per_head: i64) -> Tensor {
x.view((bs, -1, self.num_attention_heads, dim_per_head)).transpose(1, 2)
x.view((bs, -1, self.num_attention_heads, dim_per_head))
.transpose(1, 2)
}
pub fn forward_t(&self,
input_ids: &Tensor,
mask: &Option<Tensor>,
train: bool) -> (Tensor, Option<Tensor>) {
pub fn forward_t(
&self,
input_ids: &Tensor,
mask: &Option<Tensor>,
train: bool,
) -> (Tensor, Option<Tensor>) {
let bs = *input_ids.size().first().unwrap();
let key_layer = self.split_heads(input_ids.apply(&self.key), bs, self.attention_head_size);
let value_layer = self.split_heads(input_ids.apply(&self.value), bs, self.attention_head_size);
let query_layer = self.split_heads(input_ids.apply(&self.query), bs, self.attention_head_size);
let value_layer =
self.split_heads(input_ids.apply(&self.value), bs, self.attention_head_size);
let query_layer =
self.split_heads(input_ids.apply(&self.query), bs, self.attention_head_size);
let query_layer: Tensor = query_layer / (self.attention_head_size as f64).sqrt();
@ -91,9 +126,11 @@ impl AlbertSelfAttention {
let weights = scores.softmax(-1, Float).apply_t(&self.dropout, train);
let context = weights.matmul(&value_layer).transpose(1, 2).contiguous();
let w = self.dense.ws
.transpose(0, 1)
.view((self.num_attention_heads, self.attention_head_size, self.hidden_size));
let w = self.dense.ws.transpose(0, 1).view((
self.num_attention_heads,
self.attention_head_size,
self.hidden_size,
));
let context: Tensor = Tensor::einsum("bfnd,ndh->bfh", &[context, w]) + &self.dense.bs;
let context = (input_ids + context.apply_t(&self.dropout, train)).apply(&self.layer_norm);
@ -104,4 +141,4 @@ impl AlbertSelfAttention {
(context, Some(weights))
}
}
}
}

View File

@ -11,10 +11,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use tch::{nn, Tensor, Kind};
use crate::common::dropout::Dropout;
use crate::albert::AlbertConfig;
use tch::nn::{EmbeddingConfig, embedding};
use crate::common::dropout::Dropout;
use tch::nn::{embedding, EmbeddingConfig};
use tch::{nn, Kind, Tensor};
/// # Embeddings implementation for Albert model
#[derive(Debug)]
@ -34,49 +34,77 @@ impl AlbertEmbeddings {
..Default::default()
};
let word_embeddings: nn::Embedding = embedding(p / "word_embeddings",
config.vocab_size,
config.embedding_size,
embedding_config);
let word_embeddings: nn::Embedding = embedding(
p / "word_embeddings",
config.vocab_size,
config.embedding_size,
embedding_config,
);
let position_embeddings: nn::Embedding = embedding(p / "position_embeddings",
config.max_position_embeddings,
config.embedding_size,
Default::default());
let position_embeddings: nn::Embedding = embedding(
p / "position_embeddings",
config.max_position_embeddings,
config.embedding_size,
Default::default(),
);
let token_type_embeddings: nn::Embedding = embedding(p / "token_type_embeddings",
config.type_vocab_size,
config.embedding_size,
Default::default());
let token_type_embeddings: nn::Embedding = embedding(
p / "token_type_embeddings",
config.type_vocab_size,
config.embedding_size,
Default::default(),
);
let layer_norm_eps = match config.layer_norm_eps {
Some(value) => value,
None => 1e-12
None => 1e-12,
};
let layer_norm_config = nn::LayerNormConfig { eps: layer_norm_eps, ..Default::default() };
let layer_norm: nn::LayerNorm = nn::layer_norm(p / "LayerNorm", vec![config.embedding_size], layer_norm_config);
let layer_norm_config = nn::LayerNormConfig {
eps: layer_norm_eps,
..Default::default()
};
let layer_norm: nn::LayerNorm = nn::layer_norm(
p / "LayerNorm",
vec![config.embedding_size],
layer_norm_config,
);
let dropout: Dropout = Dropout::new(config.hidden_dropout_prob);
AlbertEmbeddings { word_embeddings, position_embeddings, token_type_embeddings, layer_norm, dropout}
AlbertEmbeddings {
word_embeddings,
position_embeddings,
token_type_embeddings,
layer_norm,
dropout,
}
}
pub fn forward_t(&self,
input_ids: Option<Tensor>,
token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool) -> Result<Tensor, &'static str> {
pub fn forward_t(
&self,
input_ids: Option<Tensor>,
token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool,
) -> Result<Tensor, &'static str> {
let (input_embeddings, input_shape) = match input_ids {
Some(input_value) => match input_embeds {
Some(_) => { return Err("Only one of input ids or input embeddings may be set"); }
None => (input_value.apply_t(&self.word_embeddings, train), input_value.size())
}
Some(_) => {
return Err("Only one of input ids or input embeddings may be set");
}
None => (
input_value.apply_t(&self.word_embeddings, train),
input_value.size(),
),
},
None => match input_embeds {
Some(embeds) => {
let size = vec!(embeds.size()[0], embeds.size()[1]);
let size = vec![embeds.size()[0], embeds.size()[1]];
(embeds, size)
},
None => { return Err("Only one of input ids or input embeddings may be set"); }
}
}
None => {
return Err("Only one of input ids or input embeddings may be set");
}
},
};
let seq_length = input_embeddings.as_ref().size()[1].to_owned();
@ -84,19 +112,22 @@ impl AlbertEmbeddings {
let position_ids = match position_ids {
Some(value) => value,
None => Tensor::arange(seq_length, (Kind::Int64, input_embeddings.device()))
.unsqueeze(0).
expand(&input_shape, true)
.unsqueeze(0)
.expand(&input_shape, true),
};
let token_type_ids = match token_type_ids {
Some(value) => value,
None => Tensor::zeros(&input_shape, (Kind::Int64, input_embeddings.device()))
None => Tensor::zeros(&input_shape, (Kind::Int64, input_embeddings.device())),
};
let position_embeddings = position_ids.apply(&self.position_embeddings);
let token_type_embeddings = token_type_ids.apply(&self.token_type_embeddings);
let input_embeddings: Tensor = input_embeddings + position_embeddings + token_type_embeddings;
Ok(input_embeddings.apply(&self.layer_norm).apply_t(&self.dropout, train))
let input_embeddings: Tensor =
input_embeddings + position_embeddings + token_type_embeddings;
Ok(input_embeddings
.apply(&self.layer_norm)
.apply_t(&self.dropout, train))
}
}
}

View File

@ -11,12 +11,12 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use crate::albert::attention::AlbertSelfAttention;
use tch::{nn, Tensor};
use crate::albert::AlbertConfig;
use crate::albert::albert::Activation;
use crate::common::activations::{_gelu_new, _gelu, _relu, _mish};
use crate::albert::attention::AlbertSelfAttention;
use crate::albert::AlbertConfig;
use crate::common::activations::{_gelu, _gelu_new, _mish, _relu};
use std::borrow::BorrowMut;
use tch::{nn, Tensor};
pub struct AlbertLayer {
attention: AlbertSelfAttention,
@ -32,29 +32,55 @@ impl AlbertLayer {
let layer_norm_eps = match config.layer_norm_eps {
Some(value) => value,
None => 1e-12
None => 1e-12,
};
let layer_norm_config = nn::LayerNormConfig { eps: layer_norm_eps, ..Default::default() };
let full_layer_layer_norm = nn::layer_norm(&(p / "full_layer_layer_norm"), vec![config.hidden_size], layer_norm_config);
let layer_norm_config = nn::LayerNormConfig {
eps: layer_norm_eps,
..Default::default()
};
let full_layer_layer_norm = nn::layer_norm(
&(p / "full_layer_layer_norm"),
vec![config.hidden_size],
layer_norm_config,
);
let ffn = nn::linear(&(p / "ffn"), config.hidden_size, config.intermediate_size, Default::default());
let ffn_output = nn::linear(&(p / "ffn_output"), config.intermediate_size, config.hidden_size, Default::default());
let ffn = nn::linear(
&(p / "ffn"),
config.hidden_size,
config.intermediate_size,
Default::default(),
);
let ffn_output = nn::linear(
&(p / "ffn_output"),
config.intermediate_size,
config.hidden_size,
Default::default(),
);
let activation = Box::new(match &config.hidden_act {
Activation::gelu_new => _gelu_new,
Activation::gelu => _gelu,
Activation::relu => _relu,
Activation::mish => _mish
Activation::mish => _mish,
});
AlbertLayer { attention, full_layer_layer_norm, ffn, ffn_output, activation }
AlbertLayer {
attention,
full_layer_layer_norm,
ffn,
ffn_output,
activation,
}
}
pub fn forward_t(&self,
hidden_states: &Tensor,
mask: &Option<Tensor>,
train: bool) -> (Tensor, Option<Tensor>) {
let (attention_output, attention_weights) = self.attention.forward_t(hidden_states, mask, train);
pub fn forward_t(
&self,
hidden_states: &Tensor,
mask: &Option<Tensor>,
train: bool,
) -> (Tensor, Option<Tensor>) {
let (attention_output, attention_weights) =
self.attention.forward_t(hidden_states, mask, train);
let ffn_output = attention_output.apply(&self.ffn);
let ffn_output: Tensor = (self.activation)(&ffn_output);
let ffn_output = ffn_output.apply(&self.ffn_output);
@ -76,29 +102,42 @@ impl AlbertLayerGroup {
let output_attentions = match config.output_attentions {
Some(value) => value,
None => false
None => false,
};
let output_hidden_states = match config.output_hidden_states {
Some(value) => value,
None => false
None => false,
};
let mut layers: Vec<AlbertLayer> = vec!();
let mut layers: Vec<AlbertLayer> = vec![];
for layer_index in 0..config.inner_group_num {
layers.push(AlbertLayer::new(&(p / layer_index), config));
};
}
AlbertLayerGroup { output_hidden_states, output_attentions, layers }
AlbertLayerGroup {
output_hidden_states,
output_attentions,
layers,
}
}
pub fn forward_t(&self,
hidden_states: &Tensor,
mask: &Option<Tensor>,
train: bool)
-> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
let mut all_hidden_states: Option<Vec<Tensor>> = if self.output_hidden_states { Some(vec!()) } else { None };
let mut all_attentions: Option<Vec<Tensor>> = if self.output_attentions { Some(vec!()) } else { None };
pub fn forward_t(
&self,
hidden_states: &Tensor,
mask: &Option<Tensor>,
train: bool,
) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
let mut all_hidden_states: Option<Vec<Tensor>> = if self.output_hidden_states {
Some(vec![])
} else {
None
};
let mut all_attentions: Option<Vec<Tensor>> = if self.output_attentions {
Some(vec![])
} else {
None
};
let mut hidden_state = hidden_states.copy();
let mut attention_weights: Option<Tensor>;
@ -117,9 +156,9 @@ impl AlbertLayerGroup {
attentions.push(attention_weights.as_ref().unwrap().copy());
};
}
None => break
None => break,
};
};
}
(hidden_state, all_hidden_states, all_attentions)
}
@ -140,20 +179,25 @@ impl AlbertTransformer {
let output_attentions = match config.output_attentions {
Some(value) => value,
None => false
None => false,
};
let output_hidden_states = match config.output_hidden_states {
Some(value) => value,
None => false
None => false,
};
let embedding_hidden_mapping_in = nn::linear(&(p / "embedding_hidden_mapping_in"), config.embedding_size, config.hidden_size, Default::default());
let embedding_hidden_mapping_in = nn::linear(
&(p / "embedding_hidden_mapping_in"),
config.embedding_size,
config.hidden_size,
Default::default(),
);
let mut layers: Vec<AlbertLayerGroup> = vec!();
let mut layers: Vec<AlbertLayerGroup> = vec![];
for layer_index in 0..config.inner_group_num {
layers.push(AlbertLayerGroup::new(&(p_layers / layer_index), config));
};
}
AlbertTransformer {
output_hidden_states,
@ -165,16 +209,24 @@ impl AlbertTransformer {
}
}
pub fn forward_t(&self,
hidden_states: &Tensor,
mask: Option<Tensor>,
train: bool)
-> (Tensor, Option<Vec<Tensor>>, Option<Vec<Vec<Tensor>>>) {
pub fn forward_t(
&self,
hidden_states: &Tensor,
mask: Option<Tensor>,
train: bool,
) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Vec<Tensor>>>) {
let mut hidden_state = hidden_states.apply(&self.embedding_hidden_mapping_in);
let mut all_hidden_states: Option<Vec<Tensor>> = if self.output_hidden_states { Some(vec!()) } else { None };
let mut all_attentions: Option<Vec<Vec<Tensor>>> = if self.output_attentions { Some(vec!()) } else { None };
let mut all_hidden_states: Option<Vec<Tensor>> = if self.output_hidden_states {
Some(vec![])
} else {
None
};
let mut all_attentions: Option<Vec<Vec<Tensor>>> = if self.output_attentions {
Some(vec![])
} else {
None
};
for i in 0..self.num_hidden_layers {
let group_idx = i / (self.num_hidden_layers / self.num_hidden_groups);
@ -190,9 +242,8 @@ impl AlbertTransformer {
if let Some(attentions) = all_attentions.borrow_mut() {
attentions.push(attention_weights.unwrap());
};
};
}
(hidden_state, all_hidden_states, all_attentions)
}
}

View File

@ -20,37 +20,46 @@
//! Pretrained models are available and can be downloaded using RemoteResources.
//!
//! ```no_run
//!# fn main() -> failure::Fallible<()> {
//!#
//! # fn main() -> failure::Fallible<()> {
//! #
//! use rust_tokenizers::AlbertTokenizer;
//! use tch::{nn, Device};
//!# use std::path::PathBuf;
//! use rust_bert::albert::{AlbertForMaskedLM, AlbertConfig};
//! # use std::path::PathBuf;
//! use rust_bert::albert::{AlbertConfig, AlbertForMaskedLM};
//! use rust_bert::resources::{download_resource, LocalResource, Resource};
//! use rust_bert::Config;
//! use rust_bert::resources::{Resource, download_resource, LocalResource};
//!
//! let config_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/config.json")});
//! let vocab_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/vocab.txt")});
//! let weights_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/model.ot")});
//! let config_resource = Resource::Local(LocalResource {
//! local_path: PathBuf::from("path/to/config.json"),
//! });
//! let vocab_resource = Resource::Local(LocalResource {
//! local_path: PathBuf::from("path/to/vocab.txt"),
//! });
//! let weights_resource = Resource::Local(LocalResource {
//! local_path: PathBuf::from("path/to/model.ot"),
//! });
//! let config_path = download_resource(&config_resource)?;
//! let vocab_path = download_resource(&vocab_resource)?;
//! let weights_path = download_resource(&weights_resource)?;
//! let device = Device::cuda_if_available();
//! let mut vs = nn::VarStore::new(device);
//! let tokenizer: AlbertTokenizer = AlbertTokenizer::from_file(vocab_path.to_str().unwrap(), true, true);
//! let tokenizer: AlbertTokenizer =
//! AlbertTokenizer::from_file(vocab_path.to_str().unwrap(), true, true);
//! let config = AlbertConfig::from_file(config_path);
//! let bert_model = AlbertForMaskedLM::new(&vs.root(), &config);
//! vs.load(weights_path)?;
//!
//!# Ok(())
//!# }
//! # Ok(())
//! # }
//! ```
mod encoder;
mod albert;
mod attention;
mod embeddings;
mod albert;
mod encoder;
pub use albert::{AlbertConfig, AlbertModelResources, AlbertConfigResources, AlbertVocabResources, AlbertModel, AlbertForMaskedLM, AlbertForSequenceClassification, AlbertForTokenClassification, AlbertForQuestionAnswering, AlbertForMultipleChoice};
pub use albert::{
AlbertConfig, AlbertConfigResources, AlbertForMaskedLM, AlbertForMultipleChoice,
AlbertForQuestionAnswering, AlbertForSequenceClassification, AlbertForTokenClassification,
AlbertModel, AlbertModelResources, AlbertVocabResources,
};

View File

@ -12,8 +12,8 @@
// limitations under the License.
use crate::common::dropout::Dropout;
use tch::{nn, Tensor};
use tch::kind::Kind::Float;
use tch::{nn, Tensor};
#[derive(Debug)]
/// # Cache for BART attention layers
@ -31,7 +31,7 @@ impl Clone for LayerState {
fn clone(&self) -> Self {
let prev_key_padding_mask = match &self.prev_key_padding_mask {
Some(key_padding_mask) => Some(key_padding_mask.copy()),
None => None
None => None,
};
LayerState {
prev_key: self.prev_key.copy(),
@ -46,12 +46,16 @@ impl LayerState {
self.prev_key = self.prev_key.index_select(0, new_indices);
self.prev_value = self.prev_value.index_select(0, new_indices);
if self.prev_key_padding_mask.is_some() {
self.prev_key_padding_mask = Some(self.prev_key_padding_mask.as_ref().unwrap().index_select(0, new_indices));
self.prev_key_padding_mask = Some(
self.prev_key_padding_mask
.as_ref()
.unwrap()
.index_select(0, new_indices),
);
}
}
}
#[derive(Debug)]
pub struct SelfAttention {
num_heads: i64,
@ -68,8 +72,15 @@ pub struct SelfAttention {
}
impl SelfAttention {
pub fn new(p: nn::Path, embed_dim: i64, num_heads: i64, dropout: f64,
encoder_decoder_attention: bool, store_cache: bool, output_attentions: bool) -> SelfAttention {
pub fn new(
p: nn::Path,
embed_dim: i64,
num_heads: i64,
dropout: f64,
encoder_decoder_attention: bool,
store_cache: bool,
output_attentions: bool,
) -> SelfAttention {
let k_proj = nn::linear(&p / "k_proj", embed_dim, embed_dim, Default::default());
let v_proj = nn::linear(&p / "v_proj", embed_dim, embed_dim, Default::default());
let q_proj = nn::linear(&p / "q_proj", embed_dim, embed_dim, Default::default());
@ -95,58 +106,90 @@ impl SelfAttention {
}
fn flatten(&self, x: Tensor, dim_0: i64, bs: i64) -> Tensor {
x.contiguous().view((dim_0, bs * self.num_heads, self.head_dim)).transpose(0, 1)
x.contiguous()
.view((dim_0, bs * self.num_heads, self.head_dim))
.transpose(0, 1)
}
pub fn forward_t(&self, query: &Tensor,
key: Option<&Tensor>,
key_padding_mask: Option<&Tensor>,
attention_mask: Option<&Tensor>,
mut layer_state: Option<LayerState>,
train: bool) -> (Tensor, Option<Tensor>, Option<LayerState>) {
pub fn forward_t(
&self,
query: &Tensor,
key: Option<&Tensor>,
key_padding_mask: Option<&Tensor>,
attention_mask: Option<&Tensor>,
mut layer_state: Option<LayerState>,
train: bool,
) -> (Tensor, Option<Tensor>, Option<LayerState>) {
let query_size = query.size();
let (target_sequence_length, bs) = (query_size[0], query_size[1]);
let q: Tensor = self.flatten(query.as_ref().apply(&self.q_proj) * self.scaling, target_sequence_length, bs);
let q: Tensor = self.flatten(
query.as_ref().apply(&self.q_proj) * self.scaling,
target_sequence_length,
bs,
);
let key = match &layer_state {
Some(_) => { if self.encoder_decoder_attention { None } else { key } }
None => key
Some(_) => {
if self.encoder_decoder_attention {
None
} else {
key
}
}
None => key,
};
let (k, v) = if self.encoder_decoder_attention {
match key {
Some(key) => {
(Some(self.flatten(key.apply(&self.k_proj), -1, bs)),
Some(self.flatten(key.apply(&self.v_proj), -1, bs))
)
}
None => (None, None)
Some(key) => (
Some(self.flatten(key.apply(&self.k_proj), -1, bs)),
Some(self.flatten(key.apply(&self.v_proj), -1, bs)),
),
None => (None, None),
}
} else {
(Some(self.flatten(query.apply(&self.k_proj), -1, bs)),
Some(self.flatten(query.apply(&self.v_proj), -1, bs))
(
Some(self.flatten(query.apply(&self.k_proj), -1, bs)),
Some(self.flatten(query.apply(&self.v_proj), -1, bs)),
)
};
let (k, v, key_padding_mask) = self.use_saved_state(&layer_state, k, v, key_padding_mask, bs);
let (k, v, key_padding_mask) =
self.use_saved_state(&layer_state, k, v, key_padding_mask, bs);
let source_sequence_length = k.size()[1];
let attention_weights = q.bmm(&k.transpose(1, 2));
let attention_weights = match attention_mask {
Some(mask) => {
let attention_weights = attention_weights.view((bs, self.num_heads, target_sequence_length, source_sequence_length)) + mask;
attention_weights.view((bs * self.num_heads, target_sequence_length, source_sequence_length))
let attention_weights = attention_weights.view((
bs,
self.num_heads,
target_sequence_length,
source_sequence_length,
)) + mask;
attention_weights.view((
bs * self.num_heads,
target_sequence_length,
source_sequence_length,
))
}
None => attention_weights
None => attention_weights,
};
let attention_weights = match key_padding_mask.as_ref() {
Some(mask) => {
attention_weights
.view((bs, self.num_heads, target_sequence_length, source_sequence_length))
.masked_fill(&mask.unsqueeze(1).unsqueeze(2), std::f64::NEG_INFINITY)
.view((bs * self.num_heads, target_sequence_length, source_sequence_length))
}
None => attention_weights
Some(mask) => attention_weights
.view((
bs,
self.num_heads,
target_sequence_length,
source_sequence_length,
))
.masked_fill(&mask.unsqueeze(1).unsqueeze(2), std::f64::NEG_INFINITY)
.view((
bs * self.num_heads,
target_sequence_length,
source_sequence_length,
)),
None => attention_weights,
};
let attention_weights = attention_weights.softmax(-1, Float);
@ -159,16 +202,25 @@ impl SelfAttention {
.apply(&self.out_proj);
let attention_weights = if self.output_attentions {
Some(attention_weights.view((bs, self.num_heads, target_sequence_length, source_sequence_length)))
} else { None };
Some(attention_weights.view((
bs,
self.num_heads,
target_sequence_length,
source_sequence_length,
)))
} else {
None
};
if self.store_cache {
if layer_state.is_some() {
layer_state.as_mut().unwrap().prev_key = k.view((bs, self.num_heads, -1, self.head_dim));
layer_state.as_mut().unwrap().prev_value = v.view((bs, self.num_heads, -1, self.head_dim));
layer_state.as_mut().unwrap().prev_key =
k.view((bs, self.num_heads, -1, self.head_dim));
layer_state.as_mut().unwrap().prev_value =
v.view((bs, self.num_heads, -1, self.head_dim));
layer_state.as_mut().unwrap().prev_key_padding_mask = match key_padding_mask {
Some(tensor) => Some(tensor),
None => None
None => None,
};
} else {
layer_state = Some(LayerState {
@ -176,7 +228,7 @@ impl SelfAttention {
prev_value: v.view((bs, self.num_heads, -1, self.head_dim)),
prev_key_padding_mask: match key_padding_mask {
Some(tensor) => Some(tensor),
None => None
None => None,
},
})
};
@ -185,17 +237,23 @@ impl SelfAttention {
(output, attention_weights, layer_state)
}
fn use_saved_state(&self,
layer_state: &Option<LayerState>,
k: Option<Tensor>,
v: Option<Tensor>,
key_padding_mask: Option<&Tensor>,
bs: i64)
-> (Tensor, Tensor, Option<Tensor>) {
fn use_saved_state(
&self,
layer_state: &Option<LayerState>,
k: Option<Tensor>,
v: Option<Tensor>,
key_padding_mask: Option<&Tensor>,
bs: i64,
) -> (Tensor, Tensor, Option<Tensor>) {
match &layer_state {
Some(prev_state) => {
let prev_key = prev_state.prev_key.view((bs * self.num_heads, -1, self.head_dim));
let prev_value = prev_state.prev_value.view((bs * self.num_heads, -1, self.head_dim));
let prev_key = prev_state
.prev_key
.view((bs * self.num_heads, -1, self.head_dim));
let prev_value =
prev_state
.prev_value
.view((bs * self.num_heads, -1, self.head_dim));
let k = if self.encoder_decoder_attention {
prev_key
} else {
@ -207,39 +265,54 @@ impl SelfAttention {
Tensor::cat(&[prev_value, v.unwrap()], 1)
};
let key_padding_mask = self.use_saved_key_padding_mask(key_padding_mask,
&prev_state.prev_key_padding_mask,
bs,
k.size()[1]);
let key_padding_mask = self.use_saved_key_padding_mask(
key_padding_mask,
&prev_state.prev_key_padding_mask,
bs,
k.size()[1],
);
(k, v, key_padding_mask)
}
None => {
let key_padding_mask = match key_padding_mask {
Some(value) => Some(value.copy()),
None => None
None => None,
};
(k.unwrap(), v.unwrap(), key_padding_mask)
}
}
}
fn use_saved_key_padding_mask(&self, key_padding_mask: Option<&Tensor>, prev_key_padding_mask: &Option<Tensor>,
bs: i64, sequence_length: i64) -> Option<Tensor> {
fn use_saved_key_padding_mask(
&self,
key_padding_mask: Option<&Tensor>,
prev_key_padding_mask: &Option<Tensor>,
bs: i64,
sequence_length: i64,
) -> Option<Tensor> {
if prev_key_padding_mask.is_some() {
if self.encoder_decoder_attention {
Some(prev_key_padding_mask.as_ref().unwrap().copy())
} else {
Some(Tensor::cat(&[prev_key_padding_mask.as_ref().unwrap(), key_padding_mask.as_ref().unwrap()], 1))
Some(Tensor::cat(
&[
prev_key_padding_mask.as_ref().unwrap(),
key_padding_mask.as_ref().unwrap(),
],
1,
))
}
} else {
match key_padding_mask {
Some(key_padding_mask) => {
let filler = Tensor::zeros(&[bs, sequence_length - key_padding_mask.size()[1]],
(key_padding_mask.kind(), key_padding_mask.device()));
let filler = Tensor::zeros(
&[bs, sequence_length - key_padding_mask.size()[1]],
(key_padding_mask.kind(), key_padding_mask.device()),
);
Some(Tensor::cat(&[filler, key_padding_mask.copy()], 1))
}
None => None
None => None,
}
}
}
}
}

View File

@ -11,18 +11,18 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use crate::Config;
use tch::{Tensor, nn};
use tch::kind::Kind::{Int64, Float};
use crate::bart::encoder::BartEncoder;
use crate::bart::decoder::BartDecoder;
use tch::nn::{embedding, EmbeddingConfig};
use crate::bart::attention::LayerState;
use std::borrow::BorrowMut;
use crate::bart::decoder::BartDecoder;
use crate::bart::encoder::BartEncoder;
use crate::common::dropout::Dropout;
use crate::pipelines::generation::{Cache, LMHeadModel};
use crate::Config;
use serde::{Deserialize, Serialize};
use std::borrow::BorrowMut;
use std::collections::HashMap;
use tch::kind::Kind::{Float, Int64};
use tch::nn::{embedding, EmbeddingConfig};
use tch::{nn, Tensor};
/// # BART Pretrained model weight files
pub struct BartModelResources;
@ -38,38 +38,74 @@ pub struct BartMergesResources;
impl BartModelResources {
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
pub const BART: (&'static str, &'static str) = ("bart/model.ot", "https://cdn.huggingface.co/facebook/bart-large/rust_model.ot");
pub const BART: (&'static str, &'static str) = (
"bart/model.ot",
"https://cdn.huggingface.co/facebook/bart-large/rust_model.ot",
);
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
pub const BART_CNN: (&'static str, &'static str) = ("bart-cnn/model.ot", "https://cdn.huggingface.co/facebook/bart-large-cnn/rust_model.ot");
pub const BART_CNN: (&'static str, &'static str) = (
"bart-cnn/model.ot",
"https://cdn.huggingface.co/facebook/bart-large-cnn/rust_model.ot",
);
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
pub const BART_XSUM: (&'static str, &'static str) = ("bart-xsum/model.ot", "https://cdn.huggingface.co/facebook/bart-large-xsum/rust_model.ot");
pub const BART_XSUM: (&'static str, &'static str) = (
"bart-xsum/model.ot",
"https://cdn.huggingface.co/facebook/bart-large-xsum/rust_model.ot",
);
}
impl BartConfigResources {
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
pub const BART: (&'static str, &'static str) = ("bart/config.json", "https://cdn.huggingface.co/facebook/bart-large/config.json");
pub const BART: (&'static str, &'static str) = (
"bart/config.json",
"https://cdn.huggingface.co/facebook/bart-large/config.json",
);
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
pub const BART_CNN: (&'static str, &'static str) = ("bart-cnn/config.json", "https://cdn.huggingface.co/facebook/bart-large-cnn/config.json");
pub const BART_CNN: (&'static str, &'static str) = (
"bart-cnn/config.json",
"https://cdn.huggingface.co/facebook/bart-large-cnn/config.json",
);
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
pub const BART_XSUM: (&'static str, &'static str) = ("bart-xsum/config.json", "https://cdn.huggingface.co/facebook/bart-large-xsum/config.json");
pub const BART_XSUM: (&'static str, &'static str) = (
"bart-xsum/config.json",
"https://cdn.huggingface.co/facebook/bart-large-xsum/config.json",
);
}
impl BartVocabResources {
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
pub const BART: (&'static str, &'static str) = ("bart/vocab.txt", "https://cdn.huggingface.co/roberta-large-vocab.json");
pub const BART: (&'static str, &'static str) = (
"bart/vocab.txt",
"https://cdn.huggingface.co/roberta-large-vocab.json",
);
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
pub const BART_CNN: (&'static str, &'static str) = ("bart-cnn/vocab.txt", "https://cdn.huggingface.co/roberta-large-vocab.json");
pub const BART_CNN: (&'static str, &'static str) = (
"bart-cnn/vocab.txt",
"https://cdn.huggingface.co/roberta-large-vocab.json",
);
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
pub const BART_XSUM: (&'static str, &'static str) = ("bart-xsum/vocab.txt", "https://cdn.huggingface.co/roberta-large-vocab.json");
pub const BART_XSUM: (&'static str, &'static str) = (
"bart-xsum/vocab.txt",
"https://cdn.huggingface.co/roberta-large-vocab.json",
);
}
impl BartMergesResources {
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
pub const BART: (&'static str, &'static str) = ("bart/merges.txt", "https://cdn.huggingface.co/roberta-large-merges.txt");
pub const BART: (&'static str, &'static str) = (
"bart/merges.txt",
"https://cdn.huggingface.co/roberta-large-merges.txt",
);
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
pub const BART_CNN: (&'static str, &'static str) = ("bart-cnn/merges.txt", "https://cdn.huggingface.co/roberta-large-merges.txt");
pub const BART_CNN: (&'static str, &'static str) = (
"bart-cnn/merges.txt",
"https://cdn.huggingface.co/roberta-large-merges.txt",
);
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
pub const BART_XSUM: (&'static str, &'static str) = ("bart-xsum/merges.txt", "https://cdn.huggingface.co/roberta-large-merges.txt");
pub const BART_XSUM: (&'static str, &'static str) = (
"bart-xsum/merges.txt",
"https://cdn.huggingface.co/roberta-large-merges.txt",
);
}
#[allow(non_camel_case_types)]
@ -130,14 +166,15 @@ pub struct BartConfig {
impl Config<BartConfig> for BartConfig {}
fn _prepare_bart_decoder_inputs(pad_token_id: i64,
input_ids: &Tensor,
decoder_input_ids: Option<&Tensor>,
decoder_padding_mask: Option<&Tensor>)
-> (Tensor, Option<Tensor>, Option<Tensor>) {
fn _prepare_bart_decoder_inputs(
pad_token_id: i64,
input_ids: &Tensor,
decoder_input_ids: Option<&Tensor>,
decoder_padding_mask: Option<&Tensor>,
) -> (Tensor, Option<Tensor>, Option<Tensor>) {
let decoder_input_ids = match decoder_input_ids {
Some(value) => value.copy(),
None => _shift_tokens_right(input_ids, pad_token_id)
None => _shift_tokens_right(input_ids, pad_token_id),
};
let decoder_padding_mask = match decoder_padding_mask {
@ -159,7 +196,6 @@ fn _prepare_bart_decoder_inputs(pad_token_id: i64,
(decoder_input_ids, decoder_padding_mask, Some(causal_mask))
}
fn _shift_tokens_right(input_ids: &Tensor, pad_token_id: i64) -> Tensor {
let index_eos: Tensor = input_ids.ne(pad_token_id).sum1(&[-1], true, Int64) - 1;
let output = input_ids.empty_like().to_kind(Int64);
@ -200,10 +236,10 @@ impl BartModel {
/// # Example
///
/// ```no_run
/// use tch::{nn, Device};
/// use rust_bert::bart::{BartConfig, BartModel};
/// use rust_bert::Config;
/// use std::path::Path;
/// use rust_bert::bart::{BartConfig, BartModel};
/// use tch::{nn, Device};
///
/// let config_path = Path::new("path/to/config.json");
/// let device = Device::Cpu;
@ -212,22 +248,32 @@ impl BartModel {
/// let generation_mode = true;
/// let bart: BartModel = BartModel::new(&(&p.root() / "bart"), &config, generation_mode);
/// ```
///
pub fn new(p: &nn::Path, config: &BartConfig, generation_mode: bool) -> BartModel {
let pad_token_id = match config.pad_token_id {
Some(value) => value,
None => 1
None => 1,
};
let embedding_config = EmbeddingConfig { padding_idx: pad_token_id, ..Default::default() };
let embeddings: nn::Embedding = embedding(p / "shared",
config.vocab_size,
config.d_model,
embedding_config);
let embedding_config = EmbeddingConfig {
padding_idx: pad_token_id,
..Default::default()
};
let embeddings: nn::Embedding = embedding(
p / "shared",
config.vocab_size,
config.d_model,
embedding_config,
);
let encoder = BartEncoder::new(p / "encoder", config);
let decoder = BartDecoder::new(p / "decoder", config, generation_mode);
BartModel { encoder, decoder, generation_mode, pad_token_id, embeddings }
BartModel {
encoder,
decoder,
generation_mode,
pad_token_id,
embeddings,
}
}
/// Forward pass through the model
@ -257,82 +303,116 @@ impl BartModel {
/// # Example
///
/// ```no_run
///# use tch::{nn, Device, Tensor, no_grad};
///# use rust_bert::Config;
///# use std::path::Path;
///# use tch::kind::Kind::{Int64, Double};
/// # use tch::{nn, Device, Tensor, no_grad};
/// # use rust_bert::Config;
/// # use std::path::Path;
/// # use tch::kind::Kind::{Int64, Double};
/// use rust_bert::bart::{BartConfig, BartModel};
///# let config_path = Path::new("path/to/config.json");
///# let vocab_path = Path::new("path/to/vocab.txt");
///# let device = Device::Cpu;
///# let vs = nn::VarStore::new(device);
///# let config = BartConfig::from_file(config_path);
///# let bart_model: BartModel = BartModel::new(&vs.root(), &config, false);
/// let (batch_size, source_sequence_length, target_sequence_length) = (64, 128, 56);
/// let input_tensor = Tensor::rand(&[batch_size, source_sequence_length], (Int64, device));
/// let target_tensor = Tensor::rand(&[batch_size, target_sequence_length], (Int64, device));
/// let encoder_attention_mask = Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
/// let decoder_attention_mask = Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
///
/// let (decoder_output, encoder_hidden_states, decoder_cache,
/// all_encoder_hidden_states, all_encoder_attentions,
/// all_decoder_hidden_states, all_decoder_attentions) = no_grad(|| {
/// bart_model
/// .forward_t(Some(&input_tensor),
/// Some(&encoder_attention_mask),
/// Some(&target_tensor),
/// None,
/// Some(&decoder_attention_mask),
/// None,
/// false)
/// });
/// # let config_path = Path::new("path/to/config.json");
/// # let vocab_path = Path::new("path/to/vocab.txt");
/// # let device = Device::Cpu;
/// # let vs = nn::VarStore::new(device);
/// # let config = BartConfig::from_file(config_path);
/// # let bart_model: BartModel = BartModel::new(&vs.root(), &config, false);
/// let (batch_size, source_sequence_length, target_sequence_length) = (64, 128, 56);
/// let input_tensor = Tensor::rand(&[batch_size, source_sequence_length], (Int64, device));
/// let target_tensor = Tensor::rand(&[batch_size, target_sequence_length], (Int64, device));
/// let encoder_attention_mask =
/// Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
/// let decoder_attention_mask =
/// Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
///
/// let (
/// decoder_output,
/// encoder_hidden_states,
/// decoder_cache,
/// all_encoder_hidden_states,
/// all_encoder_attentions,
/// all_decoder_hidden_states,
/// all_decoder_attentions,
/// ) = no_grad(|| {
/// bart_model.forward_t(
/// Some(&input_tensor),
/// Some(&encoder_attention_mask),
/// Some(&target_tensor),
/// None,
/// Some(&decoder_attention_mask),
/// None,
/// false,
/// )
/// });
/// ```
///
pub fn forward_t(&self,
input_ids: Option<&Tensor>,
attention_mask: Option<&Tensor>,
decoder_input_ids: Option<&Tensor>,
encoder_outputs: Option<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>)>,
decoder_attention_mask: Option<&Tensor>,
layer_states: Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
train: bool) ->
(Tensor, Tensor, Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
Option<Vec<Tensor>>, Option<Vec<Tensor>>,
Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
pub fn forward_t(
&self,
input_ids: Option<&Tensor>,
attention_mask: Option<&Tensor>,
decoder_input_ids: Option<&Tensor>,
encoder_outputs: Option<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>)>,
decoder_attention_mask: Option<&Tensor>,
layer_states: Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
train: bool,
) -> (
Tensor,
Tensor,
Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
Option<Vec<Tensor>>,
Option<Vec<Tensor>>,
Option<Vec<Tensor>>,
Option<Vec<Tensor>>,
) {
let (decoder_input_ids, decoder_padding_mask, causal_mask) = if self.generation_mode {
(decoder_input_ids.unwrap().copy(), None, None)
} else {
assert!(input_ids.is_some(), "input_ids must be provided when not in generation mode");
_prepare_bart_decoder_inputs(self.pad_token_id, input_ids.unwrap(), decoder_input_ids, decoder_attention_mask)
assert!(
input_ids.is_some(),
"input_ids must be provided when not in generation mode"
);
_prepare_bart_decoder_inputs(
self.pad_token_id,
input_ids.unwrap(),
decoder_input_ids,
decoder_attention_mask,
)
};
let (encoder_hidden_states,
all_encoder_hidden_states,
all_encoder_attentions) = match encoder_outputs {
Some(value) => value,
None => {
assert!(input_ids.is_some(), "input_ids must be provided when encoder output is not pre-computed");
self.encoder.forward_t(input_ids.unwrap(), attention_mask, &self.embeddings, train)
}
};
let (encoder_hidden_states, all_encoder_hidden_states, all_encoder_attentions) =
match encoder_outputs {
Some(value) => value,
None => {
assert!(
input_ids.is_some(),
"input_ids must be provided when encoder output is not pre-computed"
);
self.encoder.forward_t(
input_ids.unwrap(),
attention_mask,
&self.embeddings,
train,
)
}
};
let (decoder_outputs,
decoder_cache,
let (decoder_outputs, decoder_cache, all_decoder_hidden_states, all_decoder_attentions) =
self.decoder.forward_t(
&decoder_input_ids,
&encoder_hidden_states,
attention_mask,
decoder_padding_mask.as_ref(),
causal_mask.as_ref(),
&self.embeddings,
layer_states,
train,
);
(
decoder_outputs,
encoder_hidden_states,
decoder_cache.1,
all_decoder_hidden_states,
all_decoder_attentions) = self.decoder.forward_t(&decoder_input_ids,
&encoder_hidden_states,
attention_mask,
decoder_padding_mask.as_ref(),
causal_mask.as_ref(),
&self.embeddings,
layer_states,
train);
(decoder_outputs, encoder_hidden_states, decoder_cache.1,
all_decoder_hidden_states, all_decoder_attentions,
all_encoder_hidden_states, all_encoder_attentions)
all_decoder_attentions,
all_encoder_hidden_states,
all_encoder_attentions,
)
}
}
/// # BART Model for conditional generation
@ -356,20 +436,24 @@ impl BartForConditionalGeneration {
/// # Example
///
/// ```no_run
/// use tch::{nn, Device};
/// use rust_bert::bart::{BartConfig, BartForConditionalGeneration};
/// use rust_bert::Config;
/// use std::path::Path;
/// use rust_bert::bart::{BartConfig, BartForConditionalGeneration};
/// use tch::{nn, Device};
///
/// let config_path = Path::new("path/to/config.json");
/// let device = Device::Cpu;
/// let p = nn::VarStore::new(device);
/// let config = BartConfig::from_file(config_path);
/// let generation_mode = true;
/// let bart: BartForConditionalGeneration = BartForConditionalGeneration::new(&(&p.root() / "bart"), &config, generation_mode);
/// let bart: BartForConditionalGeneration =
/// BartForConditionalGeneration::new(&(&p.root() / "bart"), &config, generation_mode);
/// ```
///
pub fn new(p: &nn::Path, config: &BartConfig, generation_mode: bool) -> BartForConditionalGeneration {
pub fn new(
p: &nn::Path,
config: &BartConfig,
generation_mode: bool,
) -> BartForConditionalGeneration {
let base_model = BartModel::new(&(p / "model"), config, generation_mode);
BartForConditionalGeneration { base_model }
}
@ -398,17 +482,17 @@ impl BartForConditionalGeneration {
/// # Example
///
/// ```no_run
///# use tch::{nn, Device, Tensor, no_grad};
///# use rust_bert::Config;
///# use std::path::Path;
///# use tch::kind::Kind::{Int64, Double};
/// # use tch::{nn, Device, Tensor, no_grad};
/// # use rust_bert::Config;
/// # use std::path::Path;
/// # use tch::kind::Kind::{Int64, Double};
/// use rust_bert::bart::{BartConfig, BartForConditionalGeneration};
///# let config_path = Path::new("path/to/config.json");
///# let vocab_path = Path::new("path/to/vocab.txt");
///# let device = Device::Cpu;
///# let vs = nn::VarStore::new(device);
///# let config = BartConfig::from_file(config_path);
///# let bart_model: BartForConditionalGeneration = BartForConditionalGeneration::new(&vs.root(), &config, false);
/// # let config_path = Path::new("path/to/config.json");
/// # let vocab_path = Path::new("path/to/vocab.txt");
/// # let device = Device::Cpu;
/// # let vs = nn::VarStore::new(device);
/// # let config = BartConfig::from_file(config_path);
/// # let bart_model: BartForConditionalGeneration = BartForConditionalGeneration::new(&vs.root(), &config, false);
/// let (batch_size, source_sequence_length, target_sequence_length) = (64, 128, 56);
/// let input_tensor = Tensor::rand(&[batch_size, source_sequence_length], (Int64, device));
/// let target_tensor = Tensor::rand(&[batch_size, target_sequence_length], (Int64, device));
@ -427,37 +511,64 @@ impl BartForConditionalGeneration {
/// None,
/// false)
/// });
///
/// ```
///
pub fn forward_t(&self,
input_ids: Option<&Tensor>,
attention_mask: Option<&Tensor>,
encoder_outputs: Option<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>)>,
decoder_input_ids: Option<&Tensor>,
decoder_attention_mask: Option<&Tensor>,
old_layer_states: Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
train: bool)
-> (Tensor, Tensor, Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
Option<Vec<Tensor>>, Option<Vec<Tensor>>,
Option<Vec<Tensor>>, Option<Vec<Tensor>>)
{
let (decoder_outputs, encoder_hidden_states, decoder_cache,
all_decoder_hidden_states, all_decoder_attentions,
all_encoder_hidden_states, all_encoder_attentions) =
self.base_model.forward_t(input_ids, attention_mask, decoder_input_ids, encoder_outputs, decoder_attention_mask, old_layer_states, train);
pub fn forward_t(
&self,
input_ids: Option<&Tensor>,
attention_mask: Option<&Tensor>,
encoder_outputs: Option<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>)>,
decoder_input_ids: Option<&Tensor>,
decoder_attention_mask: Option<&Tensor>,
old_layer_states: Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
train: bool,
) -> (
Tensor,
Tensor,
Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
Option<Vec<Tensor>>,
Option<Vec<Tensor>>,
Option<Vec<Tensor>>,
Option<Vec<Tensor>>,
) {
let (
decoder_outputs,
encoder_hidden_states,
decoder_cache,
all_decoder_hidden_states,
all_decoder_attentions,
all_encoder_hidden_states,
all_encoder_attentions,
) = self.base_model.forward_t(
input_ids,
attention_mask,
decoder_input_ids,
encoder_outputs,
decoder_attention_mask,
old_layer_states,
train,
);
let lm_logits = decoder_outputs.linear::<Tensor>(&self.base_model.embeddings.ws, None);
(lm_logits, encoder_hidden_states, decoder_cache,
all_decoder_hidden_states, all_decoder_attentions,
all_encoder_hidden_states, all_encoder_attentions)
(
lm_logits,
encoder_hidden_states,
decoder_cache,
all_decoder_hidden_states,
all_decoder_attentions,
all_encoder_hidden_states,
all_encoder_attentions,
)
}
pub fn encode(&self, input_ids: &Tensor, attention_mask: Option<&Tensor>) -> Tensor {
let (encoder_hidden_states, _, _) = self.base_model.encoder.forward_t(input_ids, attention_mask, &self.base_model.embeddings, false);
let (encoder_hidden_states, _, _) = self.base_model.encoder.forward_t(
input_ids,
attention_mask,
&self.base_model.embeddings,
false,
);
encoder_hidden_states
}
}
pub struct BartClassificationHead {
@ -468,16 +579,29 @@ pub struct BartClassificationHead {
impl BartClassificationHead {
pub fn new(p: &nn::Path, config: &BartConfig) -> BartClassificationHead {
let dense = nn::linear(&(p / "dense"), config.d_model, config.d_model, Default::default());
let dense = nn::linear(
&(p / "dense"),
config.d_model,
config.d_model,
Default::default(),
);
let dropout = Dropout::new(config.classif_dropout);
let out_proj = nn::linear(&(p / "out_proj"), config.d_model, config.num_labels.unwrap(), Default::default());
let out_proj = nn::linear(
&(p / "out_proj"),
config.d_model,
config.num_labels.unwrap(),
Default::default(),
);
BartClassificationHead { dense, dropout, out_proj }
BartClassificationHead {
dense,
dropout,
out_proj,
}
}
pub fn forward_t(&self, x: &Tensor, train: bool) -> Tensor {
x
.apply_t(&self.dropout, train)
x.apply_t(&self.dropout, train)
.apply(&self.dense)
.tanh()
.apply_t(&self.dropout, train)
@ -497,7 +621,6 @@ pub struct BartForSequenceClassification {
eos_token_id: i64,
}
impl BartForSequenceClassification {
/// Build a new `BartForSequenceClassification`
///
@ -509,27 +632,31 @@ impl BartForSequenceClassification {
/// # Example
///
/// ```no_run
/// use tch::{nn, Device};
/// use rust_bert::bart::{BartConfig, BartForSequenceClassification};
/// use rust_bert::Config;
/// use std::path::Path;
/// use rust_bert::bart::{BartConfig, BartForSequenceClassification};
/// use tch::{nn, Device};
///
/// let config_path = Path::new("path/to/config.json");
/// let device = Device::Cpu;
/// let p = nn::VarStore::new(device);
/// let config = BartConfig::from_file(config_path);
/// let generation_mode = true;
/// let bart: BartForSequenceClassification = BartForSequenceClassification::new(&(&p.root() / "bart"), &config);
/// let bart: BartForSequenceClassification =
/// BartForSequenceClassification::new(&(&p.root() / "bart"), &config);
/// ```
///
pub fn new(p: &nn::Path, config: &BartConfig) -> BartForSequenceClassification {
let base_model = BartModel::new(&(p / "model"), config, false);
let classification_head = BartClassificationHead::new(&(p / "classification_head"), config);
let eos_token_id = match config.eos_token_id {
Some(value) => value,
None => 3
None => 3,
};
BartForSequenceClassification { base_model, classification_head, eos_token_id }
BartForSequenceClassification {
base_model,
classification_head,
eos_token_id,
}
}
/// Forward pass through the model
@ -556,17 +683,17 @@ impl BartForSequenceClassification {
/// # Example
///
/// ```no_run
///# use tch::{nn, Device, Tensor, no_grad};
///# use rust_bert::Config;
///# use std::path::Path;
///# use tch::kind::Kind::{Int64, Double};
/// # use tch::{nn, Device, Tensor, no_grad};
/// # use rust_bert::Config;
/// # use std::path::Path;
/// # use tch::kind::Kind::{Int64, Double};
/// use rust_bert::bart::{BartConfig, BartForConditionalGeneration};
///# let config_path = Path::new("path/to/config.json");
///# let vocab_path = Path::new("path/to/vocab.txt");
///# let device = Device::Cpu;
///# let vs = nn::VarStore::new(device);
///# let config = BartConfig::from_file(config_path);
///# let bart_model: BartForConditionalGeneration = BartForConditionalGeneration::new(&vs.root(), &config, false);
/// # let config_path = Path::new("path/to/config.json");
/// # let vocab_path = Path::new("path/to/vocab.txt");
/// # let device = Device::Cpu;
/// # let vs = nn::VarStore::new(device);
/// # let config = BartConfig::from_file(config_path);
/// # let bart_model: BartForConditionalGeneration = BartForConditionalGeneration::new(&vs.root(), &config, false);
/// let (batch_size, source_sequence_length, target_sequence_length) = (64, 128, 56);
/// let input_tensor = Tensor::rand(&[batch_size, source_sequence_length], (Int64, device));
/// let target_tensor = Tensor::rand(&[batch_size, target_sequence_length], (Int64, device));
@ -585,36 +712,63 @@ impl BartForSequenceClassification {
/// None,
/// false)
/// });
///
/// ```
///
pub fn forward_t(&mut self,
input_ids: &Tensor,
attention_mask: Option<&Tensor>,
encoder_outputs: Option<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>)>,
decoder_input_ids: Option<&Tensor>,
decoder_attention_mask: Option<&Tensor>,
train: bool)
-> (Tensor, Tensor,
Option<Vec<Tensor>>, Option<Vec<Tensor>>,
Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
let (decoder_outputs, encoder_hidden_states, _,
all_decoder_hidden_states, all_decoder_attentions,
all_encoder_hidden_states, all_encoder_attentions) =
self.borrow_mut().base_model.forward_t(Some(input_ids), attention_mask, decoder_input_ids, encoder_outputs, decoder_attention_mask, None, train);
pub fn forward_t(
&mut self,
input_ids: &Tensor,
attention_mask: Option<&Tensor>,
encoder_outputs: Option<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>)>,
decoder_input_ids: Option<&Tensor>,
decoder_attention_mask: Option<&Tensor>,
train: bool,
) -> (
Tensor,
Tensor,
Option<Vec<Tensor>>,
Option<Vec<Tensor>>,
Option<Vec<Tensor>>,
Option<Vec<Tensor>>,
) {
let (
decoder_outputs,
encoder_hidden_states,
_,
all_decoder_hidden_states,
all_decoder_attentions,
all_encoder_hidden_states,
all_encoder_attentions,
) = self.borrow_mut().base_model.forward_t(
Some(input_ids),
attention_mask,
decoder_input_ids,
encoder_outputs,
decoder_attention_mask,
None,
train,
);
let eos_mask = input_ids.eq(self.eos_token_id);
let sentence_representation = decoder_outputs
.index_select(0, &eos_mask)
.view((decoder_outputs.size()[0], -1, *decoder_outputs.size().last().unwrap()))
.view((
decoder_outputs.size()[0],
-1,
*decoder_outputs.size().last().unwrap(),
))
.select(1, -1);
let logits = self.classification_head.forward_t(&sentence_representation, train);
(logits, encoder_hidden_states,
all_decoder_hidden_states, all_decoder_attentions,
all_encoder_hidden_states, all_encoder_attentions)
let logits = self
.classification_head
.forward_t(&sentence_representation, train);
(
logits,
encoder_hidden_states,
all_decoder_hidden_states,
all_decoder_attentions,
all_encoder_hidden_states,
all_encoder_attentions,
)
}
}
impl LMHeadModel for BartForConditionalGeneration {
@ -645,18 +799,18 @@ impl LMHeadModel for BartForConditionalGeneration {
/// # Example
///
/// ```no_run
///# use tch::{nn, Device, Tensor, no_grad};
///# use rust_bert::Config;
///# use std::path::Path;
///# use tch::kind::Kind::{Int64, Double};
/// # use tch::{nn, Device, Tensor, no_grad};
/// # use rust_bert::Config;
/// # use std::path::Path;
/// # use tch::kind::Kind::{Int64, Double};
/// use rust_bert::pipelines::generation::LMHeadModel;
/// use rust_bert::bart::{BartForConditionalGeneration, BartConfig};
///# let config_path = Path::new("path/to/config.json");
///# let vocab_path = Path::new("path/to/vocab.txt");
///# let device = Device::Cpu;
///# let vs = nn::VarStore::new(device);
///# let config = BartConfig::from_file(config_path);
///# let bart_model: BartForConditionalGeneration = BartForConditionalGeneration::new(&vs.root(), &config, false);
/// # let config_path = Path::new("path/to/config.json");
/// # let vocab_path = Path::new("path/to/vocab.txt");
/// # let device = Device::Cpu;
/// # let vs = nn::VarStore::new(device);
/// # let config = BartConfig::from_file(config_path);
/// # let bart_model: BartForConditionalGeneration = BartForConditionalGeneration::new(&vs.root(), &config, false);
/// let (batch_size, source_sequence_length, target_sequence_length) = (64, 128, 56);
/// let input_tensor = Tensor::rand(&[batch_size, source_sequence_length], (Int64, device));
/// let target_tensor = Tensor::rand(&[batch_size, target_sequence_length], (Int64, device));
@ -675,39 +829,58 @@ impl LMHeadModel for BartForConditionalGeneration {
/// None,
/// false)
/// });
///
/// ```
///
fn forward_t(&self,
input_ids: &Option<Tensor>,
cache: Cache,
attention_mask: &Option<Tensor>,
_token_type_ids: &Option<Tensor>,
_position_ids: &Option<Tensor>,
_input_embeds: &Option<Tensor>,
encoder_outputs: Option<&Tensor>,
decoder_input_ids: &Option<Tensor>,
train: bool) -> Result<(Tensor, Option<Tensor>, Cache, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
fn forward_t(
&self,
input_ids: &Option<Tensor>,
cache: Cache,
attention_mask: &Option<Tensor>,
_token_type_ids: &Option<Tensor>,
_position_ids: &Option<Tensor>,
_input_embeds: &Option<Tensor>,
encoder_outputs: Option<&Tensor>,
decoder_input_ids: &Option<Tensor>,
train: bool,
) -> Result<
(
Tensor,
Option<Tensor>,
Cache,
Option<Vec<Tensor>>,
Option<Vec<Tensor>>,
),
&'static str,
> {
let (decoder_output, encoder_hidden_states, new_cache, _, _, _, _) = match cache {
Cache::BARTCache(cached_layer_states) => self.base_model.forward_t(input_ids.as_ref(),
attention_mask.as_ref(),
decoder_input_ids.as_ref(),
Some((encoder_outputs.as_ref().unwrap().copy(), None, None)),
None,
cached_layer_states,
train),
Cache::BARTCache(cached_layer_states) => self.base_model.forward_t(
input_ids.as_ref(),
attention_mask.as_ref(),
decoder_input_ids.as_ref(),
Some((encoder_outputs.as_ref().unwrap().copy(), None, None)),
None,
cached_layer_states,
train,
),
Cache::None => self.base_model.forward_t(input_ids.as_ref(),
attention_mask.as_ref(),
decoder_input_ids.as_ref(),
Some((encoder_outputs.as_ref().unwrap().copy(), None, None)),
None,
None,
train),
_ => Err("Cache not compatible with BART Model")?
Cache::None => self.base_model.forward_t(
input_ids.as_ref(),
attention_mask.as_ref(),
decoder_input_ids.as_ref(),
Some((encoder_outputs.as_ref().unwrap().copy(), None, None)),
None,
None,
train,
),
_ => Err("Cache not compatible with BART Model")?,
};
let lm_logits = decoder_output.linear::<Tensor>(&self.base_model.embeddings.ws, None);
Ok((lm_logits, Some(encoder_hidden_states), Cache::BARTCache(new_cache), None, None))
Ok((
lm_logits,
Some(encoder_hidden_states),
Cache::BARTCache(new_cache),
None,
None,
))
}
}
}

View File

@ -11,15 +11,17 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use crate::bart::attention::{SelfAttention, LayerState};
use tch::{nn, Tensor};
use crate::common::dropout::Dropout;
use crate::bart::BartConfig;
use crate::bart::attention::{LayerState, SelfAttention};
use crate::bart::bart::Activation;
use crate::common::activations::{_gelu, _relu, _swish, _gelu_new, _tanh};
use tch::kind::Kind::Int64;
use crate::bart::embeddings::{
EmbeddingOption, LearnedPositionalEmbedding, SinusoidalPositionalEmbedding,
};
use crate::bart::BartConfig;
use crate::common::activations::{_gelu, _gelu_new, _relu, _swish, _tanh};
use crate::common::dropout::Dropout;
use std::borrow::BorrowMut;
use crate::bart::embeddings::{EmbeddingOption, LearnedPositionalEmbedding, SinusoidalPositionalEmbedding};
use tch::kind::Kind::Int64;
use tch::{nn, Tensor};
pub struct DecoderLayer {
self_attention: SelfAttention,
@ -36,51 +38,74 @@ pub struct DecoderLayer {
impl DecoderLayer {
pub fn new(p: nn::Path, config: &BartConfig) -> DecoderLayer {
let layer_norm_config = nn::LayerNormConfig { eps: 1e-5, ..Default::default() };
let layer_norm_config = nn::LayerNormConfig {
eps: 1e-5,
..Default::default()
};
let output_attention = match config.output_attentions {
Some(value) => value,
None => false
None => false,
};
let self_attention = SelfAttention::new(&p / "self_attn",
config.d_model,
config.decoder_attention_heads,
config.attention_dropout,
false,
true,
output_attention);
let encoder_attention = SelfAttention::new(&p / "encoder_attn",
config.d_model,
config.decoder_attention_heads,
config.attention_dropout,
true,
true,
output_attention);
let self_attention_layer_norm = nn::layer_norm(&p / "self_attn_layer_norm",
vec![config.d_model],
layer_norm_config);
let encoder_attention_layer_norm = nn::layer_norm(&p / "encoder_attn_layer_norm",
vec![config.d_model],
layer_norm_config);
let self_attention = SelfAttention::new(
&p / "self_attn",
config.d_model,
config.decoder_attention_heads,
config.attention_dropout,
false,
true,
output_attention,
);
let encoder_attention = SelfAttention::new(
&p / "encoder_attn",
config.d_model,
config.decoder_attention_heads,
config.attention_dropout,
true,
true,
output_attention,
);
let self_attention_layer_norm = nn::layer_norm(
&p / "self_attn_layer_norm",
vec![config.d_model],
layer_norm_config,
);
let encoder_attention_layer_norm = nn::layer_norm(
&p / "encoder_attn_layer_norm",
vec![config.d_model],
layer_norm_config,
);
let dropout = Dropout::new(config.dropout);
let activation_dropout = Dropout::new(config.activation_dropout);
let activation_function = match &config.activation_function {
Some(act_function) => act_function,
None => &Activation::gelu
None => &Activation::gelu,
};
let activation = Box::new(match activation_function {
Activation::gelu => _gelu,
Activation::relu => _relu,
Activation::swish => _swish,
Activation::gelu_new => _gelu_new,
Activation::tanh => _tanh
Activation::tanh => _tanh,
});
let fc1 = nn::linear(&p / "fc1", config.d_model, config.decoder_ffn_dim, Default::default());
let fc2 = nn::linear(&p / "fc2", config.decoder_ffn_dim, config.d_model, Default::default());
let fc1 = nn::linear(
&p / "fc1",
config.d_model,
config.decoder_ffn_dim,
Default::default(),
);
let fc2 = nn::linear(
&p / "fc2",
config.decoder_ffn_dim,
config.d_model,
Default::default(),
);
let final_layer_norm = nn::layer_norm(&p / "final_layer_norm",
vec![config.d_model],
layer_norm_config);
let final_layer_norm = nn::layer_norm(
&p / "final_layer_norm",
vec![config.d_model],
layer_norm_config,
);
DecoderLayer {
self_attention,
@ -96,18 +121,38 @@ impl DecoderLayer {
}
}
pub fn forward_t(&self,
x: &Tensor,
encoder_hidden_states: &Tensor,
encoder_attn_mask: Option<&Tensor>,
causal_mask: Option<&Tensor>,
decoder_padding_mask: Option<&Tensor>,
layer_states: (Option<LayerState>, Option<LayerState>),
train: bool) -> (Tensor, Option<Tensor>, (Option<LayerState>, Option<LayerState>)) {
let (output, attention_weights, new_self_layer_states) = self.self_attention.forward_t(x, Some(x), decoder_padding_mask, causal_mask, layer_states.0, train);
pub fn forward_t(
&self,
x: &Tensor,
encoder_hidden_states: &Tensor,
encoder_attn_mask: Option<&Tensor>,
causal_mask: Option<&Tensor>,
decoder_padding_mask: Option<&Tensor>,
layer_states: (Option<LayerState>, Option<LayerState>),
train: bool,
) -> (
Tensor,
Option<Tensor>,
(Option<LayerState>, Option<LayerState>),
) {
let (output, attention_weights, new_self_layer_states) = self.self_attention.forward_t(
x,
Some(x),
decoder_padding_mask,
causal_mask,
layer_states.0,
train,
);
let output: Tensor = output.apply_t(&self.dropout, train) + x;
let output = output.apply(&self.self_attention_layer_norm);
let (output1, _, new_encoder_layer_states) = self.encoder_attention.forward_t(&output, Some(encoder_hidden_states), encoder_attn_mask, None, layer_states.1, train);
let (output1, _, new_encoder_layer_states) = self.encoder_attention.forward_t(
&output,
Some(encoder_hidden_states),
encoder_attn_mask,
None,
layer_states.1,
train,
);
let output1: Tensor = output1.apply_t(&self.dropout, train) + output;
let output1 = output1.apply(&self.encoder_attention_layer_norm);
let output2 = (self.activation)(&output1.apply(&self.fc1));
@ -116,7 +161,11 @@ impl DecoderLayer {
.apply(&self.fc2)
.apply_t(&self.dropout, train);
let output2: Tensor = output2 + output1;
(output2.apply(&self.final_layer_norm), attention_weights, (new_self_layer_states, new_encoder_layer_states))
(
output2.apply(&self.final_layer_norm),
attention_weights,
(new_self_layer_states, new_encoder_layer_states),
)
}
}
@ -136,61 +185,76 @@ impl BartDecoder {
pub fn new(p: nn::Path, config: &BartConfig, generation_mode: bool) -> BartDecoder {
let output_past = match config.output_past {
Some(value) => value,
None => true
None => true,
};
let output_attentions = match config.output_attentions {
Some(value) => value,
None => false
None => false,
};
let output_hidden_states = match config.output_hidden_states {
Some(value) => value,
None => false
None => false,
};
let normalize_embedding = match config.normalize_embedding {
Some(value) => value,
None => true
None => true,
};
let static_position_embeddings = match config.static_position_embeddings {
Some(value) => value,
None => false
None => false,
};
let scale_embedding = match config.scale_embedding {
Some(value) => if value { (config.d_model as f64).sqrt() } else { 1.0 },
None => 1.0
Some(value) => {
if value {
(config.d_model as f64).sqrt()
} else {
1.0
}
}
None => 1.0,
};
let dropout = Dropout::new(config.dropout);
let layer_norm_embedding = if normalize_embedding {
let layer_norm_config = nn::LayerNormConfig { eps: 1e-5, ..Default::default() };
Some(nn::layer_norm(&p / "layernorm_embedding",
vec![config.d_model],
layer_norm_config))
let layer_norm_config = nn::LayerNormConfig {
eps: 1e-5,
..Default::default()
};
Some(nn::layer_norm(
&p / "layernorm_embedding",
vec![config.d_model],
layer_norm_config,
))
} else {
None
};
let pad_token_id = match config.pad_token_id {
Some(value) => value,
None => 1
None => 1,
};
let embed_positions = if static_position_embeddings {
EmbeddingOption::SinusoidalPositionalEmbedding(SinusoidalPositionalEmbedding::new(&p / "embed_positions",
config.max_position_embeddings,
config.d_model))
EmbeddingOption::SinusoidalPositionalEmbedding(SinusoidalPositionalEmbedding::new(
&p / "embed_positions",
config.max_position_embeddings,
config.d_model,
))
} else {
EmbeddingOption::LearnedPositionalEmbedding(LearnedPositionalEmbedding::new(&p / "embed_positions",
config.max_position_embeddings,
config.d_model,
pad_token_id))
EmbeddingOption::LearnedPositionalEmbedding(LearnedPositionalEmbedding::new(
&p / "embed_positions",
config.max_position_embeddings,
config.d_model,
pad_token_id,
))
};
let mut layers: Vec<DecoderLayer> = vec!();
let mut layers: Vec<DecoderLayer> = vec![];
let p_layers = &p / "layers";
for layer_index in 0..config.decoder_layers {
layers.push(DecoderLayer::new(&p_layers / layer_index, config));
};
}
BartDecoder {
dropout,
@ -205,44 +269,68 @@ impl BartDecoder {
}
}
pub fn forward_t(&self,
input_ids: &Tensor,
encoder_hidden_states: &Tensor,
encoder_padding_mask: Option<&Tensor>,
decoder_padding_mask: Option<&Tensor>,
decoder_causal_mask: Option<&Tensor>,
embeddings: &nn::Embedding,
old_layer_states: Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
train: bool)
-> (Tensor,
(Option<Tensor>, Option<Vec<(Option<LayerState>, Option<LayerState>)>>),
Option<Vec<Tensor>>,
Option<Vec<Tensor>>) {
pub fn forward_t(
&self,
input_ids: &Tensor,
encoder_hidden_states: &Tensor,
encoder_padding_mask: Option<&Tensor>,
decoder_padding_mask: Option<&Tensor>,
decoder_causal_mask: Option<&Tensor>,
embeddings: &nn::Embedding,
old_layer_states: Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
train: bool,
) -> (
Tensor,
(
Option<Tensor>,
Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
),
Option<Vec<Tensor>>,
Option<Vec<Tensor>>,
) {
let encoder_padding_mask = match encoder_padding_mask {
Some(mask) => Some(mask.eq(0).to_kind(Int64)),
None => None
None => None,
};
let positions = self.embed_positions.forward(input_ids, self.generation_mode);
let positions = self
.embed_positions
.forward(input_ids, self.generation_mode);
let x: Tensor = if self.generation_mode {
let end_inputs = input_ids.size()[1];
let end_positions = positions.size()[1];
input_ids.narrow(1, end_inputs - 1, 1).apply(embeddings) * self.scale_embedding + positions.narrow(1, end_positions - 1, 1)
input_ids.narrow(1, end_inputs - 1, 1).apply(embeddings) * self.scale_embedding
+ positions.narrow(1, end_positions - 1, 1)
} else {
input_ids.apply(embeddings) * self.scale_embedding + positions
};
let x = if let Some(layer_norm_embedding) = &self.layer_norm_embedding { x.apply(layer_norm_embedding) } else { x };
let mut hidden_state = x
.apply_t(&self.dropout, train)
.transpose(0, 1);
let mut all_hidden_states: Option<Vec<Tensor>> = if self.output_hidden_states { Some(Vec::with_capacity(self.layers.len())) } else { None };
let mut all_attentions: Option<Vec<Tensor>> = if self.output_attentions { Some(Vec::with_capacity(self.layers.len())) } else { None };
let mut next_decoder_cache: Option<Vec<(Option<LayerState>, Option<LayerState>)>> = if self.output_past {
if old_layer_states.is_some() { old_layer_states } else { Some(vec!((None, None); self.layers.len())) }
let x = if let Some(layer_norm_embedding) = &self.layer_norm_embedding {
x.apply(layer_norm_embedding)
} else {
x
};
let mut hidden_state = x.apply_t(&self.dropout, train).transpose(0, 1);
let mut all_hidden_states: Option<Vec<Tensor>> = if self.output_hidden_states {
Some(Vec::with_capacity(self.layers.len()))
} else {
None
};
let mut all_attentions: Option<Vec<Tensor>> = if self.output_attentions {
Some(Vec::with_capacity(self.layers.len()))
} else {
None
};
let mut next_decoder_cache: Option<Vec<(Option<LayerState>, Option<LayerState>)>> =
if self.output_past {
if old_layer_states.is_some() {
old_layer_states
} else {
Some(vec![(None, None); self.layers.len()])
}
} else {
None
};
let encoder_hidden_states = encoder_hidden_states.transpose(0, 1);
let mut attention_weights: Option<Tensor>;
let mut layers = self.layers.iter().enumerate();
@ -252,15 +340,17 @@ impl BartDecoder {
Some((layer_idx, layer)) => {
let layer_state = match &next_decoder_cache {
Some(values) => values[layer_idx].to_owned(),
None => (None, None)
None => (None, None),
};
let temp = layer.forward_t(&hidden_state,
&encoder_hidden_states,
encoder_padding_mask.as_ref(),
decoder_causal_mask,
decoder_padding_mask,
layer_state,
train);
let temp = layer.forward_t(
&hidden_state,
&encoder_hidden_states,
encoder_padding_mask.as_ref(),
decoder_causal_mask,
decoder_padding_mask,
layer_state,
train,
);
hidden_state = temp.0;
attention_weights = temp.1;
if let Some(hidden_states) = all_hidden_states.borrow_mut() {
@ -269,15 +359,19 @@ impl BartDecoder {
if let Some(attentions) = all_attentions.borrow_mut() {
attentions.push(attention_weights.as_ref().unwrap().copy());
};
if let Some(value) = &mut next_decoder_cache { value[layer_idx] = temp.2 };
if let Some(value) = &mut next_decoder_cache {
value[layer_idx] = temp.2
};
}
None => break
None => break,
};
};
}
(hidden_state.transpose(0, 1),
(encoder_padding_mask, next_decoder_cache),
all_hidden_states,
all_attentions)
(
hidden_state.transpose(0, 1),
(encoder_padding_mask, next_decoder_cache),
all_hidden_states,
all_attentions,
)
}
}
}

View File

@ -11,10 +11,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use tch::{nn, Tensor};
use tch::nn::{EmbeddingConfig, embedding};
use tch::kind::Kind::Int64;
use tch::nn::{embedding, EmbeddingConfig};
use tch::{nn, Tensor};
/// # Abstraction that holds a embeddings configuration
pub enum EmbeddingOption {
@ -27,8 +26,12 @@ impl EmbeddingOption {
/// Interface method to forward_t() of the particular models.
pub fn forward(&self, input: &Tensor, generation_mode: bool) -> Tensor {
match *self {
Self::LearnedPositionalEmbedding(ref embeddings) => embeddings.forward(input, generation_mode),
Self::SinusoidalPositionalEmbedding(ref embeddings) => embeddings.forward(input, generation_mode)
Self::LearnedPositionalEmbedding(ref embeddings) => {
embeddings.forward(input, generation_mode)
}
Self::SinusoidalPositionalEmbedding(ref embeddings) => {
embeddings.forward(input, generation_mode)
}
}
}
}
@ -40,15 +43,24 @@ pub struct LearnedPositionalEmbedding {
}
impl LearnedPositionalEmbedding {
pub fn new(p: nn::Path, num_embeddings: i64, embedding_dim: i64, padding_index: i64) -> LearnedPositionalEmbedding {
let embedding_config = EmbeddingConfig { padding_idx: padding_index, ..Default::default() };
pub fn new(
p: nn::Path,
num_embeddings: i64,
embedding_dim: i64,
padding_index: i64,
) -> LearnedPositionalEmbedding {
let embedding_config = EmbeddingConfig {
padding_idx: padding_index,
..Default::default()
};
let num_embeddings = num_embeddings + padding_index + 1;
let embedding: nn::Embedding = embedding(p,
num_embeddings,
embedding_dim,
embedding_config);
LearnedPositionalEmbedding { embedding, padding_index }
let embedding: nn::Embedding =
embedding(p, num_embeddings, embedding_dim, embedding_config);
LearnedPositionalEmbedding {
embedding,
padding_index,
}
}
pub fn forward(&self, input: &Tensor, generation_mode: bool) -> Tensor {
@ -74,11 +86,13 @@ pub struct SinusoidalPositionalEmbedding {
}
impl SinusoidalPositionalEmbedding {
pub fn new(p: nn::Path, num_embeddings: i64, embedding_dim: i64) -> SinusoidalPositionalEmbedding {
let embedding: nn::Embedding = embedding(p,
num_embeddings,
embedding_dim,
Default::default());
pub fn new(
p: nn::Path,
num_embeddings: i64,
embedding_dim: i64,
) -> SinusoidalPositionalEmbedding {
let embedding: nn::Embedding =
embedding(p, num_embeddings, embedding_dim, Default::default());
SinusoidalPositionalEmbedding { embedding }
}
@ -86,8 +100,8 @@ impl SinusoidalPositionalEmbedding {
let positions = if generation_mode {
Tensor::full(&[1, 1], input.size()[1] - 1, (Int64, input.device()))
} else {
Tensor::arange(input.size()[1],(Int64, input.device()))
Tensor::arange(input.size()[1], (Int64, input.device()))
};
positions.apply(&self.embedding)
}
}
}

View File

@ -12,14 +12,16 @@
// limitations under the License.
use crate::bart::attention::SelfAttention;
use tch::{nn, Tensor};
use crate::common::dropout::Dropout;
use crate::bart::BartConfig;
use crate::bart::bart::Activation;
use crate::common::activations::{_gelu, _relu, _swish, _gelu_new, _tanh};
use crate::bart::embeddings::{EmbeddingOption, LearnedPositionalEmbedding, SinusoidalPositionalEmbedding};
use tch::kind::Kind::Bool;
use crate::bart::embeddings::{
EmbeddingOption, LearnedPositionalEmbedding, SinusoidalPositionalEmbedding,
};
use crate::bart::BartConfig;
use crate::common::activations::{_gelu, _gelu_new, _relu, _swish, _tanh};
use crate::common::dropout::Dropout;
use std::borrow::BorrowMut;
use tch::kind::Kind::Bool;
use tch::{nn, Tensor};
pub struct EncoderLayer {
self_attention: SelfAttention,
@ -34,46 +36,81 @@ pub struct EncoderLayer {
impl EncoderLayer {
pub fn new(p: nn::Path, config: &BartConfig) -> EncoderLayer {
let layer_norm_config = nn::LayerNormConfig { eps: 1e-5, ..Default::default() };
let layer_norm_config = nn::LayerNormConfig {
eps: 1e-5,
..Default::default()
};
let output_attention = match config.output_attentions {
Some(value) => value,
None => false
None => false,
};
let self_attention = SelfAttention::new(&p / "self_attn",
config.d_model,
config.encoder_attention_heads,
config.attention_dropout,
false,
false,
output_attention);
let self_attention_layer_norm = nn::layer_norm(&p / "self_attn_layer_norm",
vec![config.d_model],
layer_norm_config);
let self_attention = SelfAttention::new(
&p / "self_attn",
config.d_model,
config.encoder_attention_heads,
config.attention_dropout,
false,
false,
output_attention,
);
let self_attention_layer_norm = nn::layer_norm(
&p / "self_attn_layer_norm",
vec![config.d_model],
layer_norm_config,
);
let dropout = Dropout::new(config.dropout);
let activation_dropout = Dropout::new(config.activation_dropout);
let activation_function = match &config.activation_function {
Some(act_function) => act_function,
None => &Activation::gelu
None => &Activation::gelu,
};
let activation = Box::new(match activation_function {
Activation::gelu => _gelu,
Activation::relu => _relu,
Activation::swish => _swish,
Activation::gelu_new => _gelu_new,
Activation::tanh => _tanh
Activation::tanh => _tanh,
});
let fc1 = nn::linear(&p / "fc1", config.d_model, config.encoder_ffn_dim, Default::default());
let fc2 = nn::linear(&p / "fc2", config.encoder_ffn_dim, config.d_model, Default::default());
let fc1 = nn::linear(
&p / "fc1",
config.d_model,
config.encoder_ffn_dim,
Default::default(),
);
let fc2 = nn::linear(
&p / "fc2",
config.encoder_ffn_dim,
config.d_model,
Default::default(),
);
let final_layer_norm = nn::layer_norm(&p / "final_layer_norm",
vec![config.d_model],
layer_norm_config);
let final_layer_norm = nn::layer_norm(
&p / "final_layer_norm",
vec![config.d_model],
layer_norm_config,
);
EncoderLayer { self_attention, self_attention_layer_norm, dropout, activation_dropout, activation, fc1, fc2, final_layer_norm }
EncoderLayer {
self_attention,
self_attention_layer_norm,
dropout,
activation_dropout,
activation,
fc1,
fc2,
final_layer_norm,
}
}
pub fn forward_t(&self, x: &Tensor, encoder_padding_mask: Option<&Tensor>, train: bool) -> (Tensor, Option<Tensor>) {
let (output, attention_weights, _) = self.self_attention.forward_t(x, None, encoder_padding_mask, None, None, train);
pub fn forward_t(
&self,
x: &Tensor,
encoder_padding_mask: Option<&Tensor>,
train: bool,
) -> (Tensor, Option<Tensor>) {
let (output, attention_weights, _) =
self.self_attention
.forward_t(x, None, encoder_padding_mask, None, None, train);
let output: Tensor = output.apply_t(&self.dropout, train) + x;
let output = output.apply(&self.self_attention_layer_norm);
@ -102,57 +139,72 @@ impl BartEncoder {
pub fn new(p: nn::Path, config: &BartConfig) -> BartEncoder {
let output_attentions = match config.output_attentions {
Some(value) => value,
None => false
None => false,
};
let output_hidden_states = match config.output_hidden_states {
Some(value) => value,
None => false
None => false,
};
let normalize_embedding = match config.normalize_embedding {
Some(value) => value,
None => true
None => true,
};
let static_position_embeddings = match config.static_position_embeddings {
Some(value) => value,
None => false
None => false,
};
let scale_embedding = match config.scale_embedding {
Some(value) => if value { (config.d_model as f64).sqrt() } else { 1.0 },
None => 1.0
Some(value) => {
if value {
(config.d_model as f64).sqrt()
} else {
1.0
}
}
None => 1.0,
};
let dropout = Dropout::new(config.dropout);
let layer_norm_embedding = if normalize_embedding {
let layer_norm_config = nn::LayerNormConfig { eps: 1e-5, ..Default::default() };
Some(nn::layer_norm(&p / "layernorm_embedding",
vec![config.d_model],
layer_norm_config))
let layer_norm_config = nn::LayerNormConfig {
eps: 1e-5,
..Default::default()
};
Some(nn::layer_norm(
&p / "layernorm_embedding",
vec![config.d_model],
layer_norm_config,
))
} else {
None
};
let pad_token_id = match config.pad_token_id {
Some(value) => value,
None => 1
None => 1,
};
let embed_positions = if static_position_embeddings {
EmbeddingOption::SinusoidalPositionalEmbedding(SinusoidalPositionalEmbedding::new(&p / "embed_positions",
config.max_position_embeddings,
config.d_model))
EmbeddingOption::SinusoidalPositionalEmbedding(SinusoidalPositionalEmbedding::new(
&p / "embed_positions",
config.max_position_embeddings,
config.d_model,
))
} else {
EmbeddingOption::LearnedPositionalEmbedding(LearnedPositionalEmbedding::new(&p / "embed_positions",
config.max_position_embeddings,
config.d_model,
pad_token_id))
EmbeddingOption::LearnedPositionalEmbedding(LearnedPositionalEmbedding::new(
&p / "embed_positions",
config.max_position_embeddings,
config.d_model,
pad_token_id,
))
};
let mut layers: Vec<EncoderLayer> = vec!();
let mut layers: Vec<EncoderLayer> = vec![];
let p_layers = &p / "layers";
for layer_index in 0..config.encoder_layers {
layers.push(EncoderLayer::new(&p_layers / layer_index, config));
};
}
BartEncoder {
dropout,
@ -165,26 +217,37 @@ impl BartEncoder {
}
}
pub fn forward_t(&self,
input_ids: &Tensor,
attention_mask: Option<&Tensor>,
embeddings: &nn::Embedding,
train: bool)
-> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
pub fn forward_t(
&self,
input_ids: &Tensor,
attention_mask: Option<&Tensor>,
embeddings: &nn::Embedding,
train: bool,
) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
let attention_mask = match attention_mask {
Some(mask) => Some(mask.eq(0).to_kind(Bool)),
None => None
None => None,
};
let x = input_ids.apply(embeddings) * self.scale_embedding;
let x: Tensor = x + &self.embed_positions.forward(input_ids, false);
let x = if let Some(layer_norm_embedding) = &self.layer_norm_embedding { x.apply(layer_norm_embedding) } else { x };
let x = x
.apply_t(&self.dropout, train)
.transpose(0, 1);
let x = if let Some(layer_norm_embedding) = &self.layer_norm_embedding {
x.apply(layer_norm_embedding)
} else {
x
};
let x = x.apply_t(&self.dropout, train).transpose(0, 1);
let mut all_hidden_states: Option<Vec<Tensor>> = if self.output_hidden_states { Some(vec!()) } else { None };
let mut all_attentions: Option<Vec<Tensor>> = if self.output_attentions { Some(vec!()) } else { None };
let mut all_hidden_states: Option<Vec<Tensor>> = if self.output_hidden_states {
Some(vec![])
} else {
None
};
let mut all_attentions: Option<Vec<Tensor>> = if self.output_attentions {
Some(vec![])
} else {
None
};
let mut hidden_state = x.copy();
let mut attention_weights: Option<Tensor>;
@ -204,13 +267,17 @@ impl BartEncoder {
attentions.push(attention_weights.as_ref().unwrap().copy());
};
}
None => break
None => break,
};
};
}
if let Some(hidden_states) = all_hidden_states.borrow_mut() {
hidden_states.push(hidden_state.as_ref().copy().transpose(0, 1));
};
(hidden_state.transpose(0, 1), all_hidden_states, all_attentions)
(
hidden_state.transpose(0, 1),
all_hidden_states,
all_attentions,
)
}
}
}

View File

@ -15,19 +15,27 @@
//! Pretrained models are available and can be downloaded using RemoteResources.
//!
//! ```no_run
//!# fn main() -> failure::Fallible<()> {
//!#
//! # fn main() -> failure::Fallible<()> {
//! #
//! use rust_tokenizers::RobertaTokenizer;
//! use tch::{nn, Device};
//!# use std::path::PathBuf;
//! use rust_bert::Config;
//! # use std::path::PathBuf;
//! use rust_bert::bart::{BartConfig, BartModel};
//! use rust_bert::resources::{Resource, download_resource, LocalResource};
//! use rust_bert::resources::{download_resource, LocalResource, Resource};
//! use rust_bert::Config;
//!
//! let config_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/config.json")});
//! let vocab_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/vocab.txt")});
//! let merges_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/vocab.txt")});
//! let weights_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/model.ot")});
//! let config_resource = Resource::Local(LocalResource {
//! local_path: PathBuf::from("path/to/config.json"),
//! });
//! let vocab_resource = Resource::Local(LocalResource {
//! local_path: PathBuf::from("path/to/vocab.txt"),
//! });
//! let merges_resource = Resource::Local(LocalResource {
//! local_path: PathBuf::from("path/to/vocab.txt"),
//! });
//! let weights_resource = Resource::Local(LocalResource {
//! local_path: PathBuf::from("path/to/model.ot"),
//! });
//! let config_path = download_resource(&config_resource)?;
//! let vocab_path = download_resource(&vocab_resource)?;
//! let merges_path = download_resource(&merges_resource)?;
@ -35,21 +43,28 @@
//!
//! let device = Device::cuda_if_available();
//! let mut vs = nn::VarStore::new(device);
//! let tokenizer: RobertaTokenizer = RobertaTokenizer::from_file(vocab_path.to_str().unwrap(), merges_path.to_str().unwrap(), true);
//! let tokenizer: RobertaTokenizer = RobertaTokenizer::from_file(
//! vocab_path.to_str().unwrap(),
//! merges_path.to_str().unwrap(),
//! true,
//! );
//! let config = BartConfig::from_file(config_path);
//! let bart_model = BartModel::new(&vs.root(), &config, false);
//! vs.load(weights_path)?;
//!
//!# Ok(())
//!# }
//! # Ok(())
//! # }
//! ```
mod bart;
mod attention;
mod encoder;
mod bart;
mod decoder;
mod embeddings;
mod encoder;
pub use bart::{BartModelResources, BartConfigResources, BartVocabResources, BartMergesResources,
BartConfig, Activation, BartModel, BartForSequenceClassification, BartForConditionalGeneration};
pub use attention::LayerState;
pub use attention::LayerState;
pub use bart::{
Activation, BartConfig, BartConfigResources, BartForConditionalGeneration,
BartForSequenceClassification, BartMergesResources, BartModel, BartModelResources,
BartVocabResources,
};

View File

@ -11,11 +11,11 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use crate::common::dropout::Dropout;
use tch::{nn, Tensor};
use tch::kind::Kind::Float;
use crate::bert::bert::{Activation, BertConfig};
use crate::common::activations::{_gelu, _relu, _mish};
use crate::common::activations::{_gelu, _mish, _relu};
use crate::common::dropout::Dropout;
use tch::kind::Kind::Float;
use tch::{nn, Tensor};
#[derive(Debug)]
pub struct BertSelfAttention {
@ -30,17 +30,36 @@ pub struct BertSelfAttention {
impl BertSelfAttention {
pub fn new(p: nn::Path, config: &BertConfig) -> BertSelfAttention {
assert_eq!(config.hidden_size % config.num_attention_heads, 0, "Hidden size not a multiple of the number of attention heads");
assert_eq!(
config.hidden_size % config.num_attention_heads,
0,
"Hidden size not a multiple of the number of attention heads"
);
let query = nn::linear(&p / "query", config.hidden_size, config.hidden_size, Default::default());
let key = nn::linear(&p / "key", config.hidden_size, config.hidden_size, Default::default());
let value = nn::linear(&p / "value", config.hidden_size, config.hidden_size, Default::default());
let query = nn::linear(
&p / "query",
config.hidden_size,
config.hidden_size,
Default::default(),
);
let key = nn::linear(
&p / "key",
config.hidden_size,
config.hidden_size,
Default::default(),
);
let value = nn::linear(
&p / "value",
config.hidden_size,
config.hidden_size,
Default::default(),
);
let dropout = Dropout::new(config.attention_probs_dropout_prob);
let attention_head_size = config.hidden_size / config.num_attention_heads;
let output_attentions = match config.output_attentions {
Some(value) => value,
None => false
None => false,
};
BertSelfAttention {
@ -55,35 +74,44 @@ impl BertSelfAttention {
}
fn split_heads(&self, x: Tensor, bs: i64, dim_per_head: i64) -> Tensor {
x.view((bs, -1, self.num_attention_heads, dim_per_head)).transpose(1, 2)
x.view((bs, -1, self.num_attention_heads, dim_per_head))
.transpose(1, 2)
}
fn flatten(&self, x: Tensor, bs: i64, dim_per_head: i64) -> Tensor {
x.transpose(1, 2).contiguous().view((bs, -1, &self.num_attention_heads * dim_per_head))
x.transpose(1, 2)
.contiguous()
.view((bs, -1, &self.num_attention_heads * dim_per_head))
}
pub fn forward_t(&self,
hidden_states: &Tensor,
mask: &Option<Tensor>,
encoder_hidden_states: &Option<Tensor>,
encoder_mask: &Option<Tensor>,
train: bool) -> (Tensor, Option<Tensor>) {
pub fn forward_t(
&self,
hidden_states: &Tensor,
mask: &Option<Tensor>,
encoder_hidden_states: &Option<Tensor>,
encoder_mask: &Option<Tensor>,
train: bool,
) -> (Tensor, Option<Tensor>) {
let (key_layer, value_layer, mask) = match encoder_hidden_states {
Some(encoder_hidden_state_values) => {
(encoder_hidden_state_values.apply(&self.key),
encoder_hidden_state_values.apply(&self.value),
encoder_mask)
}
None => {
(hidden_states.apply(&self.key),
hidden_states.apply(&self.value),
mask)
}
Some(encoder_hidden_state_values) => (
encoder_hidden_state_values.apply(&self.key),
encoder_hidden_state_values.apply(&self.value),
encoder_mask,
),
None => (
hidden_states.apply(&self.key),
hidden_states.apply(&self.value),
mask,
),
};
let bs = hidden_states.size()[0];
let query_layer = self.split_heads(hidden_states.apply(&self.query), bs, self.attention_head_size);
let query_layer = self.split_heads(
hidden_states.apply(&self.query),
bs,
self.attention_head_size,
);
let key_layer = self.split_heads(key_layer, bs, self.attention_head_size);
let value_layer = self.split_heads(value_layer, bs, self.attention_head_size);
let query_layer: Tensor = query_layer / (self.attention_head_size as f64).sqrt();
@ -114,16 +142,32 @@ pub struct BertSelfOutput {
impl BertSelfOutput {
pub fn new(p: &nn::Path, config: &BertConfig) -> BertSelfOutput {
let linear = nn::linear(p / "dense", config.hidden_size, config.hidden_size, Default::default());
let layer_norm_config = nn::LayerNormConfig { eps: 1e-12, ..Default::default() };
let layer_norm = nn::layer_norm(p / "LayerNorm", vec![config.hidden_size], layer_norm_config);
let linear = nn::linear(
p / "dense",
config.hidden_size,
config.hidden_size,
Default::default(),
);
let layer_norm_config = nn::LayerNormConfig {
eps: 1e-12,
..Default::default()
};
let layer_norm =
nn::layer_norm(p / "LayerNorm", vec![config.hidden_size], layer_norm_config);
let dropout = Dropout::new(config.hidden_dropout_prob);
BertSelfOutput { linear, layer_norm, dropout }
BertSelfOutput {
linear,
layer_norm,
dropout,
}
}
pub fn forward_t(&self, hidden_states: &Tensor, input_tensor: &Tensor, train: bool) -> Tensor {
let hidden_states: Tensor = input_tensor + hidden_states.apply(&self.linear).apply_t(&self.dropout, train);
let hidden_states: Tensor = input_tensor
+ hidden_states
.apply(&self.linear)
.apply_t(&self.dropout, train);
hidden_states.apply(&self.layer_norm)
}
}
@ -141,14 +185,21 @@ impl BertAttention {
BertAttention { _self, output }
}
pub fn forward_t(&self,
hidden_states: &Tensor,
mask: &Option<Tensor>,
encoder_hidden_states: &Option<Tensor>,
encoder_mask: &Option<Tensor>,
train: bool) -> (Tensor, Option<Tensor>) {
let (self_output, attention_weights) = self._self.
forward_t(hidden_states, mask, encoder_hidden_states, encoder_mask, train);
pub fn forward_t(
&self,
hidden_states: &Tensor,
mask: &Option<Tensor>,
encoder_hidden_states: &Option<Tensor>,
encoder_mask: &Option<Tensor>,
train: bool,
) -> (Tensor, Option<Tensor>) {
let (self_output, attention_weights) = self._self.forward_t(
hidden_states,
mask,
encoder_hidden_states,
encoder_mask,
train,
);
let self_output = self.output.forward_t(&self_output, hidden_states, train);
(self_output, attention_weights)
@ -162,11 +213,16 @@ pub struct BertIntermediate {
impl BertIntermediate {
pub fn new(p: &nn::Path, config: &BertConfig) -> BertIntermediate {
let lin = nn::linear(p / "dense", config.hidden_size, config.intermediate_size, Default::default());
let lin = nn::linear(
p / "dense",
config.hidden_size,
config.intermediate_size,
Default::default(),
);
let activation = Box::new(match &config.hidden_act {
Activation::gelu => _gelu,
Activation::relu => _relu,
Activation::mish => _mish
Activation::mish => _mish,
});
BertIntermediate { lin, activation }
}
@ -184,17 +240,30 @@ pub struct BertOutput {
impl BertOutput {
pub fn new(p: &nn::Path, config: &BertConfig) -> BertOutput {
let lin = nn::linear(p / "dense", config.intermediate_size, config.hidden_size, Default::default());
let layer_norm_config = nn::LayerNormConfig { eps: 1e-12, ..Default::default() };
let layer_norm = nn::layer_norm(p / "LayerNorm", vec![config.hidden_size], layer_norm_config);
let lin = nn::linear(
p / "dense",
config.intermediate_size,
config.hidden_size,
Default::default(),
);
let layer_norm_config = nn::LayerNormConfig {
eps: 1e-12,
..Default::default()
};
let layer_norm =
nn::layer_norm(p / "LayerNorm", vec![config.hidden_size], layer_norm_config);
let dropout = Dropout::new(config.hidden_dropout_prob);
BertOutput { lin, layer_norm, dropout }
BertOutput {
lin,
layer_norm,
dropout,
}
}
pub fn forward_t(&self, hidden_states: &Tensor, input_tensor: &Tensor, train: bool) -> Tensor {
let hidden_states: Tensor = input_tensor + hidden_states.apply(&self.lin).apply_t(&self.dropout, train);
let hidden_states: Tensor =
input_tensor + hidden_states.apply(&self.lin).apply_t(&self.dropout, train);
hidden_states.apply(&self.layer_norm)
}
}

File diff suppressed because it is too large Load Diff

View File

@ -11,22 +11,24 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use tch::{nn, Tensor, Kind};
use tch::nn::{EmbeddingConfig, embedding};
use crate::common::dropout::Dropout;
use crate::bert::bert::BertConfig;
use crate::common::dropout::Dropout;
use tch::nn::{embedding, EmbeddingConfig};
use tch::{nn, Kind, Tensor};
/// # BertEmbedding trait (for use in BertModel or RoBERTaModel)
/// Defines an interface for the embedding layers in BERT-based models
pub trait BertEmbedding {
fn new(p: &nn::Path, config: &BertConfig) -> Self;
fn forward_t(&self,
input_ids: Option<Tensor>,
token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool) -> Result<Tensor, &'static str>;
fn forward_t(
&self,
input_ids: Option<Tensor>,
token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool,
) -> Result<Tensor, &'static str>;
}
#[derive(Debug)]
@ -51,10 +53,10 @@ impl BertEmbedding for BertEmbeddings {
/// # Example
///
/// ```no_run
/// use rust_bert::bert::{BertConfig, BertEmbeddings, BertEmbedding};
/// use tch::{nn, Device};
/// use rust_bert::bert::{BertConfig, BertEmbedding, BertEmbeddings};
/// use rust_bert::Config;
/// use std::path::Path;
/// use tch::{nn, Device};
///
/// let config_path = Path::new("path/to/config.json");
/// let device = Device::Cpu;
@ -62,29 +64,47 @@ impl BertEmbedding for BertEmbeddings {
/// let config = BertConfig::from_file(config_path);
/// let bert_embeddings = BertEmbeddings::new(&(&p.root() / "bert_embeddings"), &config);
/// ```
///
fn new(p: &nn::Path, config: &BertConfig) -> BertEmbeddings {
let embedding_config = EmbeddingConfig { padding_idx: 0, ..Default::default() };
let embedding_config = EmbeddingConfig {
padding_idx: 0,
..Default::default()
};
let word_embeddings: nn::Embedding = embedding(p / "word_embeddings",
config.vocab_size,
config.hidden_size,
embedding_config);
let word_embeddings: nn::Embedding = embedding(
p / "word_embeddings",
config.vocab_size,
config.hidden_size,
embedding_config,
);
let position_embeddings: nn::Embedding = embedding(p / "position_embeddings",
config.max_position_embeddings,
config.hidden_size,
Default::default());
let position_embeddings: nn::Embedding = embedding(
p / "position_embeddings",
config.max_position_embeddings,
config.hidden_size,
Default::default(),
);
let token_type_embeddings: nn::Embedding = embedding(p / "token_type_embeddings",
config.type_vocab_size,
config.hidden_size,
Default::default());
let token_type_embeddings: nn::Embedding = embedding(
p / "token_type_embeddings",
config.type_vocab_size,
config.hidden_size,
Default::default(),
);
let layer_norm_config = nn::LayerNormConfig { eps: 1e-12, ..Default::default() };
let layer_norm: nn::LayerNorm = nn::layer_norm(p / "LayerNorm", vec![config.hidden_size], layer_norm_config);
let layer_norm_config = nn::LayerNormConfig {
eps: 1e-12,
..Default::default()
};
let layer_norm: nn::LayerNorm =
nn::layer_norm(p / "LayerNorm", vec![config.hidden_size], layer_norm_config);
let dropout: Dropout = Dropout::new(config.hidden_dropout_prob);
BertEmbeddings { word_embeddings, position_embeddings, token_type_embeddings, layer_norm, dropout }
BertEmbeddings {
word_embeddings,
position_embeddings,
token_type_embeddings,
layer_norm,
dropout,
}
}
/// Forward pass through the embedding layer
@ -104,50 +124,62 @@ impl BertEmbedding for BertEmbeddings {
/// # Example
///
/// ```no_run
///# use rust_bert::bert::{BertConfig, BertEmbeddings, BertEmbedding};
///# use tch::{nn, Device, Tensor, no_grad};
///# use rust_bert::Config;
///# use std::path::Path;
///# use tch::kind::Kind::Int64;
///# let config_path = Path::new("path/to/config.json");
///# let vocab_path = Path::new("path/to/vocab.txt");
///# let device = Device::Cpu;
///# let vs = nn::VarStore::new(device);
///# let config = BertConfig::from_file(config_path);
///# let bert_embeddings = BertEmbeddings::new(&vs.root(), &config);
/// let (batch_size, sequence_length) = (64, 128);
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
/// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
/// let position_ids = Tensor::arange(sequence_length, (Int64, device)).expand(&[batch_size, sequence_length], true);
/// # use rust_bert::bert::{BertConfig, BertEmbeddings, BertEmbedding};
/// # use tch::{nn, Device, Tensor, no_grad};
/// # use rust_bert::Config;
/// # use std::path::Path;
/// # use tch::kind::Kind::Int64;
/// # let config_path = Path::new("path/to/config.json");
/// # let vocab_path = Path::new("path/to/vocab.txt");
/// # let device = Device::Cpu;
/// # let vs = nn::VarStore::new(device);
/// # let config = BertConfig::from_file(config_path);
/// # let bert_embeddings = BertEmbeddings::new(&vs.root(), &config);
/// let (batch_size, sequence_length) = (64, 128);
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
/// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
/// let position_ids = Tensor::arange(sequence_length, (Int64, device))
/// .expand(&[batch_size, sequence_length], true);
///
/// let embedded_output = no_grad(|| {
/// bert_embeddings
/// .forward_t(Some(input_tensor),
/// Some(token_type_ids),
/// Some(position_ids),
/// None,
/// false).unwrap()
/// });
/// let embedded_output = no_grad(|| {
/// bert_embeddings
/// .forward_t(
/// Some(input_tensor),
/// Some(token_type_ids),
/// Some(position_ids),
/// None,
/// false,
/// )
/// .unwrap()
/// });
/// ```
///
fn forward_t(&self,
input_ids: Option<Tensor>,
token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool) -> Result<Tensor, &'static str> {
fn forward_t(
&self,
input_ids: Option<Tensor>,
token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool,
) -> Result<Tensor, &'static str> {
let (input_embeddings, input_shape) = match input_ids {
Some(input_value) => match input_embeds {
Some(_) => { return Err("Only one of input ids or input embeddings may be set"); }
None => (input_value.apply_t(&self.word_embeddings, train), input_value.size())
}
Some(_) => {
return Err("Only one of input ids or input embeddings may be set");
}
None => (
input_value.apply_t(&self.word_embeddings, train),
input_value.size(),
),
},
None => match input_embeds {
Some(embeds) => {
let size = vec!(embeds.size()[0], embeds.size()[1]);
let size = vec![embeds.size()[0], embeds.size()[1]];
(embeds, size)
},
None => { return Err("Only one of input ids or input embeddings may be set"); }
}
}
None => {
return Err("Only one of input ids or input embeddings may be set");
}
},
};
let seq_length = input_embeddings.as_ref().size()[1].to_owned();
@ -155,19 +187,22 @@ impl BertEmbedding for BertEmbeddings {
let position_ids = match position_ids {
Some(value) => value,
None => Tensor::arange(seq_length, (Kind::Int64, input_embeddings.device()))
.unsqueeze(0).
expand(&input_shape, true)
.unsqueeze(0)
.expand(&input_shape, true),
};
let token_type_ids = match token_type_ids {
Some(value) => value,
None => Tensor::zeros(&input_shape, (Kind::Int64, input_embeddings.device()))
None => Tensor::zeros(&input_shape, (Kind::Int64, input_embeddings.device())),
};
let position_embeddings = position_ids.apply(&self.position_embeddings);
let token_type_embeddings = token_type_ids.apply(&self.token_type_embeddings);
let input_embeddings: Tensor = input_embeddings + position_embeddings + token_type_embeddings;
Ok(input_embeddings.apply(&self.layer_norm).apply_t(&self.dropout, train))
let input_embeddings: Tensor =
input_embeddings + position_embeddings + token_type_embeddings;
Ok(input_embeddings
.apply(&self.layer_norm)
.apply_t(&self.dropout, train))
}
}
}

View File

@ -11,10 +11,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use tch::{Tensor, nn};
use crate::bert::attention::{BertAttention, BertIntermediate, BertOutput};
use std::borrow::BorrowMut;
use crate::bert::bert::BertConfig;
use std::borrow::BorrowMut;
use tch::{nn, Tensor};
pub struct BertLayer {
attention: BertAttention,
@ -30,37 +30,57 @@ impl BertLayer {
let (is_decoder, cross_attention) = match config.is_decoder {
Some(value) => {
if value == true {
(value, Some(BertAttention::new(&(p / "cross_attention"), &config)))
(
value,
Some(BertAttention::new(&(p / "cross_attention"), &config)),
)
} else {
(value, None)
}
}
None => (false, None)
None => (false, None),
};
let intermediate = BertIntermediate::new(&(p / "intermediate"), &config);
let output = BertOutput::new(&(p / "output"), &config);
BertLayer { attention, is_decoder, cross_attention, intermediate, output }
BertLayer {
attention,
is_decoder,
cross_attention,
intermediate,
output,
}
}
pub fn forward_t(&self,
hidden_states: &Tensor,
mask: &Option<Tensor>,
encoder_hidden_states: &Option<Tensor>,
encoder_mask: &Option<Tensor>,
train: bool) -> (Tensor, Option<Tensor>, Option<Tensor>) {
let (attention_output, attention_weights, cross_attention_weights) = if self.is_decoder & encoder_hidden_states.is_some() {
let (attention_output, attention_weights) =
self.attention.forward_t(hidden_states, mask, &None, &None, train);
let (attention_output, cross_attention_weights) =
self.cross_attention.as_ref().unwrap().forward_t(&attention_output, mask, encoder_hidden_states, encoder_mask, train);
(attention_output, attention_weights, cross_attention_weights)
} else {
let (attention_output, attention_weights) =
self.attention.forward_t(hidden_states, mask, &None, &None, train);
(attention_output, attention_weights, None)
};
pub fn forward_t(
&self,
hidden_states: &Tensor,
mask: &Option<Tensor>,
encoder_hidden_states: &Option<Tensor>,
encoder_mask: &Option<Tensor>,
train: bool,
) -> (Tensor, Option<Tensor>, Option<Tensor>) {
let (attention_output, attention_weights, cross_attention_weights) =
if self.is_decoder & encoder_hidden_states.is_some() {
let (attention_output, attention_weights) =
self.attention
.forward_t(hidden_states, mask, &None, &None, train);
let (attention_output, cross_attention_weights) =
self.cross_attention.as_ref().unwrap().forward_t(
&attention_output,
mask,
encoder_hidden_states,
encoder_mask,
train,
);
(attention_output, attention_weights, cross_attention_weights)
} else {
let (attention_output, attention_weights) =
self.attention
.forward_t(hidden_states, mask, &None, &None, train);
(attention_output, attention_weights, None)
};
let output = self.intermediate.forward(&attention_output);
let output = self.output.forward_t(&output, &attention_output, train);
@ -78,26 +98,47 @@ pub struct BertEncoder {
impl BertEncoder {
pub fn new(p: &nn::Path, config: &BertConfig) -> BertEncoder {
let p = &(p / "layer");
let output_attentions = if let Some(value) = config.output_attentions { value } else { false };
let output_hidden_states = if let Some(value) = config.output_hidden_states { value } else { false };
let mut layers: Vec<BertLayer> = vec!();
for layer_index in 0..config.num_hidden_layers {
layers.push(BertLayer::new(&(p / layer_index), config));
let output_attentions = if let Some(value) = config.output_attentions {
value
} else {
false
};
let output_hidden_states = if let Some(value) = config.output_hidden_states {
value
} else {
false
};
BertEncoder { output_attentions, output_hidden_states, layers }
let mut layers: Vec<BertLayer> = vec![];
for layer_index in 0..config.num_hidden_layers {
layers.push(BertLayer::new(&(p / layer_index), config));
}
BertEncoder {
output_attentions,
output_hidden_states,
layers,
}
}
pub fn forward_t(&self,
hidden_states: &Tensor,
mask: &Option<Tensor>,
encoder_hidden_states: &Option<Tensor>,
encoder_mask: &Option<Tensor>,
train: bool)
-> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
let mut all_hidden_states: Option<Vec<Tensor>> = if self.output_hidden_states { Some(vec!()) } else { None };
let mut all_attentions: Option<Vec<Tensor>> = if self.output_attentions { Some(vec!()) } else { None };
pub fn forward_t(
&self,
hidden_states: &Tensor,
mask: &Option<Tensor>,
encoder_hidden_states: &Option<Tensor>,
encoder_mask: &Option<Tensor>,
train: bool,
) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
let mut all_hidden_states: Option<Vec<Tensor>> = if self.output_hidden_states {
Some(vec![])
} else {
None
};
let mut all_attentions: Option<Vec<Tensor>> = if self.output_attentions {
Some(vec![])
} else {
None
};
let mut hidden_state = hidden_states.copy();
let mut attention_weights: Option<Tensor>;
@ -109,16 +150,22 @@ impl BertEncoder {
hidden_states.push(hidden_state.as_ref().copy());
};
let temp = layer.forward_t(&hidden_state, &mask, encoder_hidden_states, encoder_mask, train);
let temp = layer.forward_t(
&hidden_state,
&mask,
encoder_hidden_states,
encoder_mask,
train,
);
hidden_state = temp.0;
attention_weights = temp.1;
if let Some(attentions) = all_attentions.borrow_mut() {
attentions.push(attention_weights.as_ref().unwrap().copy());
};
}
None => break
None => break,
};
};
}
(hidden_state, all_hidden_states, all_attentions)
}
@ -130,14 +177,16 @@ pub struct BertPooler {
impl BertPooler {
pub fn new(p: &nn::Path, config: &BertConfig) -> BertPooler {
let lin = nn::linear(&(p / "dense"), config.hidden_size, config.hidden_size, Default::default());
let lin = nn::linear(
&(p / "dense"),
config.hidden_size,
config.hidden_size,
Default::default(),
);
BertPooler { lin }
}
pub fn forward(&self, hidden_states: &Tensor) -> Tensor {
hidden_states
.select(1, 0)
.apply(&self.lin)
.tanh()
hidden_states.select(1, 0).apply(&self.lin).tanh()
}
}
}

View File

@ -19,18 +19,24 @@
//! Pretrained models are available and can be downloaded using RemoteResources.
//!
//! ```no_run
//!# fn main() -> failure::Fallible<()> {
//!#
//! # fn main() -> failure::Fallible<()> {
//! #
//! use rust_tokenizers::BertTokenizer;
//! use tch::{nn, Device};
//!# use std::path::PathBuf;
//! use rust_bert::bert::{BertForMaskedLM, BertConfig};
//! # use std::path::PathBuf;
//! use rust_bert::bert::{BertConfig, BertForMaskedLM};
//! use rust_bert::resources::{download_resource, LocalResource, Resource};
//! use rust_bert::Config;
//! use rust_bert::resources::{Resource, download_resource, LocalResource};
//!
//! let config_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/config.json")});
//! let vocab_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/vocab.txt")});
//! let weights_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/model.ot")});
//! let config_resource = Resource::Local(LocalResource {
//! local_path: PathBuf::from("path/to/config.json"),
//! });
//! let vocab_resource = Resource::Local(LocalResource {
//! local_path: PathBuf::from("path/to/vocab.txt"),
//! });
//! let weights_resource = Resource::Local(LocalResource {
//! local_path: PathBuf::from("path/to/model.ot"),
//! });
//! let config_path = download_resource(&config_resource)?;
//! let vocab_path = download_resource(&vocab_resource)?;
//! let weights_path = download_resource(&weights_resource)?;
@ -41,17 +47,18 @@
//! let bert_model = BertForMaskedLM::new(&vs.root(), &config);
//! vs.load(weights_path)?;
//!
//!# Ok(())
//!# }
//! # Ok(())
//! # }
//! ```
mod attention;
mod bert;
mod embeddings;
mod attention;
pub(crate) mod encoder;
pub use bert::{BertModelResources, BertConfigResources, BertVocabResources,
BertConfig, Activation, BertModel, BertForTokenClassification, BertForMultipleChoice,
BertForMaskedLM, BertForSequenceClassification, BertForQuestionAnswering};
pub use embeddings::{BertEmbedding, BertEmbeddings};
pub use bert::{
Activation, BertConfig, BertConfigResources, BertForMaskedLM, BertForMultipleChoice,
BertForQuestionAnswering, BertForSequenceClassification, BertForTokenClassification, BertModel,
BertModelResources, BertVocabResources,
};
pub use embeddings::{BertEmbedding, BertEmbeddings};

View File

@ -1,14 +1,26 @@
use tch::Tensor;
use std::f64::consts::PI;
use tch::Tensor;
pub fn _gelu(x: &Tensor) -> Tensor { x * 0.5 * (1.0 + (x / ((2.0 as f64).sqrt())).erf()) }
pub fn _gelu(x: &Tensor) -> Tensor {
x * 0.5 * (1.0 + (x / ((2.0 as f64).sqrt())).erf())
}
pub fn _relu(x: &Tensor) -> Tensor { x.relu() }
pub fn _relu(x: &Tensor) -> Tensor {
x.relu()
}
pub fn _swish(x: &Tensor) -> Tensor { x * x.sigmoid() }
pub fn _swish(x: &Tensor) -> Tensor {
x * x.sigmoid()
}
pub fn _mish(x: &Tensor) -> Tensor { x * (x.softplus().tanh()) }
pub fn _mish(x: &Tensor) -> Tensor {
x * (x.softplus().tanh())
}
pub fn _gelu_new(x: &Tensor) -> Tensor { x * 0.5 * (((x.pow(3.0f64) * 0.044715 + x) * ((2f64 / PI).sqrt())).tanh() + 1) }
pub fn _gelu_new(x: &Tensor) -> Tensor {
x * 0.5 * (((x.pow(3.0f64) * 0.044715 + x) * ((2f64 / PI).sqrt())).tanh() + 1)
}
pub fn _tanh(x: &Tensor) -> Tensor { x.tanh() }
pub fn _tanh(x: &Tensor) -> Tensor {
x.tanh()
}

View File

@ -9,15 +9,16 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::path::Path;
use serde::Deserialize;
use std::fs::File;
use std::io::BufReader;
use serde::Deserialize;
use std::path::Path;
/// # Utility to deserialize JSON config files
pub trait Config<T>
where for<'de> T: Deserialize<'de> {
where
for<'de> T: Deserialize<'de>,
{
/// Loads a `Config` object from a JSON file. The format is expected to be aligned with the [Transformers library](https://github.com/huggingface/transformers) configuration files for each model.
/// The parsing will fail if non-optional keys expected by the model are missing.
///
@ -28,18 +29,17 @@ pub trait Config<T>
/// # Example
///
/// ```no_run
/// use rust_bert::gpt2::Gpt2Config;
/// use rust_bert::Config;
/// use std::path::Path;
/// use rust_bert::gpt2::Gpt2Config;
///
/// let config_path = Path::new("path/to/config.json");
/// let config = Gpt2Config::from_file(config_path);
/// ```
///
fn from_file(path: &Path) -> T {
let f = File::open(path).expect("Could not open configuration file.");
let br = BufReader::new(f);
let config: T = serde_json::from_reader(br).expect("could not parse configuration");
config
}
}
}

View File

@ -27,4 +27,4 @@ impl ModuleT for Dropout {
fn forward_t(&self, input: &Tensor, train: bool) -> Tensor {
input.dropout(self.dropout_prob, train)
}
}
}

View File

@ -10,9 +10,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use tch::nn::{Init, Path, Module};
use tch::Tensor;
use std::borrow::Borrow;
use tch::nn::{Init, Module, Path};
use tch::Tensor;
#[derive(Debug, Clone, Copy)]
pub struct LinearNoBiasConfig {
@ -32,7 +32,6 @@ pub struct LinearNoBias {
pub ws: Tensor,
}
pub fn linear_no_bias<'a, T: Borrow<Path<'a>>>(
vs: T,
in_dim: i64,
@ -49,4 +48,4 @@ impl Module for LinearNoBias {
fn forward(&self, xs: &Tensor) -> Tensor {
xs.matmul(&self.ws.tr())
}
}
}

View File

@ -1,7 +1,7 @@
pub mod config;
pub mod resources;
pub(crate) mod dropout;
pub(crate) mod activations;
pub mod config;
pub(crate) mod dropout;
pub(crate) mod linear;
pub mod resources;
pub use config::Config;
pub use config::Config;

View File

@ -18,9 +18,9 @@
//! pre-trained models in each model module.
use lazy_static::lazy_static;
use std::path::PathBuf;
use reqwest::Client;
use std::{fs, env};
use std::path::PathBuf;
use std::{env, fs};
use tokio::prelude::*;
use tokio::runtime::Runtime;
use tokio::task;
@ -47,12 +47,13 @@ impl Resource {
/// # Example
///
/// ```no_run
/// use rust_bert::resources::{Resource, LocalResource};
/// use rust_bert::resources::{LocalResource, Resource};
/// use std::path::PathBuf;
/// let config_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/config.json")});
/// let config_resource = Resource::Local(LocalResource {
/// local_path: PathBuf::from("path/to/config.json"),
/// });
/// let config_path = config_resource.get_local_path();
/// ```
///
pub fn get_local_path(&self) -> &PathBuf {
match self {
Resource::Local(resource) => &resource.local_path,
@ -65,7 +66,7 @@ impl Resource {
#[derive(PartialEq, Clone)]
pub struct LocalResource {
/// Local path for the resource
pub local_path: PathBuf
pub local_path: PathBuf,
}
/// # Remote resource
@ -93,13 +94,18 @@ impl RemoteResource {
/// # Example
///
/// ```no_run
/// use rust_bert::resources::{Resource, RemoteResource};
/// use rust_bert::resources::{RemoteResource, Resource};
/// use std::path::PathBuf;
/// let config_resource = Resource::Remote(RemoteResource::new("http://config_json_location", PathBuf::from("path/to/config.json")));
/// let config_resource = Resource::Remote(RemoteResource::new(
/// "http://config_json_location",
/// PathBuf::from("path/to/config.json"),
/// ));
/// ```
///
pub fn new(url: &str, target: PathBuf) -> RemoteResource {
RemoteResource { url: url.to_string(), local_path: target }
RemoteResource {
url: url.to_string(),
local_path: target,
}
}
/// Creates a new RemoteResource from an URL and local name. Will define a local path pointing to
@ -117,14 +123,12 @@ impl RemoteResource {
/// # Example
///
/// ```no_run
/// use rust_bert::resources::{Resource, RemoteResource};
/// let model_resource = Resource::Remote(RemoteResource::from_pretrained(
/// ("distilbert-sst2/model.ot",
/// "https://cdn.huggingface.co/distilbert-base-uncased-finetuned-sst-2-english-rust_model.ot"
/// )
/// ));
/// use rust_bert::resources::{RemoteResource, Resource};
/// let model_resource = Resource::Remote(RemoteResource::from_pretrained((
/// "distilbert-sst2/model.ot",
/// "https://cdn.huggingface.co/distilbert-base-uncased-finetuned-sst-2-english-rust_model.ot",
/// )));
/// ```
///
pub fn from_pretrained(name_url_tuple: (&str, &str)) -> RemoteResource {
let name = name_url_tuple.0;
let url = name_url_tuple.1.to_string();
@ -171,15 +175,13 @@ fn _get_cache_directory() -> PathBuf {
/// # Example
///
/// ```no_run
/// use rust_bert::resources::{Resource, RemoteResource, download_resource};
/// let model_resource = Resource::Remote(RemoteResource::from_pretrained(
/// ("distilbert-sst2/model.ot",
/// "https://cdn.huggingface.co/distilbert-base-uncased-finetuned-sst-2-english-rust_model.ot"
/// )
/// ));
/// use rust_bert::resources::{download_resource, RemoteResource, Resource};
/// let model_resource = Resource::Remote(RemoteResource::from_pretrained((
/// "distilbert-sst2/model.ot",
/// "https://cdn.huggingface.co/distilbert-base-uncased-finetuned-sst-2-english-rust_model.ot",
/// )));
/// let local_path = download_resource(&model_resource);
/// ```
///
pub fn download_resource(resource: &Resource) -> failure::Fallible<&PathBuf> {
match resource {
Resource::Remote(remote_resource) => {
@ -202,8 +204,6 @@ pub fn download_resource(resource: &Resource) -> failure::Fallible<&PathBuf> {
Ok(resource.get_local_path())
}
Resource::Local(_) => {
Ok(resource.get_local_path())
}
Resource::Local(_) => Ok(resource.get_local_path()),
}
}
}

View File

@ -16,7 +16,11 @@ extern crate tch;
pub fn main() -> failure::Fallible<()> {
let args: Vec<_> = std::env::args().collect();
ensure!(args.len() == 3, "usage: {} source.npz destination.ot", args[0]);
ensure!(
args.len() == 3,
"usage: {} source.npz destination.ot",
args[0]
);
let source_file = &args[1];
let destination_file = &args[2];
@ -24,4 +28,4 @@ pub fn main() -> failure::Fallible<()> {
tch::Tensor::save_multi(&tensors, destination_file)?;
Ok(())
}
}

View File

@ -10,11 +10,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use tch::{nn, Tensor};
use crate::common::dropout::Dropout;
use crate::distilbert::distilbert::DistilBertConfig;
use tch::kind::Kind::Float;
use crate::common::dropout::Dropout;
use tch::{nn, Tensor};
#[derive(Debug)]
pub struct MultiHeadSelfAttention {
@ -39,7 +38,7 @@ impl MultiHeadSelfAttention {
let output_attentions = match config.output_attentions {
Some(value) => value,
None => false
None => false,
};
MultiHeadSelfAttention {
@ -59,10 +58,19 @@ impl MultiHeadSelfAttention {
}
fn flatten(&self, x: Tensor, bs: i64, dim_per_head: i64) -> Tensor {
x.transpose(1, 2).contiguous().view((bs, -1, &self.n_heads * dim_per_head))
x.transpose(1, 2)
.contiguous()
.view((bs, -1, &self.n_heads * dim_per_head))
}
pub fn forward_t(&self, query: &Tensor, key: &Tensor, value: &Tensor, mask: &Option<Tensor>, train: bool) -> (Tensor, Option<Tensor>) {
pub fn forward_t(
&self,
query: &Tensor,
key: &Tensor,
value: &Tensor,
mask: &Option<Tensor>,
train: bool,
) -> (Tensor, Option<Tensor>) {
let bs = query.size()[0];
let k_length = key.size()[1];
@ -73,14 +81,19 @@ impl MultiHeadSelfAttention {
let scores = if let Some(mask) = mask {
let unmasked_scores = q.matmul(&k.transpose(2, 3));
let mask = mask.le1(&(mask.zeros_like() + 0.1)).view((bs, 1i64, 1i64, k_length)).expand_as(&unmasked_scores);
let mask = mask
.le1(&(mask.zeros_like() + 0.1))
.view((bs, 1i64, 1i64, k_length))
.expand_as(&unmasked_scores);
unmasked_scores.masked_fill(&mask, std::f64::NEG_INFINITY)
} else {
q.matmul(&k.transpose(2, 3))
};
let weights = scores.softmax(-1, Float).apply_t(&self.dropout, train);
let context = self.flatten(weights.matmul(&v), bs, self.dim_per_head).apply(&self.out_lin);
let context = self
.flatten(weights.matmul(&v), bs, self.dim_per_head)
.apply(&self.out_lin);
if !self.output_attentions {
(context, None)

View File

@ -12,13 +12,13 @@
extern crate tch;
use std::collections::HashMap;
use serde::{Deserialize, Serialize};
use crate::distilbert::embeddings::DistilBertEmbedding;
use crate::distilbert::transformer::Transformer;
use self::tch::{nn, Tensor};
use crate::common::dropout::Dropout;
use crate::distilbert::embeddings::DistilBertEmbedding;
use crate::distilbert::transformer::Transformer;
use crate::Config;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
/// # DistilBERT Pretrained model weight files
pub struct DistilBertModelResources;
@ -31,29 +31,56 @@ pub struct DistilBertVocabResources;
impl DistilBertModelResources {
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
pub const DISTIL_BERT_SST2: (&'static str, &'static str) = ("distilbert-sst2/model.ot", "https://cdn.huggingface.co/distilbert-base-uncased-finetuned-sst-2-english-rust_model.ot");
pub const DISTIL_BERT_SST2: (&'static str, &'static str) = (
"distilbert-sst2/model.ot",
"https://cdn.huggingface.co/distilbert-base-uncased-finetuned-sst-2-english-rust_model.ot",
);
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
pub const DISTIL_BERT: (&'static str, &'static str) = ("distilbert/model.ot", "https://cdn.huggingface.co/distilbert-base-uncased-rust_model.ot");
pub const DISTIL_BERT: (&'static str, &'static str) = (
"distilbert/model.ot",
"https://cdn.huggingface.co/distilbert-base-uncased-rust_model.ot",
);
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
pub const DISTIL_BERT_SQUAD: (&'static str, &'static str) = ("distilbert-qa/model.ot", "https://cdn.huggingface.co/distilbert-base-cased-distilled-squad-rust_model.ot");
pub const DISTIL_BERT_SQUAD: (&'static str, &'static str) = (
"distilbert-qa/model.ot",
"https://cdn.huggingface.co/distilbert-base-cased-distilled-squad-rust_model.ot",
);
}
impl DistilBertConfigResources {
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
pub const DISTIL_BERT_SST2: (&'static str, &'static str) = ("distilbert-sst2/config.json", "https://cdn.huggingface.co/distilbert-base-uncased-finetuned-sst-2-english-config.json");
pub const DISTIL_BERT_SST2: (&'static str, &'static str) = (
"distilbert-sst2/config.json",
"https://cdn.huggingface.co/distilbert-base-uncased-finetuned-sst-2-english-config.json",
);
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
pub const DISTIL_BERT: (&'static str, &'static str) = ("distilbert/config.json", "https://cdn.huggingface.co/distilbert-base-uncased-config.json");
pub const DISTIL_BERT: (&'static str, &'static str) = (
"distilbert/config.json",
"https://cdn.huggingface.co/distilbert-base-uncased-config.json",
);
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
pub const DISTIL_BERT_SQUAD: (&'static str, &'static str) = ("distilbert-qa/config.json", "https://cdn.huggingface.co/distilbert-base-cased-distilled-squad-config.json");
pub const DISTIL_BERT_SQUAD: (&'static str, &'static str) = (
"distilbert-qa/config.json",
"https://cdn.huggingface.co/distilbert-base-cased-distilled-squad-config.json",
);
}
impl DistilBertVocabResources {
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
pub const DISTIL_BERT_SST2: (&'static str, &'static str) = ("distilbert-sst2/vocab.txt", "https://cdn.huggingface.co/distilbert-base-uncased-finetuned-sst-2-english-vocab.txt");
pub const DISTIL_BERT_SST2: (&'static str, &'static str) = (
"distilbert-sst2/vocab.txt",
"https://cdn.huggingface.co/distilbert-base-uncased-finetuned-sst-2-english-vocab.txt",
);
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
pub const DISTIL_BERT: (&'static str, &'static str) = ("distilbert/vocab.txt", "https://cdn.huggingface.co/bert-base-uncased-vocab.txt");
pub const DISTIL_BERT: (&'static str, &'static str) = (
"distilbert/vocab.txt",
"https://cdn.huggingface.co/bert-base-uncased-vocab.txt",
);
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
pub const DISTIL_BERT_SQUAD: (&'static str, &'static str) = ("distilbert-qa/vocab.txt", "https://cdn.huggingface.co/bert-large-cased-vocab.txt");
pub const DISTIL_BERT_SQUAD: (&'static str, &'static str) = (
"distilbert-qa/vocab.txt",
"https://cdn.huggingface.co/bert-large-cased-vocab.txt",
);
}
#[allow(non_camel_case_types)]
@ -118,10 +145,10 @@ impl DistilBertModel {
/// # Example
///
/// ```no_run
/// use tch::{nn, Device};
/// use rust_bert::distilbert::{DistilBertConfig, DistilBertModel};
/// use rust_bert::Config;
/// use std::path::Path;
/// use rust_bert::distilbert::{DistilBertConfig, DistilBertModel};
/// use tch::{nn, Device};
///
/// let config_path = Path::new("path/to/config.json");
/// let device = Device::Cpu;
@ -129,12 +156,14 @@ impl DistilBertModel {
/// let config = DistilBertConfig::from_file(config_path);
/// let distil_bert: DistilBertModel = DistilBertModel::new(&(&p.root() / "distilbert"), &config);
/// ```
///
pub fn new(p: &nn::Path, config: &DistilBertConfig) -> DistilBertModel {
let p = &(p / "distilbert");
let embeddings = DistilBertEmbedding::new(&(p / "embeddings"), config);
let transformer = Transformer::new(&(p / "transformer"), config);
DistilBertModel { embeddings, transformer }
DistilBertModel {
embeddings,
transformer,
}
}
/// Forward pass through the model
@ -155,45 +184,49 @@ impl DistilBertModel {
/// # Example
///
/// ```no_run
///# use tch::{nn, Device, Tensor, no_grad};
///# use rust_bert::Config;
///# use std::path::Path;
///# use tch::kind::Kind::Int64;
/// # use tch::{nn, Device, Tensor, no_grad};
/// # use rust_bert::Config;
/// # use std::path::Path;
/// # use tch::kind::Kind::Int64;
/// use rust_bert::distilbert::{DistilBertConfig, DistilBertModel};
///# let config_path = Path::new("path/to/config.json");
///# let vocab_path = Path::new("path/to/vocab.txt");
///# let device = Device::Cpu;
///# let vs = nn::VarStore::new(device);
///# let config = DistilBertConfig::from_file(config_path);
///# let distilbert_model: DistilBertModel = DistilBertModel::new(&vs.root(), &config);
/// let (batch_size, sequence_length) = (64, 128);
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
///
/// let (output, all_hidden_states, all_attentions) = no_grad(|| {
/// distilbert_model
/// .forward_t(Some(input_tensor),
/// Some(mask),
/// None,
/// false).unwrap()
/// });
/// # let config_path = Path::new("path/to/config.json");
/// # let vocab_path = Path::new("path/to/vocab.txt");
/// # let device = Device::Cpu;
/// # let vs = nn::VarStore::new(device);
/// # let config = DistilBertConfig::from_file(config_path);
/// # let distilbert_model: DistilBertModel = DistilBertModel::new(&vs.root(), &config);
/// let (batch_size, sequence_length) = (64, 128);
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
///
/// let (output, all_hidden_states, all_attentions) = no_grad(|| {
/// distilbert_model
/// .forward_t(Some(input_tensor), Some(mask), None, false)
/// .unwrap()
/// });
/// ```
///
pub fn forward_t(&self, input: Option<Tensor>, mask: Option<Tensor>, input_embeds: Option<Tensor>, train: bool)
-> Result<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
pub fn forward_t(
&self,
input: Option<Tensor>,
mask: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool,
) -> Result<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
let input_embeddings = match input {
Some(input_value) => match input_embeds {
Some(_) => { return Err("Only one of input ids or input embeddings may be set"); }
None => input_value.apply_t(&self.embeddings, train)
}
Some(_) => {
return Err("Only one of input ids or input embeddings may be set");
}
None => input_value.apply_t(&self.embeddings, train),
},
None => match input_embeds {
Some(embeds) => embeds,
None => { return Err("At least one of input ids or input embeddings must be set"); }
}
None => {
return Err("At least one of input ids or input embeddings must be set");
}
},
};
let transformer_output = (&self.transformer).forward_t(&input_embeddings, mask, train);
Ok(transformer_output)
}
@ -223,28 +256,47 @@ impl DistilBertModelClassifier {
/// # Example
///
/// ```no_run
/// use tch::{nn, Device};
/// use rust_bert::distilbert::{DistilBertConfig, DistilBertModelClassifier};
/// use rust_bert::Config;
/// use std::path::Path;
/// use rust_bert::distilbert::{DistilBertConfig, DistilBertModelClassifier};
/// use tch::{nn, Device};
///
/// let config_path = Path::new("path/to/config.json");
/// let device = Device::Cpu;
/// let p = nn::VarStore::new(device);
/// let config = DistilBertConfig::from_file(config_path);
/// let distil_bert: DistilBertModelClassifier = DistilBertModelClassifier::new(&(&p.root() / "distilbert"), &config);
/// let distil_bert: DistilBertModelClassifier =
/// DistilBertModelClassifier::new(&(&p.root() / "distilbert"), &config);
/// ```
///
pub fn new(p: &nn::Path, config: &DistilBertConfig) -> DistilBertModelClassifier {
let distil_bert_model = DistilBertModel::new(&p, config);
let num_labels = config.id2label.as_ref().expect("id2label must be provided for classifiers").len() as i64;
let num_labels = config
.id2label
.as_ref()
.expect("id2label must be provided for classifiers")
.len() as i64;
let pre_classifier = nn::linear(&(p / "pre_classifier"), config.dim, config.dim, Default::default());
let classifier = nn::linear(&(p / "classifier"), config.dim, num_labels, Default::default());
let pre_classifier = nn::linear(
&(p / "pre_classifier"),
config.dim,
config.dim,
Default::default(),
);
let classifier = nn::linear(
&(p / "classifier"),
config.dim,
num_labels,
Default::default(),
);
let dropout = Dropout::new(config.seq_classif_dropout);
DistilBertModelClassifier { distil_bert_model, pre_classifier, classifier, dropout }
DistilBertModelClassifier {
distil_bert_model,
pre_classifier,
classifier,
dropout,
}
}
/// Forward pass through the model
@ -265,17 +317,17 @@ impl DistilBertModelClassifier {
/// # Example
///
/// ```no_run
///# use tch::{nn, Device, Tensor, no_grad};
///# use rust_bert::Config;
///# use std::path::Path;
///# use tch::kind::Kind::Int64;
/// # use tch::{nn, Device, Tensor, no_grad};
/// # use rust_bert::Config;
/// # use std::path::Path;
/// # use tch::kind::Kind::Int64;
/// use rust_bert::distilbert::{DistilBertConfig, DistilBertModelClassifier};
///# let config_path = Path::new("path/to/config.json");
///# let vocab_path = Path::new("path/to/vocab.txt");
///# let device = Device::Cpu;
///# let vs = nn::VarStore::new(device);
///# let config = DistilBertConfig::from_file(config_path);
///# let distilbert_model: DistilBertModelClassifier = DistilBertModelClassifier::new(&vs.root(), &config);
/// # let config_path = Path::new("path/to/config.json");
/// # let vocab_path = Path::new("path/to/vocab.txt");
/// # let device = Device::Cpu;
/// # let vs = nn::VarStore::new(device);
/// # let config = DistilBertConfig::from_file(config_path);
/// # let distilbert_model: DistilBertModelClassifier = DistilBertModelClassifier::new(&vs.root(), &config);
/// let (batch_size, sequence_length) = (64, 128);
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
@ -287,15 +339,22 @@ impl DistilBertModelClassifier {
/// None,
/// false).unwrap()
/// });
///
/// ```
///
pub fn forward_t(&self, input: Option<Tensor>, mask: Option<Tensor>, input_embeds: Option<Tensor>, train: bool)
-> Result<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
let (output, all_hidden_states, all_attentions) = match self.distil_bert_model.forward_t(input, mask, input_embeds, train) {
Ok(value) => value,
Err(err) => return Err(err)
};
pub fn forward_t(
&self,
input: Option<Tensor>,
mask: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool,
) -> Result<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
let (output, all_hidden_states, all_attentions) =
match self
.distil_bert_model
.forward_t(input, mask, input_embeds, train)
{
Ok(value) => value,
Err(err) => return Err(err),
};
let output = output
.select(1, 0)
@ -322,7 +381,6 @@ pub struct DistilBertModelMaskedLM {
vocab_projector: nn::Linear,
}
impl DistilBertModelMaskedLM {
/// Build a new `DistilBertModelMaskedLM` for sequence classification
///
@ -334,10 +392,10 @@ impl DistilBertModelMaskedLM {
/// # Example
///
/// ```no_run
/// use tch::{nn, Device};
/// use rust_bert::distilbert::{DistilBertConfig, DistilBertModelMaskedLM};
/// use rust_bert::Config;
/// use std::path::Path;
/// use rust_bert::distilbert::{DistilBertConfig, DistilBertModelMaskedLM};
/// use tch::{nn, Device};
///
/// let config_path = Path::new("path/to/config.json");
/// let device = Device::Cpu;
@ -345,15 +403,33 @@ impl DistilBertModelMaskedLM {
/// let config = DistilBertConfig::from_file(config_path);
/// let distil_bert = DistilBertModelMaskedLM::new(&(&p.root() / "distilbert"), &config);
/// ```
///
pub fn new(p: &nn::Path, config: &DistilBertConfig) -> DistilBertModelMaskedLM {
let distil_bert_model = DistilBertModel::new(&p, config);
let vocab_transform = nn::linear(&(p / "vocab_transform"), config.dim, config.dim, Default::default());
let layer_norm_config = nn::LayerNormConfig { eps: 1e-12, ..Default::default() };
let vocab_layer_norm = nn::layer_norm(p / "vocab_layer_norm", vec![config.dim], layer_norm_config);
let vocab_projector = nn::linear(&(p / "vocab_projector"), config.dim, config.vocab_size, Default::default());
let vocab_transform = nn::linear(
&(p / "vocab_transform"),
config.dim,
config.dim,
Default::default(),
);
let layer_norm_config = nn::LayerNormConfig {
eps: 1e-12,
..Default::default()
};
let vocab_layer_norm =
nn::layer_norm(p / "vocab_layer_norm", vec![config.dim], layer_norm_config);
let vocab_projector = nn::linear(
&(p / "vocab_projector"),
config.dim,
config.vocab_size,
Default::default(),
);
DistilBertModelMaskedLM { distil_bert_model, vocab_transform, vocab_layer_norm, vocab_projector }
DistilBertModelMaskedLM {
distil_bert_model,
vocab_transform,
vocab_layer_norm,
vocab_projector,
}
}
/// Forward pass through the model
@ -374,37 +450,42 @@ impl DistilBertModelMaskedLM {
/// # Example
///
/// ```no_run
///# use tch::{nn, Device, Tensor, no_grad};
///# use rust_bert::Config;
///# use std::path::Path;
///# use tch::kind::Kind::Int64;
/// # use tch::{nn, Device, Tensor, no_grad};
/// # use rust_bert::Config;
/// # use std::path::Path;
/// # use tch::kind::Kind::Int64;
/// use rust_bert::distilbert::{DistilBertConfig, DistilBertModelMaskedLM};
///# let config_path = Path::new("path/to/config.json");
///# let vocab_path = Path::new("path/to/vocab.txt");
///# let device = Device::Cpu;
///# let vs = nn::VarStore::new(device);
///# let config = DistilBertConfig::from_file(config_path);
///# let distilbert_model = DistilBertModelMaskedLM::new(&vs.root(), &config);
/// let (batch_size, sequence_length) = (64, 128);
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
///
/// let (output, all_hidden_states, all_attentions) = no_grad(|| {
/// distilbert_model
/// .forward_t(Some(input_tensor),
/// Some(mask),
/// None,
/// false).unwrap()
/// });
/// # let config_path = Path::new("path/to/config.json");
/// # let vocab_path = Path::new("path/to/vocab.txt");
/// # let device = Device::Cpu;
/// # let vs = nn::VarStore::new(device);
/// # let config = DistilBertConfig::from_file(config_path);
/// # let distilbert_model = DistilBertModelMaskedLM::new(&vs.root(), &config);
/// let (batch_size, sequence_length) = (64, 128);
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
///
/// let (output, all_hidden_states, all_attentions) = no_grad(|| {
/// distilbert_model
/// .forward_t(Some(input_tensor), Some(mask), None, false)
/// .unwrap()
/// });
/// ```
///
pub fn forward_t(&self, input: Option<Tensor>, mask: Option<Tensor>, input_embeds: Option<Tensor>, train: bool)
-> Result<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
let (output, all_hidden_states, all_attentions) = match self.distil_bert_model.forward_t(input, mask, input_embeds, train) {
Ok(value) => value,
Err(err) => return Err(err)
};
pub fn forward_t(
&self,
input: Option<Tensor>,
mask: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool,
) -> Result<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
let (output, all_hidden_states, all_attentions) =
match self
.distil_bert_model
.forward_t(input, mask, input_embeds, train)
{
Ok(value) => value,
Err(err) => return Err(err),
};
let output = output
.apply(&self.vocab_transform)
@ -440,10 +521,10 @@ impl DistilBertForQuestionAnswering {
/// # Example
///
/// ```no_run
/// use tch::{nn, Device};
/// use rust_bert::distilbert::{DistilBertConfig, DistilBertForQuestionAnswering};
/// use rust_bert::Config;
/// use std::path::Path;
/// use rust_bert::distilbert::{DistilBertConfig, DistilBertForQuestionAnswering};
/// use tch::{nn, Device};
///
/// let config_path = Path::new("path/to/config.json");
/// let device = Device::Cpu;
@ -451,13 +532,16 @@ impl DistilBertForQuestionAnswering {
/// let config = DistilBertConfig::from_file(config_path);
/// let distil_bert = DistilBertForQuestionAnswering::new(&(&p.root() / "distilbert"), &config);
/// ```
///
pub fn new(p: &nn::Path, config: &DistilBertConfig) -> DistilBertForQuestionAnswering {
let distil_bert_model = DistilBertModel::new(&p, config);
let qa_outputs = nn::linear(&(p / "qa_outputs"), config.dim, 2, Default::default());
let dropout = Dropout::new(config.qa_dropout);
DistilBertForQuestionAnswering { distil_bert_model, qa_outputs, dropout }
DistilBertForQuestionAnswering {
distil_bert_model,
qa_outputs,
dropout,
}
}
/// Forward pass through the model
@ -479,52 +563,50 @@ impl DistilBertForQuestionAnswering {
/// # Example
///
/// ```no_run
///# use tch::{nn, Device, Tensor, no_grad};
///# use rust_bert::Config;
///# use std::path::Path;
///# use tch::kind::Kind::Int64;
/// # use tch::{nn, Device, Tensor, no_grad};
/// # use rust_bert::Config;
/// # use std::path::Path;
/// # use tch::kind::Kind::Int64;
/// use rust_bert::distilbert::{DistilBertConfig, DistilBertForQuestionAnswering};
///# let config_path = Path::new("path/to/config.json");
///# let vocab_path = Path::new("path/to/vocab.txt");
///# let device = Device::Cpu;
///# let vs = nn::VarStore::new(device);
///# let config = DistilBertConfig::from_file(config_path);
///# let distilbert_model = DistilBertForQuestionAnswering::new(&vs.root(), &config);
/// let (batch_size, sequence_length) = (64, 128);
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
///
/// let (start_scores, end_score, all_hidden_states, all_attentions) = no_grad(|| {
/// distilbert_model
/// .forward_t(Some(input_tensor),
/// Some(mask),
/// None,
/// false).unwrap()
/// });
/// # let config_path = Path::new("path/to/config.json");
/// # let vocab_path = Path::new("path/to/vocab.txt");
/// # let device = Device::Cpu;
/// # let vs = nn::VarStore::new(device);
/// # let config = DistilBertConfig::from_file(config_path);
/// # let distilbert_model = DistilBertForQuestionAnswering::new(&vs.root(), &config);
/// let (batch_size, sequence_length) = (64, 128);
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
///
/// let (start_scores, end_score, all_hidden_states, all_attentions) = no_grad(|| {
/// distilbert_model
/// .forward_t(Some(input_tensor), Some(mask), None, false)
/// .unwrap()
/// });
/// ```
///
pub fn forward_t(&self,
input: Option<Tensor>,
mask: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool)
-> Result<(Tensor, Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
let (output, all_hidden_states, all_attentions) = match self.distil_bert_model.forward_t(input, mask, input_embeds, train) {
Ok(value) => value,
Err(err) => return Err(err)
};
pub fn forward_t(
&self,
input: Option<Tensor>,
mask: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool,
) -> Result<(Tensor, Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
let (output, all_hidden_states, all_attentions) =
match self
.distil_bert_model
.forward_t(input, mask, input_embeds, train)
{
Ok(value) => value,
Err(err) => return Err(err),
};
let output = output
.apply_t(&self.dropout, train)
.apply(&self.qa_outputs);
let output = output.apply_t(&self.dropout, train).apply(&self.qa_outputs);
let logits = output.split(1, -1);
let (start_logits, end_logits) = (&logits[0], &logits[1]);
let start_logits = start_logits.squeeze1(-1);
let end_logits = end_logits.squeeze1(-1);
Ok((start_logits, end_logits, all_hidden_states, all_attentions))
}
}
@ -552,10 +634,10 @@ impl DistilBertForTokenClassification {
/// # Example
///
/// ```no_run
/// use tch::{nn, Device};
/// use rust_bert::distilbert::{DistilBertConfig, DistilBertForTokenClassification};
/// use rust_bert::Config;
/// use std::path::Path;
/// use rust_bert::distilbert::{DistilBertConfig, DistilBertForTokenClassification};
/// use tch::{nn, Device};
///
/// let config_path = Path::new("path/to/config.json");
/// let device = Device::Cpu;
@ -563,16 +645,28 @@ impl DistilBertForTokenClassification {
/// let config = DistilBertConfig::from_file(config_path);
/// let distil_bert = DistilBertForTokenClassification::new(&(&p.root() / "distilbert"), &config);
/// ```
///
pub fn new(p: &nn::Path, config: &DistilBertConfig) -> DistilBertForTokenClassification {
let distil_bert_model = DistilBertModel::new(&p, config);
let num_labels = config.id2label.as_ref().expect("id2label must be provided for classifiers").len() as i64;
let num_labels = config
.id2label
.as_ref()
.expect("id2label must be provided for classifiers")
.len() as i64;
let classifier = nn::linear(&(p / "classifier"), config.dim, num_labels, Default::default());
let classifier = nn::linear(
&(p / "classifier"),
config.dim,
num_labels,
Default::default(),
);
let dropout = Dropout::new(config.seq_classif_dropout);
DistilBertForTokenClassification { distil_bert_model, classifier, dropout }
DistilBertForTokenClassification {
distil_bert_model,
classifier,
dropout,
}
}
/// Forward pass through the model
@ -593,41 +687,44 @@ impl DistilBertForTokenClassification {
/// # Example
///
/// ```no_run
///# use tch::{nn, Device, Tensor, no_grad};
///# use rust_bert::Config;
///# use std::path::Path;
///# use tch::kind::Kind::Int64;
/// # use tch::{nn, Device, Tensor, no_grad};
/// # use rust_bert::Config;
/// # use std::path::Path;
/// # use tch::kind::Kind::Int64;
/// use rust_bert::distilbert::{DistilBertConfig, DistilBertForTokenClassification};
///# let config_path = Path::new("path/to/config.json");
///# let vocab_path = Path::new("path/to/vocab.txt");
///# let device = Device::Cpu;
///# let vs = nn::VarStore::new(device);
///# let config = DistilBertConfig::from_file(config_path);
///# let distilbert_model = DistilBertForTokenClassification::new(&vs.root(), &config);
/// let (batch_size, sequence_length) = (64, 128);
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
///
/// let (output, all_hidden_states, all_attentions) = no_grad(|| {
/// distilbert_model
/// .forward_t(Some(input_tensor),
/// Some(mask),
/// None,
/// false).unwrap()
/// });
/// # let config_path = Path::new("path/to/config.json");
/// # let vocab_path = Path::new("path/to/vocab.txt");
/// # let device = Device::Cpu;
/// # let vs = nn::VarStore::new(device);
/// # let config = DistilBertConfig::from_file(config_path);
/// # let distilbert_model = DistilBertForTokenClassification::new(&vs.root(), &config);
/// let (batch_size, sequence_length) = (64, 128);
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
///
/// let (output, all_hidden_states, all_attentions) = no_grad(|| {
/// distilbert_model
/// .forward_t(Some(input_tensor), Some(mask), None, false)
/// .unwrap()
/// });
/// ```
///
pub fn forward_t(&self, input: Option<Tensor>, mask: Option<Tensor>, input_embeds: Option<Tensor>, train: bool)
-> Result<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
let (output, all_hidden_states, all_attentions) = match self.distil_bert_model.forward_t(input, mask, input_embeds, train) {
Ok(value) => value,
Err(err) => return Err(err)
};
pub fn forward_t(
&self,
input: Option<Tensor>,
mask: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool,
) -> Result<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
let (output, all_hidden_states, all_attentions) =
match self
.distil_bert_model
.forward_t(input, mask, input_embeds, train)
{
Ok(value) => value,
Err(err) => return Err(err),
};
let output = output
.apply_t(&self.dropout, train)
.apply(&self.classifier);
let output = output.apply_t(&self.dropout, train).apply(&self.classifier);
Ok((output, all_hidden_states, all_attentions))
}

View File

@ -10,22 +10,26 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use tch::{nn, Tensor, Kind, Device};
use tch::nn::{ModuleT, embedding, EmbeddingConfig};
use crate::distilbert::distilbert::DistilBertConfig;
use crate::common::dropout::Dropout;
use crate::distilbert::distilbert::DistilBertConfig;
use tch::kind::Kind::Float;
use tch::nn::{embedding, EmbeddingConfig, ModuleT};
use tch::{nn, Device, Kind, Tensor};
fn create_sinusoidal_embeddings(config: &DistilBertConfig, device: Device) -> nn::Embedding {
let mut sinusoidal_embedding: Vec<Tensor> = Vec::with_capacity(config.max_position_embeddings as usize);
let mut sinusoidal_embedding: Vec<Tensor> =
Vec::with_capacity(config.max_position_embeddings as usize);
for pos in 0..config.max_position_embeddings {
let mut temp_vec: Vec<f64> = Vec::with_capacity(config.dim as usize);
for j in 0..config.dim {
if j % 2 == 0 {
temp_vec.push((pos as f64 / 10000f64.powf((2 * (j / 2)) as f64 / config.dim as f64)).sin());
temp_vec.push(
(pos as f64 / 10000f64.powf((2 * (j / 2)) as f64 / config.dim as f64)).sin(),
);
} else {
temp_vec.push((pos as f64 / 10000f64.powf((2 * (j / 2)) as f64 / config.dim as f64)).cos());
temp_vec.push(
(pos as f64 / 10000f64.powf((2 * (j / 2)) as f64 / config.dim as f64)).cos(),
);
}
}
let temp_vec = Tensor::of_slice(&temp_vec);
@ -35,17 +39,21 @@ fn create_sinusoidal_embeddings(config: &DistilBertConfig, device: Device) -> nn
.to_kind(Float)
.to_device(device);
let embedding_config = EmbeddingConfig { padding_idx: 0, ..Default::default() };
let mut embeddings = embedding(&nn::VarStore::new(device).root(),
config.max_position_embeddings,
config.dim,
embedding_config);
let embedding_config = EmbeddingConfig {
padding_idx: 0,
..Default::default()
};
let mut embeddings = embedding(
&nn::VarStore::new(device).root(),
config.max_position_embeddings,
config.dim,
embedding_config,
);
embeddings.ws = sinusoidal_embedding;
embeddings
}
#[derive(Debug)]
pub struct DistilBertEmbedding {
word_embeddings: nn::Embedding,
@ -56,24 +64,40 @@ pub struct DistilBertEmbedding {
impl DistilBertEmbedding {
pub fn new(p: &nn::Path, config: &DistilBertConfig) -> DistilBertEmbedding {
let embedding_config = EmbeddingConfig { padding_idx: 0, ..Default::default() };
let word_embeddings: nn::Embedding = embedding(p / "word_embeddings",
config.vocab_size,
config.dim,
embedding_config);
let position_embeddings: nn::Embedding = match config.sinusoidal_pos_embds {
false => embedding(p / "position_embeddings",
config.max_position_embeddings,
config.dim,
embedding_config),
true => create_sinusoidal_embeddings(&config, p.device())
let embedding_config = EmbeddingConfig {
padding_idx: 0,
..Default::default()
};
let layer_norm_config = nn::LayerNormConfig { eps: 1e-12, ..Default::default() };
let layer_norm: nn::LayerNorm = nn::layer_norm(p / "LayerNorm", vec![config.dim], layer_norm_config);
let word_embeddings: nn::Embedding = embedding(
p / "word_embeddings",
config.vocab_size,
config.dim,
embedding_config,
);
let position_embeddings: nn::Embedding = match config.sinusoidal_pos_embds {
false => embedding(
p / "position_embeddings",
config.max_position_embeddings,
config.dim,
embedding_config,
),
true => create_sinusoidal_embeddings(&config, p.device()),
};
let layer_norm_config = nn::LayerNormConfig {
eps: 1e-12,
..Default::default()
};
let layer_norm: nn::LayerNorm =
nn::layer_norm(p / "LayerNorm", vec![config.dim], layer_norm_config);
let dropout: Dropout = Dropout::new(config.dropout);
DistilBertEmbedding { word_embeddings, position_embeddings, layer_norm, dropout }
DistilBertEmbedding {
word_embeddings,
position_embeddings,
layer_norm,
dropout,
}
}
pub fn _get_word_embeddings(&self) -> &nn::Embedding {
@ -94,10 +118,12 @@ impl ModuleT for DistilBertEmbedding {
let word_embed = input.apply(&self.word_embeddings);
let position_embed = position_ids.apply(&self.position_embeddings);
// position_embed.get(0).get(0).print();
// position_embed.get(0).get(0).print();
let embeddings = word_embed + position_embed;
let embeddings = embeddings.apply(&self.layer_norm).apply_t(&self.dropout, train);
let embeddings = embeddings
.apply(&self.layer_norm)
.apply_t(&self.dropout, train);
embeddings
}
}
}

View File

@ -18,18 +18,27 @@
//! Pretrained models are available and can be downloaded using RemoteResources.
//!
//! ```no_run
//!# fn main() -> failure::Fallible<()> {
//!#
//! # fn main() -> failure::Fallible<()> {
//! #
//! use rust_tokenizers::BertTokenizer;
//! use tch::{nn, Device};
//!# use std::path::PathBuf;
//! # use std::path::PathBuf;
//! use rust_bert::distilbert::{
//! DistilBertConfig, DistilBertConfigResources, DistilBertModelMaskedLM,
//! DistilBertModelResources, DistilBertVocabResources,
//! };
//! use rust_bert::resources::{download_resource, LocalResource, RemoteResource, Resource};
//! use rust_bert::Config;
//! use rust_bert::distilbert::{DistilBertModelMaskedLM, DistilBertConfig, DistilBertConfigResources, DistilBertVocabResources, DistilBertModelResources};
//! use rust_bert::resources::{Resource, download_resource, RemoteResource, LocalResource};
//!
//! let config_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/config.json")});
//! let vocab_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/vocab.txt")});
//! let weights_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/model.ot")});
//! let config_resource = Resource::Local(LocalResource {
//! local_path: PathBuf::from("path/to/config.json"),
//! });
//! let vocab_resource = Resource::Local(LocalResource {
//! local_path: PathBuf::from("path/to/vocab.txt"),
//! });
//! let weights_resource = Resource::Local(LocalResource {
//! local_path: PathBuf::from("path/to/model.ot"),
//! });
//! let config_path = download_resource(&config_resource)?;
//! let vocab_path = download_resource(&vocab_resource)?;
//! let weights_path = download_resource(&weights_resource)?;
@ -40,17 +49,17 @@
//! let bert_model = DistilBertModelMaskedLM::new(&vs.root(), &config);
//! vs.load(weights_path)?;
//!
//!# Ok(())
//!# }
//! # Ok(())
//! # }
//! ```
mod attention;
mod distilbert;
mod embeddings;
mod attention;
mod transformer;
pub use distilbert::{DistilBertModelResources, DistilBertConfigResources, DistilBertVocabResources,
DistilBertConfig, Activation, DistilBertModel, DistilBertForQuestionAnswering, DistilBertForTokenClassification,
DistilBertModelMaskedLM, DistilBertModelClassifier};
pub use distilbert::{
Activation, DistilBertConfig, DistilBertConfigResources, DistilBertForQuestionAnswering,
DistilBertForTokenClassification, DistilBertModel, DistilBertModelClassifier,
DistilBertModelMaskedLM, DistilBertModelResources, DistilBertVocabResources,
};

View File

@ -10,13 +10,13 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use tch::{Tensor, nn};
use crate::distilbert::distilbert::{DistilBertConfig, Activation};
use crate::distilbert::attention::MultiHeadSelfAttention;
use tch::nn::LayerNorm;
use std::borrow::BorrowMut;
use crate::common::dropout::Dropout;
use crate::common::activations::{_gelu, _relu};
use crate::common::dropout::Dropout;
use crate::distilbert::attention::MultiHeadSelfAttention;
use crate::distilbert::distilbert::{Activation, DistilBertConfig};
use std::borrow::BorrowMut;
use tch::nn::LayerNorm;
use tch::{nn, Tensor};
pub struct FeedForwardNetwork {
lin1: nn::Linear,
@ -27,18 +27,35 @@ pub struct FeedForwardNetwork {
impl FeedForwardNetwork {
pub fn new(p: nn::Path, config: &DistilBertConfig) -> FeedForwardNetwork {
let lin1 = nn::linear(&p / "lin1", config.dim, config.hidden_dim, Default::default());
let lin2 = nn::linear(&p / "lin2", config.hidden_dim, config.dim, Default::default());
let lin1 = nn::linear(
&p / "lin1",
config.dim,
config.hidden_dim,
Default::default(),
);
let lin2 = nn::linear(
&p / "lin2",
config.hidden_dim,
config.dim,
Default::default(),
);
let dropout = Dropout::new(config.dropout);
let activation = Box::new(match &config.activation {
Activation::gelu => _gelu,
Activation::relu => _relu
Activation::relu => _relu,
});
FeedForwardNetwork { lin1, lin2, dropout, activation }
FeedForwardNetwork {
lin1,
lin2,
dropout,
activation,
}
}
pub fn forward_t(&self, input: &Tensor, train: bool) -> Tensor {
(self.activation)(&input.apply(&self.lin1)).apply(&self.lin2).apply_t(&self.dropout, train)
(self.activation)(&input.apply(&self.lin1))
.apply(&self.lin2)
.apply_t(&self.dropout, train)
}
}
@ -52,10 +69,15 @@ pub struct TransformerBlock {
impl TransformerBlock {
pub fn new(p: &nn::Path, config: &DistilBertConfig) -> TransformerBlock {
let attention = MultiHeadSelfAttention::new(p / "attention", &config);
let layer_norm_config = nn::LayerNormConfig { eps: 1e-12, ..Default::default() };
let sa_layer_norm = nn::layer_norm(p / "sa_layer_norm", vec![config.dim], layer_norm_config);
let layer_norm_config = nn::LayerNormConfig {
eps: 1e-12,
..Default::default()
};
let sa_layer_norm =
nn::layer_norm(p / "sa_layer_norm", vec![config.dim], layer_norm_config);
let ffn = FeedForwardNetwork::new(p / "ffn", &config);
let output_layer_norm = nn::layer_norm(p / "output_layer_norm", vec![config.dim], layer_norm_config);
let output_layer_norm =
nn::layer_norm(p / "output_layer_norm", vec![config.dim], layer_norm_config);
TransformerBlock {
attention,
@ -65,8 +87,15 @@ impl TransformerBlock {
}
}
pub fn forward_t(&self, input: &Tensor, mask: &Option<Tensor>, train: bool) -> (Tensor, Option<Tensor>) {
let (output, sa_weights) = self.attention.forward_t(&input, &input, &input, mask, train);
pub fn forward_t(
&self,
input: &Tensor,
mask: &Option<Tensor>,
train: bool,
) -> (Tensor, Option<Tensor>) {
let (output, sa_weights) = self
.attention
.forward_t(&input, &input, &input, mask, train);
let output = (input + &output).apply(&self.sa_layer_norm);
let output = (&output + self.ffn.forward_t(&output, train)).apply(&self.output_layer_norm);
(output, sa_weights)
@ -84,25 +113,41 @@ impl Transformer {
let p = &(p / "layer");
let output_attentions = match config.output_attentions {
Some(value) => value,
None => false
None => false,
};
let output_hidden_states = match config.output_hidden_states {
Some(value) => value,
None => false
None => false,
};
let mut layers: Vec<TransformerBlock> = vec!();
let mut layers: Vec<TransformerBlock> = vec![];
for layer_index in 0..config.n_layers {
layers.push(TransformerBlock::new(&(p / layer_index), config));
};
}
Transformer { output_attentions, output_hidden_states, layers }
Transformer {
output_attentions,
output_hidden_states,
layers,
}
}
pub fn forward_t(&self, input: &Tensor, mask: Option<Tensor>, train: bool)
-> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
let mut all_hidden_states: Option<Vec<Tensor>> = if self.output_hidden_states { Some(vec!()) } else { None };
let mut all_attentions: Option<Vec<Tensor>> = if self.output_attentions { Some(vec!()) } else { None };
pub fn forward_t(
&self,
input: &Tensor,
mask: Option<Tensor>,
train: bool,
) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
let mut all_hidden_states: Option<Vec<Tensor>> = if self.output_hidden_states {
Some(vec![])
} else {
None
};
let mut all_attentions: Option<Vec<Tensor>> = if self.output_attentions {
Some(vec![])
} else {
None
};
let mut hidden_state = input.copy();
let mut attention_weights: Option<Tensor>;
@ -121,10 +166,10 @@ impl Transformer {
attentions.push(attention_weights.as_ref().unwrap().copy());
};
}
None => break
None => break,
};
};
}
(hidden_state, all_hidden_states, all_attentions)
}
}
}

View File

@ -12,15 +12,15 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use crate::bert::encoder::BertEncoder;
use crate::bert::{Activation, BertConfig};
use crate::common::activations::{_gelu, _mish, _relu};
use crate::common::dropout::Dropout;
use crate::electra::embeddings::ElectraEmbeddings;
use crate::Config;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use crate::bert::{Activation, BertConfig};
use crate::Config;
use crate::electra::embeddings::ElectraEmbeddings;
use tch::{nn, Tensor, Kind};
use crate::bert::encoder::BertEncoder;
use crate::common::activations::{_gelu, _relu, _mish};
use crate::common::dropout::Dropout;
use tch::{nn, Kind, Tensor};
/// # Electra Pretrained model weight files
pub struct ElectraModelResources;
@ -33,23 +33,41 @@ pub struct ElectraVocabResources;
impl ElectraModelResources {
/// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/electra. Modified with conversion to C-array format.
pub const BASE_GENERATOR: (&'static str, &'static str) = ("electra-base-generator/model.ot", "https://cdn.huggingface.co/google/electra-base-generator/rust_model.ot");
pub const BASE_GENERATOR: (&'static str, &'static str) = (
"electra-base-generator/model.ot",
"https://cdn.huggingface.co/google/electra-base-generator/rust_model.ot",
);
/// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/electra. Modified with conversion to C-array format.
pub const BASE_DISCRIMINATOR: (&'static str, &'static str) = ("electra-base-discriminator/model.ot", "https://cdn.huggingface.co/google/electra-base-discriminator/rust_model.ot");
pub const BASE_DISCRIMINATOR: (&'static str, &'static str) = (
"electra-base-discriminator/model.ot",
"https://cdn.huggingface.co/google/electra-base-discriminator/rust_model.ot",
);
}
impl ElectraConfigResources {
/// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/electra. Modified with conversion to C-array format.
pub const BASE_GENERATOR: (&'static str, &'static str) = ("electra-base-generator/config.json", "https://cdn.huggingface.co/google/electra-base-generator/config.json");
pub const BASE_GENERATOR: (&'static str, &'static str) = (
"electra-base-generator/config.json",
"https://cdn.huggingface.co/google/electra-base-generator/config.json",
);
/// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/electra. Modified with conversion to C-array format.
pub const BASE_DISCRIMINATOR: (&'static str, &'static str) = ("electra-base-discriminator/config.json", "https://cdn.huggingface.co/google/electra-base-discriminator/config.json");
pub const BASE_DISCRIMINATOR: (&'static str, &'static str) = (
"electra-base-discriminator/config.json",
"https://cdn.huggingface.co/google/electra-base-discriminator/config.json",
);
}
impl ElectraVocabResources {
/// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/electra. Modified with conversion to C-array format.
pub const BASE_GENERATOR: (&'static str, &'static str) = ("electra-base-generator/vocab.txt", "https://cdn.huggingface.co/google/electra-base-generator/vocab.txt");
pub const BASE_GENERATOR: (&'static str, &'static str) = (
"electra-base-generator/vocab.txt",
"https://cdn.huggingface.co/google/electra-base-generator/vocab.txt",
);
/// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/electra. Modified with conversion to C-array format.
pub const BASE_DISCRIMINATOR: (&'static str, &'static str) = ("electra-base-discriminator/vocab.txt", "https://cdn.huggingface.co/google/electra-base-discriminator/vocab.txt");
pub const BASE_DISCRIMINATOR: (&'static str, &'static str) = (
"electra-base-discriminator/vocab.txt",
"https://cdn.huggingface.co/google/electra-base-discriminator/vocab.txt",
);
}
#[derive(Debug, Serialize, Deserialize)]
@ -103,10 +121,10 @@ impl ElectraModel {
/// # Example
///
/// ```no_run
/// use rust_bert::electra::{ElectraModel, ElectraConfig};
/// use tch::{nn, Device};
/// use rust_bert::electra::{ElectraConfig, ElectraModel};
/// use rust_bert::Config;
/// use std::path::Path;
/// use tch::{nn, Device};
///
/// let config_path = Path::new("path/to/config.json");
/// let device = Device::Cpu;
@ -114,11 +132,15 @@ impl ElectraModel {
/// let config = ElectraConfig::from_file(config_path);
/// let electra_model: ElectraModel = ElectraModel::new(&(&p.root() / "electra"), &config);
/// ```
///
pub fn new(p: &nn::Path, config: &ElectraConfig) -> ElectraModel {
let embeddings = ElectraEmbeddings::new(&(p / "embeddings"), config);
let embeddings_project = if config.embedding_size != config.hidden_size {
Some(nn::linear(&(p / "embeddings_project"), config.embedding_size, config.hidden_size, Default::default()))
Some(nn::linear(
&(p / "embeddings_project"),
config.embedding_size,
config.hidden_size,
Default::default(),
))
} else {
None
};
@ -141,7 +163,11 @@ impl ElectraModel {
label2id: config.label2id.clone(),
};
let encoder = BertEncoder::new(&(p / "encoder"), &bert_config);
ElectraModel { embeddings, embeddings_project, encoder }
ElectraModel {
embeddings,
embeddings_project,
encoder,
}
}
/// Forward pass through the model
@ -164,80 +190,98 @@ impl ElectraModel {
/// # Example
///
/// ```no_run
///# use rust_bert::electra::{ElectraModel, ElectraConfig};
///# use tch::{nn, Device, Tensor, no_grad};
///# use rust_bert::Config;
///# use std::path::Path;
///# use tch::kind::Kind::Int64;
///# let config_path = Path::new("path/to/config.json");
///# let device = Device::Cpu;
///# let vs = nn::VarStore::new(device);
///# let config = ElectraConfig::from_file(config_path);
///# let electra_model: ElectraModel = ElectraModel::new(&vs.root(), &config);
/// let (batch_size, sequence_length) = (64, 128);
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
/// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
/// let position_ids = Tensor::arange(sequence_length, (Int64, device)).expand(&[batch_size, sequence_length], true);
///
/// let (output, all_hidden_states, all_attentions) = no_grad(|| {
/// electra_model
/// .forward_t(Some(input_tensor),
/// Some(mask),
/// Some(token_type_ids),
/// Some(position_ids),
/// None,
/// false).unwrap()
/// });
/// # use rust_bert::electra::{ElectraModel, ElectraConfig};
/// # use tch::{nn, Device, Tensor, no_grad};
/// # use rust_bert::Config;
/// # use std::path::Path;
/// # use tch::kind::Kind::Int64;
/// # let config_path = Path::new("path/to/config.json");
/// # let device = Device::Cpu;
/// # let vs = nn::VarStore::new(device);
/// # let config = ElectraConfig::from_file(config_path);
/// # let electra_model: ElectraModel = ElectraModel::new(&vs.root(), &config);
/// let (batch_size, sequence_length) = (64, 128);
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
/// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
/// let position_ids = Tensor::arange(sequence_length, (Int64, device))
/// .expand(&[batch_size, sequence_length], true);
///
/// let (output, all_hidden_states, all_attentions) = no_grad(|| {
/// electra_model
/// .forward_t(
/// Some(input_tensor),
/// Some(mask),
/// Some(token_type_ids),
/// Some(position_ids),
/// None,
/// false,
/// )
/// .unwrap()
/// });
/// ```
///
pub fn forward_t(&self,
input_ids: Option<Tensor>,
mask: Option<Tensor>,
token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool)
-> Result<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
pub fn forward_t(
&self,
input_ids: Option<Tensor>,
mask: Option<Tensor>,
token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool,
) -> Result<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
let (input_shape, device) = match &input_ids {
Some(input_value) => match &input_embeds {
Some(_) => { return Err("Only one of input ids or input embeddings may be set"); }
None => (input_value.size(), input_value.device())
}
Some(_) => {
return Err("Only one of input ids or input embeddings may be set");
}
None => (input_value.size(), input_value.device()),
},
None => match &input_embeds {
Some(embeds) => (vec!(embeds.size()[0], embeds.size()[1]), embeds.device()),
None => { return Err("At least one of input ids or input embeddings must be set"); }
}
Some(embeds) => (vec![embeds.size()[0], embeds.size()[1]], embeds.device()),
None => {
return Err("At least one of input ids or input embeddings must be set");
}
},
};
let mask = match mask {
Some(value) => value,
None => Tensor::ones(&input_shape, (Kind::Int64, device))
None => Tensor::ones(&input_shape, (Kind::Int64, device)),
};
let extended_attention_mask = match mask.dim() {
3 => mask.unsqueeze(1),
2 => mask.unsqueeze(1).unsqueeze(1),
_ => { return Err("Invalid attention mask dimension, must be 2 or 3"); }
_ => {
return Err("Invalid attention mask dimension, must be 2 or 3");
}
};
let hidden_states = match self.embeddings.forward_t(input_ids, token_type_ids, position_ids, input_embeds, train) {
let hidden_states = match self.embeddings.forward_t(
input_ids,
token_type_ids,
position_ids,
input_embeds,
train,
) {
Ok(value) => value,
Err(e) => { return Err(e); }
Err(e) => {
return Err(e);
}
};
let hidden_states = match &self.embeddings_project {
Some(layer) => hidden_states.apply(layer),
None => hidden_states
None => hidden_states,
};
let (hidden_state, all_hidden_states, all_attentions) =
self.encoder.forward_t(&hidden_states,
&Some(extended_attention_mask),
&None,
&None,
train);
let (hidden_state, all_hidden_states, all_attentions) = self.encoder.forward_t(
&hidden_states,
&Some(extended_attention_mask),
&None,
&None,
train,
);
Ok((hidden_state, all_hidden_states, all_attentions))
}
@ -268,9 +312,9 @@ impl ElectraDiscriminatorHead {
///
/// ```no_run
/// use rust_bert::electra::{ElectraConfig, ElectraDiscriminatorHead};
/// use tch::{nn, Device};
/// use rust_bert::Config;
/// use std::path::Path;
/// use tch::{nn, Device};
///
/// let config_path = Path::new("path/to/config.json");
/// let device = Device::Cpu;
@ -278,16 +322,29 @@ impl ElectraDiscriminatorHead {
/// let config = ElectraConfig::from_file(config_path);
/// let discriminator_head = ElectraDiscriminatorHead::new(&(&p.root() / "electra"), &config);
/// ```
///
pub fn new(p: &nn::Path, config: &ElectraConfig) -> ElectraDiscriminatorHead {
let dense = nn::linear(&(p / "dense"), config.hidden_size, config.hidden_size, Default::default());
let dense_prediction = nn::linear(&(p / "dense_prediction"), config.hidden_size, 1, Default::default());
let dense = nn::linear(
&(p / "dense"),
config.hidden_size,
config.hidden_size,
Default::default(),
);
let dense_prediction = nn::linear(
&(p / "dense_prediction"),
config.hidden_size,
1,
Default::default(),
);
let activation = Box::new(match &config.hidden_act {
Activation::gelu => _gelu,
Activation::relu => _relu,
Activation::mish => _mish
Activation::mish => _mish,
});
ElectraDiscriminatorHead { dense, dense_prediction, activation }
ElectraDiscriminatorHead {
dense,
dense_prediction,
activation,
}
}
/// Forward pass through the discriminator head
@ -305,25 +362,24 @@ impl ElectraDiscriminatorHead {
/// # Example
///
/// ```no_run
///# use rust_bert::electra::{ElectraConfig, ElectraDiscriminatorHead};
///# use tch::{nn, Device, Tensor, no_grad};
///# use rust_bert::Config;
///# use std::path::Path;
///# use tch::kind::Kind::Float;
///# let config_path = Path::new("path/to/config.json");
///# let device = Device::Cpu;
///# let vs = nn::VarStore::new(device);
///# let config = ElectraConfig::from_file(config_path);
///# let discriminator_head = ElectraDiscriminatorHead::new(&vs.root(), &config);
/// let (batch_size, sequence_length) = (64, 128);
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length, config.hidden_size], (Float, device));
///
/// let output = no_grad(|| {
/// discriminator_head.forward(&input_tensor)
/// });
/// # use rust_bert::electra::{ElectraConfig, ElectraDiscriminatorHead};
/// # use tch::{nn, Device, Tensor, no_grad};
/// # use rust_bert::Config;
/// # use std::path::Path;
/// # use tch::kind::Kind::Float;
/// # let config_path = Path::new("path/to/config.json");
/// # let device = Device::Cpu;
/// # let vs = nn::VarStore::new(device);
/// # let config = ElectraConfig::from_file(config_path);
/// # let discriminator_head = ElectraDiscriminatorHead::new(&vs.root(), &config);
/// let (batch_size, sequence_length) = (64, 128);
/// let input_tensor = Tensor::rand(
/// &[batch_size, sequence_length, config.hidden_size],
/// (Float, device),
/// );
///
/// let output = no_grad(|| discriminator_head.forward(&input_tensor));
/// ```
///
pub fn forward(&self, encoder_hidden_states: &Tensor) -> Tensor {
let output = encoder_hidden_states.apply(&self.dense);
let output = (self.activation)(&output);
@ -356,9 +412,9 @@ impl ElectraGeneratorHead {
///
/// ```no_run
/// use rust_bert::electra::{ElectraConfig, ElectraGeneratorHead};
/// use tch::{nn, Device};
/// use rust_bert::Config;
/// use std::path::Path;
/// use tch::{nn, Device};
///
/// let config_path = Path::new("path/to/config.json");
/// let device = Device::Cpu;
@ -366,13 +422,25 @@ impl ElectraGeneratorHead {
/// let config = ElectraConfig::from_file(config_path);
/// let generator_head = ElectraGeneratorHead::new(&(&p.root() / "electra"), &config);
/// ```
///
pub fn new(p: &nn::Path, config: &ElectraConfig) -> ElectraGeneratorHead {
let layer_norm = nn::layer_norm(p / "LayerNorm", vec![config.embedding_size], Default::default());
let dense = nn::linear(&(p / "dense"), config.hidden_size, config.embedding_size, Default::default());
let layer_norm = nn::layer_norm(
p / "LayerNorm",
vec![config.embedding_size],
Default::default(),
);
let dense = nn::linear(
&(p / "dense"),
config.hidden_size,
config.embedding_size,
Default::default(),
);
let activation = Box::new(_gelu);
ElectraGeneratorHead { layer_norm, dense, activation }
ElectraGeneratorHead {
layer_norm,
dense,
activation,
}
}
/// Forward pass through the generator head
@ -388,25 +456,24 @@ impl ElectraGeneratorHead {
/// # Example
///
/// ```no_run
///# use rust_bert::electra::{ElectraConfig, ElectraGeneratorHead};
///# use tch::{nn, Device, Tensor, no_grad};
///# use rust_bert::Config;
///# use std::path::Path;
///# use tch::kind::Kind::Float;
///# let config_path = Path::new("path/to/config.json");
///# let device = Device::Cpu;
///# let vs = nn::VarStore::new(device);
///# let config = ElectraConfig::from_file(config_path);
///# let generator_head = ElectraGeneratorHead::new(&vs.root(), &config);
/// let (batch_size, sequence_length) = (64, 128);
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length, config.hidden_size], (Float, device));
///
/// let output = no_grad(|| {
/// generator_head.forward(&input_tensor)
/// });
/// # use rust_bert::electra::{ElectraConfig, ElectraGeneratorHead};
/// # use tch::{nn, Device, Tensor, no_grad};
/// # use rust_bert::Config;
/// # use std::path::Path;
/// # use tch::kind::Kind::Float;
/// # let config_path = Path::new("path/to/config.json");
/// # let device = Device::Cpu;
/// # let vs = nn::VarStore::new(device);
/// # let config = ElectraConfig::from_file(config_path);
/// # let generator_head = ElectraGeneratorHead::new(&vs.root(), &config);
/// let (batch_size, sequence_length) = (64, 128);
/// let input_tensor = Tensor::rand(
/// &[batch_size, sequence_length, config.hidden_size],
/// (Float, device),
/// );
///
/// let output = no_grad(|| generator_head.forward(&input_tensor));
/// ```
///
pub fn forward(&self, encoder_hidden_states: &Tensor) -> Tensor {
let output = encoder_hidden_states.apply(&self.dense);
let output = (self.activation)(&output);
@ -438,10 +505,10 @@ impl ElectraForMaskedLM {
/// # Example
///
/// ```no_run
/// use rust_bert::electra::{ElectraForMaskedLM, ElectraConfig};
/// use tch::{nn, Device};
/// use rust_bert::electra::{ElectraConfig, ElectraForMaskedLM};
/// use rust_bert::Config;
/// use std::path::Path;
/// use tch::{nn, Device};
///
/// let config_path = Path::new("path/to/config.json");
/// let device = Device::Cpu;
@ -449,13 +516,21 @@ impl ElectraForMaskedLM {
/// let config = ElectraConfig::from_file(config_path);
/// let electra_model: ElectraForMaskedLM = ElectraForMaskedLM::new(&p.root(), &config);
/// ```
///
pub fn new(p: &nn::Path, config: &ElectraConfig) -> ElectraForMaskedLM {
let electra = ElectraModel::new(&(p / "electra"), config);
let generator_head = ElectraGeneratorHead::new(&(p / "generator_predictions"), config);
let lm_head = nn::linear(&(p / "generator_lm_head"), config.embedding_size, config.vocab_size, Default::default());
let lm_head = nn::linear(
&(p / "generator_lm_head"),
config.embedding_size,
config.vocab_size,
Default::default(),
);
ElectraForMaskedLM { electra, generator_head, lm_head }
ElectraForMaskedLM {
electra,
generator_head,
lm_head,
}
}
/// Forward pass through the model
@ -478,46 +553,53 @@ impl ElectraForMaskedLM {
/// # Example
///
/// ```no_run
///# use rust_bert::electra::{ElectraForMaskedLM, ElectraConfig};
///# use tch::{nn, Device, Tensor, no_grad};
///# use rust_bert::Config;
///# use std::path::Path;
///# use tch::kind::Kind::Int64;
///# let config_path = Path::new("path/to/config.json");
///# let device = Device::Cpu;
///# let vs = nn::VarStore::new(device);
///# let config = ElectraConfig::from_file(config_path);
///# let electra_model: ElectraForMaskedLM = ElectraForMaskedLM::new(&vs.root(), &config);
/// let (batch_size, sequence_length) = (64, 128);
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
/// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
/// let position_ids = Tensor::arange(sequence_length, (Int64, device)).expand(&[batch_size, sequence_length], true);
///
/// let (output, all_hidden_states, all_attentions) = no_grad(|| {
/// electra_model
/// .forward_t(Some(input_tensor),
/// Some(mask),
/// Some(token_type_ids),
/// Some(position_ids),
/// None,
/// false)
/// });
/// # use rust_bert::electra::{ElectraForMaskedLM, ElectraConfig};
/// # use tch::{nn, Device, Tensor, no_grad};
/// # use rust_bert::Config;
/// # use std::path::Path;
/// # use tch::kind::Kind::Int64;
/// # let config_path = Path::new("path/to/config.json");
/// # let device = Device::Cpu;
/// # let vs = nn::VarStore::new(device);
/// # let config = ElectraConfig::from_file(config_path);
/// # let electra_model: ElectraForMaskedLM = ElectraForMaskedLM::new(&vs.root(), &config);
/// let (batch_size, sequence_length) = (64, 128);
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
/// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
/// let position_ids = Tensor::arange(sequence_length, (Int64, device))
/// .expand(&[batch_size, sequence_length], true);
///
/// let (output, all_hidden_states, all_attentions) = no_grad(|| {
/// electra_model.forward_t(
/// Some(input_tensor),
/// Some(mask),
/// Some(token_type_ids),
/// Some(position_ids),
/// None,
/// false,
/// )
/// });
/// ```
///
pub fn forward_t(&self,
input_ids: Option<Tensor>,
mask: Option<Tensor>,
token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool)
-> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
let (hidden_states,
all_hidden_states,
all_attentions) = self.electra
.forward_t(input_ids, mask, token_type_ids, position_ids, input_embeds, train)
pub fn forward_t(
&self,
input_ids: Option<Tensor>,
mask: Option<Tensor>,
token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool,
) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
let (hidden_states, all_hidden_states, all_attentions) = self
.electra
.forward_t(
input_ids,
mask,
token_type_ids,
position_ids,
input_embeds,
train,
)
.unwrap();
let hidden_states = self.generator_head.forward(&hidden_states);
let hidden_states = hidden_states.apply(&self.lm_head);
@ -547,10 +629,10 @@ impl ElectraDiscriminator {
/// # Example
///
/// ```no_run
/// use rust_bert::electra::{ElectraDiscriminator, ElectraConfig};
/// use tch::{nn, Device};
/// use rust_bert::electra::{ElectraConfig, ElectraDiscriminator};
/// use rust_bert::Config;
/// use std::path::Path;
/// use tch::{nn, Device};
///
/// let config_path = Path::new("path/to/config.json");
/// let device = Device::Cpu;
@ -558,12 +640,15 @@ impl ElectraDiscriminator {
/// let config = ElectraConfig::from_file(config_path);
/// let electra_model: ElectraDiscriminator = ElectraDiscriminator::new(&p.root(), &config);
/// ```
///
pub fn new(p: &nn::Path, config: &ElectraConfig) -> ElectraDiscriminator {
let electra = ElectraModel::new(&(p / "electra"), config);
let discriminator_head = ElectraDiscriminatorHead::new(&(p / "discriminator_predictions"), config);
let discriminator_head =
ElectraDiscriminatorHead::new(&(p / "discriminator_predictions"), config);
ElectraDiscriminator { electra, discriminator_head }
ElectraDiscriminator {
electra,
discriminator_head,
}
}
/// Forward pass through the model
@ -586,16 +671,16 @@ impl ElectraDiscriminator {
/// # Example
///
/// ```no_run
///# use rust_bert::electra::{ElectraDiscriminator, ElectraConfig};
///# use tch::{nn, Device, Tensor, no_grad};
///# use rust_bert::Config;
///# use std::path::Path;
///# use tch::kind::Kind::Int64;
///# let config_path = Path::new("path/to/config.json");
///# let device = Device::Cpu;
///# let vs = nn::VarStore::new(device);
///# let config = ElectraConfig::from_file(config_path);
///# let electra_model: ElectraDiscriminator = ElectraDiscriminator::new(&vs.root(), &config);
/// # use rust_bert::electra::{ElectraDiscriminator, ElectraConfig};
/// # use tch::{nn, Device, Tensor, no_grad};
/// # use rust_bert::Config;
/// # use std::path::Path;
/// # use tch::kind::Kind::Int64;
/// # let config_path = Path::new("path/to/config.json");
/// # let device = Device::Cpu;
/// # let vs = nn::VarStore::new(device);
/// # let config = ElectraConfig::from_file(config_path);
/// # let electra_model: ElectraDiscriminator = ElectraDiscriminator::new(&vs.root(), &config);
/// let (batch_size, sequence_length) = (64, 128);
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
@ -611,21 +696,26 @@ impl ElectraDiscriminator {
/// None,
/// false)
/// });
///
/// ```
///
pub fn forward_t(&self,
input_ids: Option<Tensor>,
mask: Option<Tensor>,
token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool)
-> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
let (hidden_states,
all_hidden_states,
all_attentions) = self.electra
.forward_t(input_ids, mask, token_type_ids, position_ids, input_embeds, train)
pub fn forward_t(
&self,
input_ids: Option<Tensor>,
mask: Option<Tensor>,
token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool,
) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
let (hidden_states, all_hidden_states, all_attentions) = self
.electra
.forward_t(
input_ids,
mask,
token_type_ids,
position_ids,
input_embeds,
train,
)
.unwrap();
let probabilities = self.discriminator_head.forward(&hidden_states).sigmoid();
(probabilities, all_hidden_states, all_attentions)
@ -656,24 +746,37 @@ impl ElectraForTokenClassification {
/// # Example
///
/// ```no_run
/// use rust_bert::electra::{ElectraForTokenClassification, ElectraConfig};
/// use tch::{nn, Device};
/// use rust_bert::electra::{ElectraConfig, ElectraForTokenClassification};
/// use rust_bert::Config;
/// use std::path::Path;
/// use tch::{nn, Device};
/// let config_path = Path::new("path/to/config.json");
/// let device = Device::Cpu;
/// let p = nn::VarStore::new(device);
/// let config = ElectraConfig::from_file(config_path);
/// let electra_model: ElectraForTokenClassification = ElectraForTokenClassification::new(&p.root(), &config);
/// let electra_model: ElectraForTokenClassification =
/// ElectraForTokenClassification::new(&p.root(), &config);
/// ```
///
pub fn new(p: &nn::Path, config: &ElectraConfig) -> ElectraForTokenClassification {
let electra = ElectraModel::new(&(p / "electra"), config);
let dropout = Dropout::new(config.hidden_dropout_prob);
let num_labels = config.id2label.as_ref().expect("id2label must be provided for classifiers").len() as i64;
let classifier = nn::linear(&(p / "classifier"), config.hidden_size, num_labels, Default::default());
let num_labels = config
.id2label
.as_ref()
.expect("id2label must be provided for classifiers")
.len() as i64;
let classifier = nn::linear(
&(p / "classifier"),
config.hidden_size,
num_labels,
Default::default(),
);
ElectraForTokenClassification { electra, dropout, classifier }
ElectraForTokenClassification {
electra,
dropout,
classifier,
}
}
/// Forward pass through the model
@ -696,16 +799,16 @@ impl ElectraForTokenClassification {
/// # Example
///
/// ```no_run
///# use rust_bert::electra::{ElectraForTokenClassification, ElectraConfig};
///# use tch::{nn, Device, Tensor, no_grad};
///# use rust_bert::Config;
///# use std::path::Path;
///# use tch::kind::Kind::Int64;
///# let config_path = Path::new("path/to/config.json");
///# let device = Device::Cpu;
///# let vs = nn::VarStore::new(device);
///# let config = ElectraConfig::from_file(config_path);
///# let electra_model: ElectraForTokenClassification = ElectraForTokenClassification::new(&vs.root(), &config);
/// # use rust_bert::electra::{ElectraForTokenClassification, ElectraConfig};
/// # use tch::{nn, Device, Tensor, no_grad};
/// # use rust_bert::Config;
/// # use std::path::Path;
/// # use tch::kind::Kind::Int64;
/// # let config_path = Path::new("path/to/config.json");
/// # let device = Device::Cpu;
/// # let vs = nn::VarStore::new(device);
/// # let config = ElectraConfig::from_file(config_path);
/// # let electra_model: ElectraForTokenClassification = ElectraForTokenClassification::new(&vs.root(), &config);
/// let (batch_size, sequence_length) = (64, 128);
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
@ -721,21 +824,26 @@ impl ElectraForTokenClassification {
/// None,
/// false)
/// });
///
/// ```
///
pub fn forward_t(&self,
input_ids: Option<Tensor>,
mask: Option<Tensor>,
token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool)
-> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
let (hidden_states,
all_hidden_states,
all_attentions) = self.electra
.forward_t(input_ids, mask, token_type_ids, position_ids, input_embeds, train)
pub fn forward_t(
&self,
input_ids: Option<Tensor>,
mask: Option<Tensor>,
token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool,
) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
let (hidden_states, all_hidden_states, all_attentions) = self
.electra
.forward_t(
input_ids,
mask,
token_type_ids,
position_ids,
input_embeds,
train,
)
.unwrap();
let output = hidden_states
.apply_t(&self.dropout, train)

View File

@ -12,10 +12,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use tch::{nn, Tensor, Kind};
use crate::common::dropout::Dropout;
use crate::electra::electra::ElectraConfig;
use tch::nn::{EmbeddingConfig, embedding};
use tch::nn::{embedding, EmbeddingConfig};
use tch::{nn, Kind, Tensor};
#[derive(Debug)]
/// # Embeddings implementation for Electra model
@ -34,49 +34,77 @@ impl ElectraEmbeddings {
..Default::default()
};
let word_embeddings: nn::Embedding = embedding(p / "word_embeddings",
config.vocab_size,
config.embedding_size,
embedding_config);
let word_embeddings: nn::Embedding = embedding(
p / "word_embeddings",
config.vocab_size,
config.embedding_size,
embedding_config,
);
let position_embeddings: nn::Embedding = embedding(p / "position_embeddings",
config.max_position_embeddings,
config.embedding_size,
Default::default());
let position_embeddings: nn::Embedding = embedding(
p / "position_embeddings",
config.max_position_embeddings,
config.embedding_size,
Default::default(),
);
let token_type_embeddings: nn::Embedding = embedding(p / "token_type_embeddings",
config.type_vocab_size,
config.embedding_size,
Default::default());
let token_type_embeddings: nn::Embedding = embedding(
p / "token_type_embeddings",
config.type_vocab_size,
config.embedding_size,
Default::default(),
);
let layer_norm_eps = match config.layer_norm_eps {
Some(value) => value,
None => 1e-12
None => 1e-12,
};
let layer_norm_config = nn::LayerNormConfig { eps: layer_norm_eps, ..Default::default() };
let layer_norm: nn::LayerNorm = nn::layer_norm(p / "LayerNorm", vec![config.embedding_size], layer_norm_config);
let layer_norm_config = nn::LayerNormConfig {
eps: layer_norm_eps,
..Default::default()
};
let layer_norm: nn::LayerNorm = nn::layer_norm(
p / "LayerNorm",
vec![config.embedding_size],
layer_norm_config,
);
let dropout: Dropout = Dropout::new(config.hidden_dropout_prob);
ElectraEmbeddings { word_embeddings, position_embeddings, token_type_embeddings, layer_norm, dropout}
ElectraEmbeddings {
word_embeddings,
position_embeddings,
token_type_embeddings,
layer_norm,
dropout,
}
}
pub fn forward_t(&self,
input_ids: Option<Tensor>,
token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool) -> Result<Tensor, &'static str> {
pub fn forward_t(
&self,
input_ids: Option<Tensor>,
token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool,
) -> Result<Tensor, &'static str> {
let (input_embeddings, input_shape) = match input_ids {
Some(input_value) => match input_embeds {
Some(_) => { return Err("Only one of input ids or input embeddings may be set"); }
None => (input_value.apply_t(&self.word_embeddings, train), input_value.size())
}
Some(_) => {
return Err("Only one of input ids or input embeddings may be set");
}
None => (
input_value.apply_t(&self.word_embeddings, train),
input_value.size(),
),
},
None => match input_embeds {
Some(embeds) => {
let size = vec!(embeds.size()[0], embeds.size()[1]);
let size = vec![embeds.size()[0], embeds.size()[1]];
(embeds, size)
},
None => { return Err("Only one of input ids or input embeddings may be set"); }
}
}
None => {
return Err("Only one of input ids or input embeddings may be set");
}
},
};
let seq_length = input_embeddings.as_ref().size()[1].to_owned();
@ -84,19 +112,22 @@ impl ElectraEmbeddings {
let position_ids = match position_ids {
Some(value) => value,
None => Tensor::arange(seq_length, (Kind::Int64, input_embeddings.device()))
.unsqueeze(0).
expand(&input_shape, true)
.unsqueeze(0)
.expand(&input_shape, true),
};
let token_type_ids = match token_type_ids {
Some(value) => value,
None => Tensor::zeros(&input_shape, (Kind::Int64, input_embeddings.device()))
None => Tensor::zeros(&input_shape, (Kind::Int64, input_embeddings.device())),
};
let position_embeddings = position_ids.apply(&self.position_embeddings);
let token_type_embeddings = token_type_ids.apply(&self.token_type_embeddings);
let input_embeddings: Tensor = input_embeddings + position_embeddings + token_type_embeddings;
Ok(input_embeddings.apply(&self.layer_norm).apply_t(&self.dropout, train))
let input_embeddings: Tensor =
input_embeddings + position_embeddings + token_type_embeddings;
Ok(input_embeddings
.apply(&self.layer_norm)
.apply_t(&self.dropout, train))
}
}

View File

@ -23,18 +23,24 @@
//! Pretrained models are available and can be downloaded using RemoteResources.
//!
//! ```no_run
//!# fn main() -> failure::Fallible<()> {
//!#
//! # fn main() -> failure::Fallible<()> {
//! #
//! use rust_tokenizers::BertTokenizer;
//! use tch::{nn, Device};
//!# use std::path::PathBuf;
//! use rust_bert::electra::{ElectraForMaskedLM, ElectraConfig};
//! # use std::path::PathBuf;
//! use rust_bert::electra::{ElectraConfig, ElectraForMaskedLM};
//! use rust_bert::resources::{download_resource, LocalResource, Resource};
//! use rust_bert::Config;
//! use rust_bert::resources::{Resource, download_resource, LocalResource};
//!
//! let config_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/config.json")});
//! let vocab_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/vocab.txt")});
//! let weights_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/model.ot")});
//! let config_resource = Resource::Local(LocalResource {
//! local_path: PathBuf::from("path/to/config.json"),
//! });
//! let vocab_resource = Resource::Local(LocalResource {
//! local_path: PathBuf::from("path/to/vocab.txt"),
//! });
//! let weights_resource = Resource::Local(LocalResource {
//! local_path: PathBuf::from("path/to/model.ot"),
//! });
//! let config_path = download_resource(&config_resource)?;
//! let vocab_path = download_resource(&vocab_resource)?;
//! let weights_path = download_resource(&weights_resource)?;
@ -45,13 +51,15 @@
//! let electra_model = ElectraForMaskedLM::new(&vs.root(), &config);
//! vs.load(weights_path)?;
//!
//!# Ok(())
//!# }
//! # Ok(())
//! # }
//! ```
mod embeddings;
mod electra;
mod embeddings;
pub use electra::{ElectraModelResources, ElectraVocabResources, ElectraConfigResources, ElectraConfig,
ElectraModel, ElectraDiscriminator, ElectraForMaskedLM, ElectraDiscriminatorHead, ElectraGeneratorHead, ElectraForTokenClassification};
pub use electra::{
ElectraConfig, ElectraConfigResources, ElectraDiscriminator, ElectraDiscriminatorHead,
ElectraForMaskedLM, ElectraForTokenClassification, ElectraGeneratorHead, ElectraModel,
ElectraModelResources, ElectraVocabResources,
};

View File

@ -12,11 +12,11 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use tch::{Tensor, nn};
use crate::common::dropout::Dropout;
use crate::gpt2::gpt2::Gpt2Config;
use tch::kind::Kind::Float;
use tch::nn::{Init, Module};
use tch::{nn, Tensor};
#[derive(Debug)]
pub struct GPTConv1D {
@ -26,7 +26,14 @@ pub struct GPTConv1D {
impl GPTConv1D {
pub fn new(p: &nn::Path, nf: i64, nx: i64) -> GPTConv1D {
let weight = p.var("weight", &[nx, nf], Init::Randn { mean: 0., stdev: 0.02 });
let weight = p.var(
"weight",
&[nx, nf],
Init::Randn {
mean: 0.,
stdev: 0.02,
},
);
let bias = p.var("bias", &[nf], Init::Const(0.));
GPTConv1D { weight, bias }
}
@ -38,7 +45,6 @@ impl Module for GPTConv1D {
}
}
pub struct Attention {
bias: Tensor,
c_attn: GPTConv1D,
@ -62,23 +68,27 @@ impl Attention {
let attn_pdrop = match config.attn_pdrop {
Some(value) => value,
None => 0.1
None => 0.1,
};
let resid_pdrop = match config.resid_pdrop {
Some(value) => value,
None => 0.1
None => 0.1,
};
let output_attentions = match config.output_attentions {
Some(value) => value,
None => false
None => false,
};
let attn_dropout = Dropout::new(attn_pdrop);
let resid_dropout = Dropout::new(resid_pdrop);
assert_eq!(config.n_embd % config.n_head, 0, "Attention hidden states not a multiple of the number of heads");
assert_eq!(
config.n_embd % config.n_head,
0,
"Attention hidden states not a multiple of the number of heads"
);
let dim_per_head = config.n_embd / config.n_head;
Attention {
@ -105,19 +115,31 @@ impl Attention {
}
fn flatten(&self, x: Tensor) -> Tensor {
x.transpose(1, 2).contiguous().view((x.size()[0], -1, &self.n_head * self.dim_per_head))
x.transpose(1, 2)
.contiguous()
.view((x.size()[0], -1, &self.n_head * self.dim_per_head))
}
fn attention(&self, q: &Tensor, k: &Tensor, v: &Tensor, attention_mask: &Option<Tensor>, train: bool)
-> (Tensor, Option<Tensor>) {
fn attention(
&self,
q: &Tensor,
k: &Tensor,
v: &Tensor,
attention_mask: &Option<Tensor>,
train: bool,
) -> (Tensor, Option<Tensor>) {
let mut w = q.matmul(&k);
if self.scale { w = w / (*v.size().last().unwrap() as f64).sqrt(); }
if self.scale {
w = w / (*v.size().last().unwrap() as f64).sqrt();
}
let (nd, ns) = (w.size()[2], w.size()[3]);
let b = self.bias.narrow(2, ns - nd, nd).narrow(3, 0, ns);
let mut w: Tensor = w * &b + 1e4 * (&b - 1);
if let Some(mask) = attention_mask { w = w + mask; }
if let Some(mask) = attention_mask {
w = w + mask;
}
w = w.softmax(-1, Float).apply_t(&self.attn_dropout, train);
let output = w.matmul(&v);
@ -128,29 +150,36 @@ impl Attention {
}
}
pub fn forward_t(&self, x: &Tensor, layer_past: &Option<Tensor>, attention_mask: &Option<Tensor>, train: bool)
-> (Tensor, Tensor, Option<Tensor>) {
pub fn forward_t(
&self,
x: &Tensor,
layer_past: &Option<Tensor>,
attention_mask: &Option<Tensor>,
train: bool,
) -> (Tensor, Tensor, Option<Tensor>) {
let x = x.apply(&self.c_attn).split(self.n_state, 2);
let (query, key, value) =
(
self.split_heads(&x[0], false),
self.split_heads(&x[1], true),
self.split_heads(&x[2], false)
);
let (query, key, value) = (
self.split_heads(&x[0], false),
self.split_heads(&x[1], true),
self.split_heads(&x[2], false),
);
let (key, value) = match layer_past {
Some(past) => {
let key = Tensor::cat(&[past.get(0).transpose(-2, -1), key], -1);
let value = Tensor::cat(&[past.get(1), value], -2);
(key, value)
}
None => (key, value)
None => (key, value),
};
let present = Tensor::stack(&[key.transpose(-2, -1), value.copy()], 0);
let (a, attentions) = self.attention(&query, &key, &value, &attention_mask, train);
let a = self.flatten(a).apply(&self.c_proj).apply_t(&self.resid_dropout, train);
let a = self
.flatten(a)
.apply(&self.c_proj)
.apply_t(&self.resid_dropout, train);
(a, present, attentions)
}
}
}

View File

@ -12,16 +12,16 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use serde::{Deserialize, Serialize};
use tch::{nn, Tensor};
use crate::common::dropout::Dropout;
use tch::nn::embedding;
use crate::common::linear::{linear_no_bias, LinearNoBias};
use crate::gpt2::transformer::Block;
use tch::kind::Kind::Int64;
use std::borrow::BorrowMut;
use crate::common::linear::{LinearNoBias, linear_no_bias};
use crate::pipelines::generation::{Cache, LMHeadModel};
use crate::Config;
use crate::pipelines::generation::{LMHeadModel, Cache};
use serde::{Deserialize, Serialize};
use std::borrow::BorrowMut;
use tch::kind::Kind::Int64;
use tch::nn::embedding;
use tch::{nn, Tensor};
/// # GPT2 Pretrained model weight files
pub struct Gpt2ModelResources;
@ -37,54 +37,114 @@ pub struct Gpt2MergesResources;
impl Gpt2ModelResources {
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
pub const GPT2: (&'static str, &'static str) = ("gpt2/model.ot", "https://cdn.huggingface.co/gpt2-rust_model.ot");
pub const GPT2: (&'static str, &'static str) = (
"gpt2/model.ot",
"https://cdn.huggingface.co/gpt2-rust_model.ot",
);
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
pub const GPT2_MEDIUM: (&'static str, &'static str) = ("gpt2-medium/model.ot", "https://cdn.huggingface.co/gpt2-medium-rust_model.ot");
pub const GPT2_MEDIUM: (&'static str, &'static str) = (
"gpt2-medium/model.ot",
"https://cdn.huggingface.co/gpt2-medium-rust_model.ot",
);
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
pub const GPT2_LARGE: (&'static str, &'static str) = ("gpt2-large/model.ot", "https://cdn.huggingface.co/gpt2-large-rust_model.ot");
pub const GPT2_LARGE: (&'static str, &'static str) = (
"gpt2-large/model.ot",
"https://cdn.huggingface.co/gpt2-large-rust_model.ot",
);
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
pub const GPT2_XL: (&'static str, &'static str) = ("gpt2-xl/model.ot", "https://cdn.huggingface.co/gpt2-xl-rust_model.ot");
pub const GPT2_XL: (&'static str, &'static str) = (
"gpt2-xl/model.ot",
"https://cdn.huggingface.co/gpt2-xl-rust_model.ot",
);
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
pub const DISTIL_GPT2: (&'static str, &'static str) = ("distilgpt2/model.ot", "https://cdn.huggingface.co/distilgpt2-rust_model.ot");
pub const DISTIL_GPT2: (&'static str, &'static str) = (
"distilgpt2/model.ot",
"https://cdn.huggingface.co/distilgpt2-rust_model.ot",
);
}
impl Gpt2ConfigResources {
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
pub const GPT2: (&'static str, &'static str) = ("gpt2/config.json", "https://cdn.huggingface.co/gpt2-config.json");
pub const GPT2: (&'static str, &'static str) = (
"gpt2/config.json",
"https://cdn.huggingface.co/gpt2-config.json",
);
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
pub const GPT2_MEDIUM: (&'static str, &'static str) = ("gpt2-medium/config.json", "https://cdn.huggingface.co/gpt2-medium-config.json");
pub const GPT2_MEDIUM: (&'static str, &'static str) = (
"gpt2-medium/config.json",
"https://cdn.huggingface.co/gpt2-medium-config.json",
);
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
pub const GPT2_LARGE: (&'static str, &'static str) = ("gpt2-large/config.json", "https://cdn.huggingface.co/gpt2-large-config.json");
pub const GPT2_LARGE: (&'static str, &'static str) = (
"gpt2-large/config.json",
"https://cdn.huggingface.co/gpt2-large-config.json",
);
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
pub const GPT2_XL: (&'static str, &'static str) = ("gpt2-xl/config.json", "https://cdn.huggingface.co/gpt2-xl-config.json");
pub const GPT2_XL: (&'static str, &'static str) = (
"gpt2-xl/config.json",
"https://cdn.huggingface.co/gpt2-xl-config.json",
);
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
pub const DISTIL_GPT2: (&'static str, &'static str) = ("distilgpt2/config.json", "https://cdn.huggingface.co/distilgpt2-config.json");
pub const DISTIL_GPT2: (&'static str, &'static str) = (
"distilgpt2/config.json",
"https://cdn.huggingface.co/distilgpt2-config.json",
);
}
impl Gpt2VocabResources {
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
pub const GPT2: (&'static str, &'static str) = ("gpt2/vocab.txt", "https://cdn.huggingface.co/gpt2-vocab.json");
pub const GPT2: (&'static str, &'static str) = (
"gpt2/vocab.txt",
"https://cdn.huggingface.co/gpt2-vocab.json",
);
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
pub const GPT2_MEDIUM: (&'static str, &'static str) = ("gpt2-medium/vocab.txt", "https://cdn.huggingface.co/gpt2-medium-vocab.json");
pub const GPT2_MEDIUM: (&'static str, &'static str) = (
"gpt2-medium/vocab.txt",
"https://cdn.huggingface.co/gpt2-medium-vocab.json",
);
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
pub const GPT2_LARGE: (&'static str, &'static str) = ("gpt2-large/vocab.txt", "https://cdn.huggingface.co/gpt2-large-vocab.json");
pub const GPT2_LARGE: (&'static str, &'static str) = (
"gpt2-large/vocab.txt",
"https://cdn.huggingface.co/gpt2-large-vocab.json",
);
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
pub const GPT2_XL: (&'static str, &'static str) = ("gpt2-xl/vocab.txt", "https://cdn.huggingface.co/gpt2-xl-vocab.json");
pub const GPT2_XL: (&'static str, &'static str) = (
"gpt2-xl/vocab.txt",
"https://cdn.huggingface.co/gpt2-xl-vocab.json",
);
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
pub const DISTIL_GPT2: (&'static str, &'static str) = ("distilgpt2/vocab.txt", "https://cdn.huggingface.co/distilgpt2-vocab.json");
pub const DISTIL_GPT2: (&'static str, &'static str) = (
"distilgpt2/vocab.txt",
"https://cdn.huggingface.co/distilgpt2-vocab.json",
);
}
impl Gpt2MergesResources {
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
pub const GPT2: (&'static str, &'static str) = ("gpt2/merges.txt", "https://cdn.huggingface.co/gpt2-merges.txt");
pub const GPT2: (&'static str, &'static str) = (
"gpt2/merges.txt",
"https://cdn.huggingface.co/gpt2-merges.txt",
);
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
pub const GPT2_MEDIUM: (&'static str, &'static str) = ("gpt2-medium/merges.txt", "https://cdn.huggingface.co/gpt2-medium-merges.txt");
pub const GPT2_MEDIUM: (&'static str, &'static str) = (
"gpt2-medium/merges.txt",
"https://cdn.huggingface.co/gpt2-medium-merges.txt",
);
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
pub const GPT2_LARGE: (&'static str, &'static str) = ("gpt2-large/merges.txt", "https://cdn.huggingface.co/gpt2-large-merges.txt");
pub const GPT2_LARGE: (&'static str, &'static str) = (
"gpt2-large/merges.txt",
"https://cdn.huggingface.co/gpt2-large-merges.txt",
);
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
pub const GPT2_XL: (&'static str, &'static str) = ("gpt2-xl/merges.txt", "https://cdn.huggingface.co/gpt2-xl-merges.txt");
pub const GPT2_XL: (&'static str, &'static str) = (
"gpt2-xl/merges.txt",
"https://cdn.huggingface.co/gpt2-xl-merges.txt",
);
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
pub const DISTIL_GPT2: (&'static str, &'static str) = ("distilgpt2/merges.txt", "https://cdn.huggingface.co/distilgpt2-merges.txt");
pub const DISTIL_GPT2: (&'static str, &'static str) = (
"distilgpt2/merges.txt",
"https://cdn.huggingface.co/distilgpt2-merges.txt",
);
}
#[allow(non_camel_case_types)]
@ -156,10 +216,10 @@ impl Gpt2Model {
/// # Example
///
/// ```no_run
/// use tch::{nn, Device};
/// use rust_bert::gpt2::{Gpt2Config, Gpt2Model};
/// use rust_bert::Config;
/// use std::path::Path;
/// use rust_bert::gpt2::{Gpt2Config, Gpt2Model};
/// use tch::{nn, Device};
///
/// let config_path = Path::new("path/to/config.json");
/// let device = Device::Cpu;
@ -167,37 +227,58 @@ impl Gpt2Model {
/// let config = Gpt2Config::from_file(config_path);
/// let gpt2: Gpt2Model = Gpt2Model::new(&(&p.root() / "gpt2"), &config);
/// ```
///
pub fn new(p: &nn::Path, config: &Gpt2Config) -> Gpt2Model {
let p = &(p / "transformer");
let wte = embedding(&(p / "wte"), config.vocab_size, config.n_embd, Default::default());
let wpe = embedding(&(p / "wpe"), config.n_positions, config.n_embd, Default::default());
let wte = embedding(
&(p / "wte"),
config.vocab_size,
config.n_embd,
Default::default(),
);
let wpe = embedding(
&(p / "wpe"),
config.n_positions,
config.n_embd,
Default::default(),
);
let embd_pdrop = match config.embd_pdrop {
Some(value) => value,
None => 0.1
None => 0.1,
};
let drop = Dropout::new(embd_pdrop);
let layer_norm_config = nn::LayerNormConfig { eps: config.layer_norm_epsilon, ..Default::default() };
let layer_norm_config = nn::LayerNormConfig {
eps: config.layer_norm_epsilon,
..Default::default()
};
let ln_f = nn::layer_norm(p / "ln_f", vec![config.n_embd], layer_norm_config);
let mut h: Vec<Block> = vec!();
let mut h: Vec<Block> = vec![];
let h_path = &(p / "h");
for layer_index in 0..config.n_layer {
h.push(Block::new(&(h_path / layer_index), config, true));
};
}
let output_attentions = match config.output_attentions {
Some(value) => value,
None => false
None => false,
};
let output_past = match config.output_past {
Some(value) => value,
None => true
None => true,
};
let output_hidden_states = match config.output_hidden_states {
Some(value) => value,
None => false
None => false,
};
Gpt2Model { wte, wpe, drop, ln_f, h, output_past, output_hidden_states, output_attentions }
Gpt2Model {
wte,
wpe,
drop,
ln_f,
h,
output_past,
output_hidden_states,
output_attentions,
}
}
/// Forward pass through the model
@ -222,63 +303,101 @@ impl Gpt2Model {
/// # Example
///
/// ```no_run
///# use tch::{nn, Device, Tensor, no_grad};
///# use rust_bert::Config;
///# use std::path::Path;
///# use tch::kind::Kind::{Int64, Double};
/// use rust_bert::gpt2::{Gpt2Model, Gpt2Config};
///# let config_path = Path::new("path/to/config.json");
///# let vocab_path = Path::new("path/to/vocab.txt");
///# let device = Device::Cpu;
///# let vs = nn::VarStore::new(device);
///# let config = Gpt2Config::from_file(config_path);
///# let gpt2_model: Gpt2Model = Gpt2Model::new(&vs.root(), &config);
/// let (batch_size, sequence_length, past_sequence_length) = (64, 128, 56);
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
/// let mut past: Vec<Tensor> = Vec::with_capacity(config.n_layer as usize);
/// for _ in 0..config.n_layer as usize {
/// past.push(Tensor::rand(&[2, batch_size, config.n_head, past_sequence_length, config.n_embd / config.n_head], (Double, device)))
/// # use tch::{nn, Device, Tensor, no_grad};
/// # use rust_bert::Config;
/// # use std::path::Path;
/// # use tch::kind::Kind::{Int64, Double};
/// use rust_bert::gpt2::{Gpt2Config, Gpt2Model};
/// # let config_path = Path::new("path/to/config.json");
/// # let vocab_path = Path::new("path/to/vocab.txt");
/// # let device = Device::Cpu;
/// # let vs = nn::VarStore::new(device);
/// # let config = Gpt2Config::from_file(config_path);
/// # let gpt2_model: Gpt2Model = Gpt2Model::new(&vs.root(), &config);
/// let (batch_size, sequence_length, past_sequence_length) = (64, 128, 56);
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
/// let mut past: Vec<Tensor> = Vec::with_capacity(config.n_layer as usize);
/// for _ in 0..config.n_layer as usize {
/// past.push(Tensor::rand(
/// &[
/// 2,
/// batch_size,
/// config.n_head,
/// past_sequence_length,
/// config.n_embd / config.n_head,
/// ],
/// (Double, device),
/// ))
/// }
/// let attention_mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
/// let token_type_ids = Tensor::ones(&[batch_size, sequence_length], (Int64, device));
/// let position_ids = Tensor::arange(sequence_length, (Int64, device)).expand(&[batch_size, sequence_length], true);
///
/// let (output, past, hidden_states, attentions) = no_grad(|| {
/// gpt2_model
/// .forward_t(&Some(input_tensor),
/// &Some(past),
/// &Some(attention_mask),
/// &Some(token_type_ids),
/// &Some(position_ids),
/// &None,
/// false).unwrap()
/// });
/// let attention_mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
/// let token_type_ids = Tensor::ones(&[batch_size, sequence_length], (Int64, device));
/// let position_ids = Tensor::arange(sequence_length, (Int64, device))
/// .expand(&[batch_size, sequence_length], true);
///
/// let (output, past, hidden_states, attentions) = no_grad(|| {
/// gpt2_model
/// .forward_t(
/// &Some(input_tensor),
/// &Some(past),
/// &Some(attention_mask),
/// &Some(token_type_ids),
/// &Some(position_ids),
/// &None,
/// false,
/// )
/// .unwrap()
/// });
/// ```
///
pub fn forward_t(&self,
input_ids: &Option<Tensor>,
layer_past: &Option<Vec<Tensor>>,
attention_mask: &Option<Tensor>,
token_type_ids: &Option<Tensor>,
position_ids: &Option<Tensor>,
input_embeds: &Option<Tensor>,
train: bool) -> Result<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
pub fn forward_t(
&self,
input_ids: &Option<Tensor>,
layer_past: &Option<Vec<Tensor>>,
attention_mask: &Option<Tensor>,
token_type_ids: &Option<Tensor>,
position_ids: &Option<Tensor>,
input_embeds: &Option<Tensor>,
train: bool,
) -> Result<
(
Tensor,
Option<Vec<Tensor>>,
Option<Vec<Tensor>>,
Option<Vec<Tensor>>,
),
&'static str,
> {
let (input_embeddings, seq_length) = match input_ids {
Some(input_value) => match input_embeds {
Some(_) => { return Err("Only one of input ids or input embeddings may be set"); }
None => (input_value.apply(&self.wte), *input_value.size().last().unwrap())
}
Some(_) => {
return Err("Only one of input ids or input embeddings may be set");
}
None => (
input_value.apply(&self.wte),
*input_value.size().last().unwrap(),
),
},
None => match input_embeds {
Some(embeds) => (embeds.copy(), embeds.size()[1]),
None => { return Err("At least one of input ids or input embeddings must be set"); }
}
None => {
return Err("At least one of input ids or input embeddings must be set");
}
},
};
let (layer_past, layer_past_length) = match layer_past {
Some(value) => {
assert_eq!(value.len(), self.h.len(), "Past activations vector must be of length equal to the number of layers");
(value.iter().map(|v| Some(v.copy())).collect::<Vec<Option<Tensor>>>(), value[0].size()[3])
assert_eq!(
value.len(),
self.h.len(),
"Past activations vector must be of length equal to the number of layers"
);
(
value
.iter()
.map(|v| Some(v.copy()))
.collect::<Vec<Option<Tensor>>>(),
value[0].size()[3],
)
}
None => {
let mut out = Vec::with_capacity(self.h.len());
@ -289,31 +408,45 @@ impl Gpt2Model {
let position_ids = match position_ids {
Some(value) => value.copy(),
None => Tensor::arange1(layer_past_length, seq_length + layer_past_length, (Int64, input_embeddings.device())).unsqueeze(0)
None => Tensor::arange1(
layer_past_length,
seq_length + layer_past_length,
(Int64, input_embeddings.device()),
)
.unsqueeze(0),
};
let attention_mask: Option<Tensor> = match attention_mask {
Some(value) => {
Some(
(value
.view((input_embeddings.size()[0], -1))
.unsqueeze(1)
.unsqueeze(2)
- 1.0
) * 10000.0)
}
None => None
Some(value) => Some(
(value
.view((input_embeddings.size()[0], -1))
.unsqueeze(1)
.unsqueeze(2)
- 1.0)
* 10000.0,
),
None => None,
};
let position_embeds = position_ids.apply(&self.wpe);
let token_type_embeds = match token_type_ids {
Some(value) => value.apply(&self.wte),
None => Tensor::zeros_like(&position_embeds)
None => Tensor::zeros_like(&position_embeds),
};
let mut hidden_state: Tensor =
(input_embeddings + position_embeds + token_type_embeds).apply_t(&self.drop, train);
let mut all_presents: Option<Vec<Tensor>> =
if self.output_past { Some(vec![]) } else { None };
let mut all_hidden_states: Option<Vec<Tensor>> = if self.output_hidden_states {
Some(vec![])
} else {
None
};
let mut all_attentions: Option<Vec<Tensor>> = if self.output_attentions {
Some(vec![])
} else {
None
};
let mut hidden_state: Tensor = (input_embeddings + position_embeds + token_type_embeds).apply_t(&self.drop, train);
let mut all_presents: Option<Vec<Tensor>> = if self.output_past { Some(vec!()) } else { None };
let mut all_hidden_states: Option<Vec<Tensor>> = if self.output_hidden_states { Some(vec!()) } else { None };
let mut all_attentions: Option<Vec<Tensor>> = if self.output_attentions { Some(vec!()) } else { None };
let mut layer_iter = self.h.iter().zip(layer_past);
loop {
@ -333,11 +466,16 @@ impl Gpt2Model {
attentions.push(temp.2.as_ref().unwrap().copy());
};
}
None => break
None => break,
};
};
}
Ok((hidden_state.apply(&self.ln_f), all_presents, all_hidden_states, all_attentions))
Ok((
hidden_state.apply(&self.ln_f),
all_presents,
all_hidden_states,
all_attentions,
))
}
}
@ -362,10 +500,10 @@ impl GPT2LMHeadModel {
/// # Example
///
/// ```no_run
/// use tch::{nn, Device};
/// use rust_bert::gpt2::{GPT2LMHeadModel, Gpt2Config};
/// use rust_bert::Config;
/// use std::path::Path;
/// use rust_bert::gpt2::{Gpt2Config, GPT2LMHeadModel};
/// use tch::{nn, Device};
///
/// let config_path = Path::new("path/to/config.json");
/// let device = Device::Cpu;
@ -373,11 +511,18 @@ impl GPT2LMHeadModel {
/// let config = Gpt2Config::from_file(config_path);
/// let gpt2: GPT2LMHeadModel = GPT2LMHeadModel::new(&(&p.root() / "gpt2"), &config);
/// ```
///
pub fn new(p: &nn::Path, config: &Gpt2Config) -> GPT2LMHeadModel {
let transformer = Gpt2Model::new(&p, config);
let lm_head = linear_no_bias(&(p / "lm_head"), config.n_embd, config.vocab_size, Default::default());
GPT2LMHeadModel { transformer, lm_head }
let lm_head = linear_no_bias(
&(p / "lm_head"),
config.n_embd,
config.vocab_size,
Default::default(),
);
GPT2LMHeadModel {
transformer,
lm_head,
}
}
}
@ -408,75 +553,104 @@ impl LMHeadModel for GPT2LMHeadModel {
/// # Example
///
/// ```no_run
///# use tch::{nn, Device, Tensor, no_grad};
///# use rust_bert::Config;
///# use std::path::Path;
///# use tch::kind::Kind::{Int64, Double};
/// use rust_bert::gpt2::{Gpt2Config, GPT2LMHeadModel};
/// use rust_bert::pipelines::generation::{LMHeadModel, Cache};
///# let config_path = Path::new("path/to/config.json");
///# let vocab_path = Path::new("path/to/vocab.txt");
///# let device = Device::Cpu;
///# let vs = nn::VarStore::new(device);
///# let config = Gpt2Config::from_file(config_path);
///# let mut gpt2_model: GPT2LMHeadModel = GPT2LMHeadModel::new(&vs.root(), &config);
/// let (batch_size, sequence_length, past_sequence_length) = (64, 128, 56);
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
/// let mut past: Vec<Tensor> = Vec::with_capacity(config.n_layer as usize);
/// for _ in 0..config.n_layer as usize {
/// past.push(Tensor::rand(&[2, batch_size, config.n_head, past_sequence_length, config.n_embd / config.n_head], (Double, device)))
/// # use tch::{nn, Device, Tensor, no_grad};
/// # use rust_bert::Config;
/// # use std::path::Path;
/// # use tch::kind::Kind::{Int64, Double};
/// use rust_bert::gpt2::{GPT2LMHeadModel, Gpt2Config};
/// use rust_bert::pipelines::generation::{Cache, LMHeadModel};
/// # let config_path = Path::new("path/to/config.json");
/// # let vocab_path = Path::new("path/to/vocab.txt");
/// # let device = Device::Cpu;
/// # let vs = nn::VarStore::new(device);
/// # let config = Gpt2Config::from_file(config_path);
/// # let mut gpt2_model: GPT2LMHeadModel = GPT2LMHeadModel::new(&vs.root(), &config);
/// let (batch_size, sequence_length, past_sequence_length) = (64, 128, 56);
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
/// let mut past: Vec<Tensor> = Vec::with_capacity(config.n_layer as usize);
/// for _ in 0..config.n_layer as usize {
/// past.push(Tensor::rand(
/// &[
/// 2,
/// batch_size,
/// config.n_head,
/// past_sequence_length,
/// config.n_embd / config.n_head,
/// ],
/// (Double, device),
/// ))
/// }
/// let attention_mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
/// let token_type_ids = Tensor::ones(&[batch_size, sequence_length], (Int64, device));
/// let position_ids = Tensor::arange(sequence_length, (Int64, device)).expand(&[batch_size, sequence_length], true);
///
/// let (output, _, past, hidden_states, attentions) = no_grad(|| {
/// gpt2_model
/// .forward_t(&Some(input_tensor),
/// Cache::GPT2Cache(Some(past)),
/// &Some(attention_mask),
/// &Some(token_type_ids),
/// &Some(position_ids),
/// &None,
/// None,
/// &None,
/// false).unwrap()
/// });
/// let attention_mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
/// let token_type_ids = Tensor::ones(&[batch_size, sequence_length], (Int64, device));
/// let position_ids = Tensor::arange(sequence_length, (Int64, device))
/// .expand(&[batch_size, sequence_length], true);
///
/// let (output, _, past, hidden_states, attentions) = no_grad(|| {
/// gpt2_model
/// .forward_t(
/// &Some(input_tensor),
/// Cache::GPT2Cache(Some(past)),
/// &Some(attention_mask),
/// &Some(token_type_ids),
/// &Some(position_ids),
/// &None,
/// None,
/// &None,
/// false,
/// )
/// .unwrap()
/// });
/// ```
///
fn forward_t(&self,
input_ids: &Option<Tensor>,
layer_past: Cache,
attention_mask: &Option<Tensor>,
token_type_ids: &Option<Tensor>,
position_ids: &Option<Tensor>,
input_embeds: &Option<Tensor>,
_encoder_outputs: Option<&Tensor>,
_decoder_input_ids: &Option<Tensor>,
train: bool) -> Result<(Tensor, Option<Tensor>, Cache, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
let (output,
past,
all_hidden_states,
all_attentions) = match layer_past {
Cache::GPT2Cache(layer_past) => Ok(self.transformer.forward_t(input_ids,
&layer_past,
attention_mask,
token_type_ids,
position_ids,
input_embeds,
train)?),
Cache::None => Ok(self.transformer.forward_t(input_ids,
&None,
attention_mask,
token_type_ids,
position_ids,
input_embeds,
train)?),
_ => Err("Cache not compatible with GPT2 model")
fn forward_t(
&self,
input_ids: &Option<Tensor>,
layer_past: Cache,
attention_mask: &Option<Tensor>,
token_type_ids: &Option<Tensor>,
position_ids: &Option<Tensor>,
input_embeds: &Option<Tensor>,
_encoder_outputs: Option<&Tensor>,
_decoder_input_ids: &Option<Tensor>,
train: bool,
) -> Result<
(
Tensor,
Option<Tensor>,
Cache,
Option<Vec<Tensor>>,
Option<Vec<Tensor>>,
),
&'static str,
> {
let (output, past, all_hidden_states, all_attentions) = match layer_past {
Cache::GPT2Cache(layer_past) => Ok(self.transformer.forward_t(
input_ids,
&layer_past,
attention_mask,
token_type_ids,
position_ids,
input_embeds,
train,
)?),
Cache::None => Ok(self.transformer.forward_t(
input_ids,
&None,
attention_mask,
token_type_ids,
position_ids,
input_embeds,
train,
)?),
_ => Err("Cache not compatible with GPT2 model"),
}?;
let lm_logits = output.apply(&self.lm_head);
Ok((lm_logits, None, Cache::GPT2Cache(past), all_hidden_states, all_attentions))
Ok((
lm_logits,
None,
Cache::GPT2Cache(past),
all_hidden_states,
all_attentions,
))
}
}

View File

@ -14,19 +14,27 @@
//! Pretrained models are available and can be downloaded using RemoteResources.
//!
//! ```no_run
//!# fn main() -> failure::Fallible<()> {
//!#
//! # fn main() -> failure::Fallible<()> {
//! #
//! use rust_tokenizers::Gpt2Tokenizer;
//! use tch::{nn, Device};
//!# use std::path::PathBuf;
//! # use std::path::PathBuf;
//! use rust_bert::gpt2::{GPT2LMHeadModel, Gpt2Config};
//! use rust_bert::resources::{download_resource, LocalResource, Resource};
//! use rust_bert::Config;
//! use rust_bert::gpt2::{Gpt2Config, GPT2LMHeadModel};
//! use rust_bert::resources::{Resource, download_resource, LocalResource};
//!
//! let config_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/config.json")});
//! let vocab_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/vocab.txt")});
//! let merges_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/vocab.txt")});
//! let weights_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/model.ot")});
//! let config_resource = Resource::Local(LocalResource {
//! local_path: PathBuf::from("path/to/config.json"),
//! });
//! let vocab_resource = Resource::Local(LocalResource {
//! local_path: PathBuf::from("path/to/vocab.txt"),
//! });
//! let merges_resource = Resource::Local(LocalResource {
//! local_path: PathBuf::from("path/to/vocab.txt"),
//! });
//! let weights_resource = Resource::Local(LocalResource {
//! local_path: PathBuf::from("path/to/model.ot"),
//! });
//! let config_path = download_resource(&config_resource)?;
//! let vocab_path = download_resource(&vocab_resource)?;
//! let merges_path = download_resource(&merges_resource)?;
@ -34,18 +42,24 @@
//!
//! let device = Device::cuda_if_available();
//! let mut vs = nn::VarStore::new(device);
//! let tokenizer: Gpt2Tokenizer = Gpt2Tokenizer::from_file(vocab_path.to_str().unwrap(), merges_path.to_str().unwrap(), true);
//! let tokenizer: Gpt2Tokenizer = Gpt2Tokenizer::from_file(
//! vocab_path.to_str().unwrap(),
//! merges_path.to_str().unwrap(),
//! true,
//! );
//! let config = Gpt2Config::from_file(config_path);
//! let gpt2_model = GPT2LMHeadModel::new(&vs.root(), &config);
//! vs.load(weights_path)?;
//!
//!# Ok(())
//!# }
//! # Ok(())
//! # }
//! ```
mod gpt2;
pub(crate) mod attention;
mod gpt2;
pub(crate) mod transformer;
pub use gpt2::{Gpt2ModelResources, Gpt2ConfigResources, Gpt2VocabResources, Gpt2MergesResources,
Gpt2Config, Gpt2Model, GptActivation, GPT2LMHeadModel};
pub use gpt2::{
GPT2LMHeadModel, Gpt2Config, Gpt2ConfigResources, Gpt2MergesResources, Gpt2Model,
Gpt2ModelResources, Gpt2VocabResources, GptActivation,
};

View File

@ -12,11 +12,11 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use crate::gpt2::attention::{GPTConv1D, Attention};
use tch::{Tensor, nn};
use crate::common::dropout::Dropout;
use crate::gpt2::gpt2::{Gpt2Config, GptActivation};
use crate::common::activations::{_gelu_new, _relu, _swish};
use crate::common::dropout::Dropout;
use crate::gpt2::attention::{Attention, GPTConv1D};
use crate::gpt2::gpt2::{Gpt2Config, GptActivation};
use tch::{nn, Tensor};
pub struct MLP {
c_fc: GPTConv1D,
@ -35,14 +35,19 @@ impl MLP {
GptActivation::relu => _relu,
GptActivation::swish => _swish,
},
None => _gelu_new
None => _gelu_new,
});
let resid_pdrop = match config.resid_pdrop {
Some(value) => value,
None => 0.1
None => 0.1,
};
let dropout = Dropout::new(resid_pdrop);
MLP { c_fc, c_proj, activation, dropout }
MLP {
c_fc,
c_proj,
activation,
dropout,
}
}
pub fn forward_t(&self, x: &Tensor, train: bool) -> Tensor {
@ -60,21 +65,36 @@ pub struct Block {
impl Block {
pub fn new(p: &nn::Path, config: &Gpt2Config, scale: bool) -> Block {
let layer_norm_config = nn::LayerNormConfig { eps: config.layer_norm_epsilon, ..Default::default() };
let layer_norm_config = nn::LayerNormConfig {
eps: config.layer_norm_epsilon,
..Default::default()
};
let ln_1 = nn::layer_norm(p / "ln_1", vec![config.n_embd], layer_norm_config);
let ln_2 = nn::layer_norm(p / "ln_2", vec![config.n_embd], layer_norm_config);
let attn = Attention::new(&(p / "attn"), config, scale);
let mlp = MLP::new(&(p / "mlp"), config);
Block { ln_1, attn, ln_2, mlp }
Block {
ln_1,
attn,
ln_2,
mlp,
}
}
pub fn forward_t(&self, x: &Tensor, layer_past: &Option<Tensor>, attention_mask: &Option<Tensor>, train: bool)
-> (Tensor, Tensor, Option<Tensor>) {
let (output, present, attentions) = self.attn.forward_t(&x.apply(&self.ln_1), layer_past, attention_mask, train);
pub fn forward_t(
&self,
x: &Tensor,
layer_past: &Option<Tensor>,
attention_mask: &Option<Tensor>,
train: bool,
) -> (Tensor, Tensor, Option<Tensor>) {
let (output, present, attentions) =
self.attn
.forward_t(&x.apply(&self.ln_1), layer_past, attention_mask, train);
let x = x + output;
let m = self.mlp.forward_t(&x.apply(&self.ln_2), train);
let x = x + m;
(x, present, attentions)
}
}
}

View File

@ -16,14 +16,14 @@
//!
//! More information on these can be found in the [`pipelines` module](./pipelines/index.html)
//! ```no_run
//! use rust_bert::pipelines::question_answering::{QuestionAnsweringModel, QaInput};
//! use rust_bert::pipelines::question_answering::{QaInput, QuestionAnsweringModel};
//!
//!# fn main() -> failure::Fallible<()> {
//! # fn main() -> failure::Fallible<()> {
//! let qa_model = QuestionAnsweringModel::new(Default::default())?;
//!
//! let question = String::from("Where does Amy live ?");
//! let context = String::from("Amy lives in Amsterdam");
//! let answers = qa_model.predict(&vec!(QaInput { question, context }), 1, 32);
//! let answers = qa_model.predict(&vec![QaInput { question, context }], 1, 32);
//! # Ok(())
//! # }
//! ```
@ -55,19 +55,18 @@
//! - Set-up a virtual environment and install dependencies
//! - run the conversion script python /utils/download-dependencies_{MODEL_TO_DOWNLOAD}.py. The dependencies will be downloaded to the user's home directory, under ~/rustbert/{}
//! 3. Run the example cargo run --release
//!
pub mod distilbert;
pub mod bert;
pub mod roberta;
pub mod openai_gpt;
pub mod gpt2;
pub mod bart;
pub mod electra;
pub mod marian;
pub mod albert;
pub mod bart;
pub mod bert;
mod common;
pub mod distilbert;
pub mod electra;
pub mod gpt2;
pub mod marian;
pub mod openai_gpt;
pub mod pipelines;
pub mod roberta;
pub use common::Config;
pub use common::resources;
pub use common::Config;

View File

@ -11,10 +11,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use crate::bart::{BartModel, BartConfig, LayerState};
use tch::{Tensor, nn};
use crate::pipelines::generation::{LMHeadModel, Cache};
use crate::bart::{BartConfig, BartModel, LayerState};
use crate::pipelines::generation::{Cache, LMHeadModel};
use tch::nn::Init;
use tch::{nn, Tensor};
/// # Marian Pretrained model weight files
pub struct MarianModelResources;
@ -33,78 +33,174 @@ pub struct MarianPrefix;
impl MarianModelResources {
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. Modified with conversion to C-array format.
pub const ENGLISH2ROMANCE: (&'static str, &'static str) = ("marian-mt-en-ROMANCE/model.ot", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-ROMANCE/rust_model.ot");
pub const ENGLISH2ROMANCE: (&'static str, &'static str) = (
"marian-mt-en-ROMANCE/model.ot",
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-ROMANCE/rust_model.ot",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. Modified with conversion to C-array format.
pub const ROMANCE2ENGLISH: (&'static str, &'static str) = ("marian-mt-ROMANCE-en/model.ot", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-ROMANCE-en/rust_model.ot");
pub const ROMANCE2ENGLISH: (&'static str, &'static str) = (
"marian-mt-ROMANCE-en/model.ot",
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-ROMANCE-en/rust_model.ot",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. Modified with conversion to C-array format.
pub const ENGLISH2GERMAN: (&'static str, &'static str) = ("marian-mt-en-de/model.ot", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-de/rust_model.ot");
pub const ENGLISH2GERMAN: (&'static str, &'static str) = (
"marian-mt-en-de/model.ot",
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-de/rust_model.ot",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. Modified with conversion to C-array format.
pub const GERMAN2ENGLISH: (&'static str, &'static str) = ("marian-mt-de-en/model.ot", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-de-en/rust_model.ot");
pub const GERMAN2ENGLISH: (&'static str, &'static str) = (
"marian-mt-de-en/model.ot",
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-de-en/rust_model.ot",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. Modified with conversion to C-array format.
pub const ENGLISH2RUSSIAN: (&'static str, &'static str) = ("marian-mt-en-ru/model.ot", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-ru/rust_model.ot");
pub const ENGLISH2RUSSIAN: (&'static str, &'static str) = (
"marian-mt-en-ru/model.ot",
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-ru/rust_model.ot",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. Modified with conversion to C-array format.
pub const RUSSIAN2ENGLISH: (&'static str, &'static str) = ("marian-mt-ru-en/model.ot", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-ru-en/rust_model.ot");
pub const RUSSIAN2ENGLISH: (&'static str, &'static str) = (
"marian-mt-ru-en/model.ot",
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-ru-en/rust_model.ot",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. Modified with conversion to C-array format.
pub const FRENCH2GERMAN: (&'static str, &'static str) = ("marian-mt-fr-de/model.ot", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-fr-de/rust_model.ot");
pub const FRENCH2GERMAN: (&'static str, &'static str) = (
"marian-mt-fr-de/model.ot",
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-fr-de/rust_model.ot",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. Modified with conversion to C-array format.
pub const GERMAN2FRENCH: (&'static str, &'static str) = ("marian-mt-de-fr/model.ot", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-de-fr/rust_model.ot");
pub const GERMAN2FRENCH: (&'static str, &'static str) = (
"marian-mt-de-fr/model.ot",
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-de-fr/rust_model.ot",
);
}
impl MarianConfigResources {
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
pub const ENGLISH2ROMANCE: (&'static str, &'static str) = ("marian-mt-en-ROMANCE/config.json", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-ROMANCE/config.json");
pub const ENGLISH2ROMANCE: (&'static str, &'static str) = (
"marian-mt-en-ROMANCE/config.json",
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-ROMANCE/config.json",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
pub const ROMANCE2ENGLISH: (&'static str, &'static str) = ("marian-mt-ROMANCE-en/config.json", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-ROMANCE-en/config.json");
pub const ROMANCE2ENGLISH: (&'static str, &'static str) = (
"marian-mt-ROMANCE-en/config.json",
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-ROMANCE-en/config.json",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
pub const ENGLISH2GERMAN: (&'static str, &'static str) = ("marian-mt-en-de/config.json", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-de/config.json");
pub const ENGLISH2GERMAN: (&'static str, &'static str) = (
"marian-mt-en-de/config.json",
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-de/config.json",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
pub const GERMAN2ENGLISH: (&'static str, &'static str) = ("marian-mt-de-en/config.json", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-de-en/config.json");
pub const GERMAN2ENGLISH: (&'static str, &'static str) = (
"marian-mt-de-en/config.json",
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-de-en/config.json",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
pub const ENGLISH2RUSSIAN: (&'static str, &'static str) = ("marian-mt-en-ru/config.json", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-ru/config.json");
pub const ENGLISH2RUSSIAN: (&'static str, &'static str) = (
"marian-mt-en-ru/config.json",
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-ru/config.json",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
pub const RUSSIAN2ENGLISH: (&'static str, &'static str) = ("marian-mt-ru-en/config.json", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-ru-en/config.json");
pub const RUSSIAN2ENGLISH: (&'static str, &'static str) = (
"marian-mt-ru-en/config.json",
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-ru-en/config.json",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
pub const FRENCH2GERMAN: (&'static str, &'static str) = ("marian-mt-fr-de/config.json", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-fr-de/config.json");
pub const FRENCH2GERMAN: (&'static str, &'static str) = (
"marian-mt-fr-de/config.json",
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-fr-de/config.json",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
pub const GERMAN2FRENCH: (&'static str, &'static str) = ("marian-mt-de-fr/config.json", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-de-fr/config.json");
pub const GERMAN2FRENCH: (&'static str, &'static str) = (
"marian-mt-de-fr/config.json",
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-de-fr/config.json",
);
}
impl MarianVocabResources {
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
pub const ENGLISH2ROMANCE: (&'static str, &'static str) = ("marian-mt-en-ROMANCE/vocab.json", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-ROMANCE/vocab.json");
pub const ENGLISH2ROMANCE: (&'static str, &'static str) = (
"marian-mt-en-ROMANCE/vocab.json",
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-ROMANCE/vocab.json",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
pub const ROMANCE2ENGLISH: (&'static str, &'static str) = ("marian-mt-ROMANCE-en/vocab.json", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-ROMANCE-en/vocab.json");
pub const ROMANCE2ENGLISH: (&'static str, &'static str) = (
"marian-mt-ROMANCE-en/vocab.json",
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-ROMANCE-en/vocab.json",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
pub const ENGLISH2GERMAN: (&'static str, &'static str) = ("marian-mt-en-de/vocab.json", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-de/vocab.json");
pub const ENGLISH2GERMAN: (&'static str, &'static str) = (
"marian-mt-en-de/vocab.json",
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-de/vocab.json",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
pub const GERMAN2ENGLISH: (&'static str, &'static str) = ("marian-mt-de-en/vocab.json", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-de-en/vocab.json");
pub const GERMAN2ENGLISH: (&'static str, &'static str) = (
"marian-mt-de-en/vocab.json",
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-de-en/vocab.json",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
pub const ENGLISH2RUSSIAN: (&'static str, &'static str) = ("marian-mt-en-ru/vocab.json", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-ru/vocab.json");
pub const ENGLISH2RUSSIAN: (&'static str, &'static str) = (
"marian-mt-en-ru/vocab.json",
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-ru/vocab.json",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
pub const RUSSIAN2ENGLISH: (&'static str, &'static str) = ("marian-mt-ru-en/vocab.json", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-ru-en/vocab.json");
pub const RUSSIAN2ENGLISH: (&'static str, &'static str) = (
"marian-mt-ru-en/vocab.json",
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-ru-en/vocab.json",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
pub const FRENCH2GERMAN: (&'static str, &'static str) = ("marian-mt-fr-de/vocab.json", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-fr-de/vocab.json");
pub const FRENCH2GERMAN: (&'static str, &'static str) = (
"marian-mt-fr-de/vocab.json",
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-fr-de/vocab.json",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
pub const GERMAN2FRENCH: (&'static str, &'static str) = ("marian-mt-de-fr/vocab.json", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-de-fr/vocab.json");
pub const GERMAN2FRENCH: (&'static str, &'static str) = (
"marian-mt-de-fr/vocab.json",
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-de-fr/vocab.json",
);
}
impl MarianSpmResources {
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
pub const ENGLISH2ROMANCE: (&'static str, &'static str) = ("marian-mt-en-ROMANCE/spiece.model", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-ROMANCE/source.spm");
pub const ENGLISH2ROMANCE: (&'static str, &'static str) = (
"marian-mt-en-ROMANCE/spiece.model",
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-ROMANCE/source.spm",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
pub const ROMANCE2ENGLISH: (&'static str, &'static str) = ("marian-mt-ROMANCE-en/spiece.model", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-ROMANCE-en/source.spm");
pub const ROMANCE2ENGLISH: (&'static str, &'static str) = (
"marian-mt-ROMANCE-en/spiece.model",
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-ROMANCE-en/source.spm",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
pub const ENGLISH2GERMAN: (&'static str, &'static str) = ("marian-mt-en-de/spiece.model", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-de/source.spm");
pub const ENGLISH2GERMAN: (&'static str, &'static str) = (
"marian-mt-en-de/spiece.model",
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-de/source.spm",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
pub const GERMAN2ENGLISH: (&'static str, &'static str) = ("marian-mt-de-en/spiece.model", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-de-en/source.spm");
pub const GERMAN2ENGLISH: (&'static str, &'static str) = (
"marian-mt-de-en/spiece.model",
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-de-en/source.spm",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
pub const ENGLISH2RUSSIAN: (&'static str, &'static str) = ("marian-mt-en-ru/spiece.model", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-ru/source.spm");
pub const ENGLISH2RUSSIAN: (&'static str, &'static str) = (
"marian-mt-en-ru/spiece.model",
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-ru/source.spm",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
pub const RUSSIAN2ENGLISH: (&'static str, &'static str) = ("marian-mt-ru-en/spiece.model", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-ru-en/source.spm");
pub const RUSSIAN2ENGLISH: (&'static str, &'static str) = (
"marian-mt-ru-en/spiece.model",
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-ru-en/source.spm",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
pub const FRENCH2GERMAN: (&'static str, &'static str) = ("marian-mt-fr-de/spiece.model", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-fr-de/source.spm");
pub const FRENCH2GERMAN: (&'static str, &'static str) = (
"marian-mt-fr-de/spiece.model",
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-fr-de/source.spm",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
pub const GERMAN2FRENCH: (&'static str, &'static str) = ("marian-mt-de-fr/spiece.model", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-de-fr/source.spm");
pub const GERMAN2FRENCH: (&'static str, &'static str) = (
"marian-mt-de-fr/spiece.model",
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-de-fr/source.spm",
);
}
impl MarianPrefix {
@ -150,23 +246,34 @@ impl MarianForConditionalGeneration {
/// # Example
///
/// ```no_run
/// use tch::{nn, Device};
/// use rust_bert::bart::{BartConfig, BartForConditionalGeneration};
/// use rust_bert::Config;
/// use std::path::Path;
/// use rust_bert::bart::{BartConfig, BartForConditionalGeneration};
/// use tch::{nn, Device};
///
/// let config_path = Path::new("path/to/config.json");
/// let device = Device::Cpu;
/// let p = nn::VarStore::new(device);
/// let config = BartConfig::from_file(config_path);
/// let generation_mode = true;
/// let bart: BartForConditionalGeneration = BartForConditionalGeneration::new(&(&p.root() / "bart"), &config, generation_mode);
/// let bart: BartForConditionalGeneration =
/// BartForConditionalGeneration::new(&(&p.root() / "bart"), &config, generation_mode);
/// ```
///
pub fn new(p: &nn::Path, config: &BartConfig, generation_mode: bool) -> MarianForConditionalGeneration {
pub fn new(
p: &nn::Path,
config: &BartConfig,
generation_mode: bool,
) -> MarianForConditionalGeneration {
let base_model = BartModel::new(&(p / "model"), config, generation_mode);
let final_logits_bias = p.var("final_logits_bias", &[1, config.vocab_size], Init::Const(0.));
MarianForConditionalGeneration { base_model, final_logits_bias }
let final_logits_bias = p.var(
"final_logits_bias",
&[1, config.vocab_size],
Init::Const(0.),
);
MarianForConditionalGeneration {
base_model,
final_logits_bias,
}
}
/// Forward pass through the model
@ -193,64 +300,101 @@ impl MarianForConditionalGeneration {
/// # Example
///
/// ```no_run
///# use tch::{nn, Device, Tensor, no_grad};
///# use rust_bert::Config;
///# use std::path::Path;
///# use tch::kind::Kind::{Int64, Double};
/// use rust_bert::bart::{BartConfig};
/// # use tch::{nn, Device, Tensor, no_grad};
/// # use rust_bert::Config;
/// # use std::path::Path;
/// # use tch::kind::Kind::{Int64, Double};
/// use rust_bert::bart::BartConfig;
/// use rust_bert::marian::MarianForConditionalGeneration;
///# let config_path = Path::new("path/to/config.json");
///# let vocab_path = Path::new("path/to/vocab.txt");
///# let device = Device::Cpu;
///# let vs = nn::VarStore::new(device);
///# let config = BartConfig::from_file(config_path);
///# let mut marian_model = MarianForConditionalGeneration::new(&vs.root(), &config, false);
/// let (batch_size, source_sequence_length, target_sequence_length) = (64, 128, 56);
/// let input_tensor = Tensor::rand(&[batch_size, source_sequence_length], (Int64, device));
/// let target_tensor = Tensor::rand(&[batch_size, target_sequence_length], (Int64, device));
/// let encoder_attention_mask = Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
/// let decoder_attention_mask = Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
///
/// let (decoder_output, encoder_hidden_states, cache,
/// all_encoder_hidden_states, all_encoder_attentions,
/// all_decoder_hidden_states, all_decoder_attentions) = no_grad(|| {
/// marian_model
/// .forward_t(Some(&input_tensor),
/// Some(&encoder_attention_mask),
/// None,
/// Some(&target_tensor),
/// Some(&decoder_attention_mask),
/// None,
/// false)
/// });
/// # let config_path = Path::new("path/to/config.json");
/// # let vocab_path = Path::new("path/to/vocab.txt");
/// # let device = Device::Cpu;
/// # let vs = nn::VarStore::new(device);
/// # let config = BartConfig::from_file(config_path);
/// # let mut marian_model = MarianForConditionalGeneration::new(&vs.root(), &config, false);
/// let (batch_size, source_sequence_length, target_sequence_length) = (64, 128, 56);
/// let input_tensor = Tensor::rand(&[batch_size, source_sequence_length], (Int64, device));
/// let target_tensor = Tensor::rand(&[batch_size, target_sequence_length], (Int64, device));
/// let encoder_attention_mask =
/// Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
/// let decoder_attention_mask =
/// Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
///
/// let (
/// decoder_output,
/// encoder_hidden_states,
/// cache,
/// all_encoder_hidden_states,
/// all_encoder_attentions,
/// all_decoder_hidden_states,
/// all_decoder_attentions,
/// ) = no_grad(|| {
/// marian_model.forward_t(
/// Some(&input_tensor),
/// Some(&encoder_attention_mask),
/// None,
/// Some(&target_tensor),
/// Some(&decoder_attention_mask),
/// None,
/// false,
/// )
/// });
/// ```
///
pub fn forward_t(&self,
input_ids: Option<&Tensor>,
attention_mask: Option<&Tensor>,
encoder_outputs: Option<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>)>,
decoder_input_ids: Option<&Tensor>,
decoder_attention_mask: Option<&Tensor>,
old_layer_states: Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
train: bool)
-> (Tensor, Tensor, Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
Option<Vec<Tensor>>, Option<Vec<Tensor>>,
Option<Vec<Tensor>>, Option<Vec<Tensor>>)
{
let (decoder_outputs, encoder_hidden_states, decoder_cache,
all_decoder_hidden_states, all_decoder_attentions,
all_encoder_hidden_states, all_encoder_attentions) =
self.base_model.forward_t(input_ids, attention_mask, decoder_input_ids, encoder_outputs, decoder_attention_mask, old_layer_states, train);
pub fn forward_t(
&self,
input_ids: Option<&Tensor>,
attention_mask: Option<&Tensor>,
encoder_outputs: Option<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>)>,
decoder_input_ids: Option<&Tensor>,
decoder_attention_mask: Option<&Tensor>,
old_layer_states: Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
train: bool,
) -> (
Tensor,
Tensor,
Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
Option<Vec<Tensor>>,
Option<Vec<Tensor>>,
Option<Vec<Tensor>>,
Option<Vec<Tensor>>,
) {
let (
decoder_outputs,
encoder_hidden_states,
decoder_cache,
all_decoder_hidden_states,
all_decoder_attentions,
all_encoder_hidden_states,
all_encoder_attentions,
) = self.base_model.forward_t(
input_ids,
attention_mask,
decoder_input_ids,
encoder_outputs,
decoder_attention_mask,
old_layer_states,
train,
);
let lm_logits = decoder_outputs.linear::<Tensor>(&self.base_model.embeddings.ws, None);
(lm_logits, encoder_hidden_states, decoder_cache,
all_decoder_hidden_states, all_decoder_attentions,
all_encoder_hidden_states, all_encoder_attentions)
(
lm_logits,
encoder_hidden_states,
decoder_cache,
all_decoder_hidden_states,
all_decoder_attentions,
all_encoder_hidden_states,
all_encoder_attentions,
)
}
pub fn encode(&self, input_ids: &Tensor, attention_mask: Option<&Tensor>) -> Tensor {
let (encoder_hidden_states, _, _) = self.base_model.encoder.forward_t(input_ids, attention_mask, &self.base_model.embeddings, false);
let (encoder_hidden_states, _, _) = self.base_model.encoder.forward_t(
input_ids,
attention_mask,
&self.base_model.embeddings,
false,
);
encoder_hidden_states
}
}
@ -283,68 +427,97 @@ impl LMHeadModel for MarianForConditionalGeneration {
/// # Example
///
/// ```no_run
///# use tch::{nn, Device, Tensor, no_grad};
///# use rust_bert::Config;
///# use std::path::Path;
///# use tch::kind::Kind::{Int64, Double};
/// use rust_bert::bart::{BartConfig};
/// # use tch::{nn, Device, Tensor, no_grad};
/// # use rust_bert::Config;
/// # use std::path::Path;
/// # use tch::kind::Kind::{Int64, Double};
/// use rust_bert::bart::BartConfig;
/// use rust_bert::marian::MarianForConditionalGeneration;
///# let config_path = Path::new("path/to/config.json");
///# let vocab_path = Path::new("path/to/vocab.txt");
///# let device = Device::Cpu;
///# let vs = nn::VarStore::new(device);
///# let config = BartConfig::from_file(config_path);
///# let marian_model = MarianForConditionalGeneration::new(&vs.root(), &config, false);
/// let (batch_size, source_sequence_length, target_sequence_length) = (64, 128, 56);
/// let input_tensor = Tensor::rand(&[batch_size, source_sequence_length], (Int64, device));
/// let target_tensor = Tensor::rand(&[batch_size, target_sequence_length], (Int64, device));
/// let encoder_attention_mask = Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
/// let decoder_attention_mask = Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
///
/// let (decoder_output, encoder_hidden_states, cache,
/// all_encoder_hidden_states, all_encoder_attentions,
/// all_decoder_hidden_states, all_decoder_attentions) = no_grad(|| {
/// marian_model
/// .forward_t(Some(&input_tensor),
/// Some(&encoder_attention_mask),
/// None,
/// Some(&target_tensor),
/// Some(&decoder_attention_mask),
/// None,
/// false)
/// });
/// # let config_path = Path::new("path/to/config.json");
/// # let vocab_path = Path::new("path/to/vocab.txt");
/// # let device = Device::Cpu;
/// # let vs = nn::VarStore::new(device);
/// # let config = BartConfig::from_file(config_path);
/// # let marian_model = MarianForConditionalGeneration::new(&vs.root(), &config, false);
/// let (batch_size, source_sequence_length, target_sequence_length) = (64, 128, 56);
/// let input_tensor = Tensor::rand(&[batch_size, source_sequence_length], (Int64, device));
/// let target_tensor = Tensor::rand(&[batch_size, target_sequence_length], (Int64, device));
/// let encoder_attention_mask =
/// Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
/// let decoder_attention_mask =
/// Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
///
/// let (
/// decoder_output,
/// encoder_hidden_states,
/// cache,
/// all_encoder_hidden_states,
/// all_encoder_attentions,
/// all_decoder_hidden_states,
/// all_decoder_attentions,
/// ) = no_grad(|| {
/// marian_model.forward_t(
/// Some(&input_tensor),
/// Some(&encoder_attention_mask),
/// None,
/// Some(&target_tensor),
/// Some(&decoder_attention_mask),
/// None,
/// false,
/// )
/// });
/// ```
///
fn forward_t(&self,
input_ids: &Option<Tensor>,
cache: Cache,
attention_mask: &Option<Tensor>,
_token_type_ids: &Option<Tensor>,
_position_ids: &Option<Tensor>,
_input_embeds: &Option<Tensor>,
encoder_outputs: Option<&Tensor>,
decoder_input_ids: &Option<Tensor>,
train: bool) -> Result<(Tensor, Option<Tensor>, Cache, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
fn forward_t(
&self,
input_ids: &Option<Tensor>,
cache: Cache,
attention_mask: &Option<Tensor>,
_token_type_ids: &Option<Tensor>,
_position_ids: &Option<Tensor>,
_input_embeds: &Option<Tensor>,
encoder_outputs: Option<&Tensor>,
decoder_input_ids: &Option<Tensor>,
train: bool,
) -> Result<
(
Tensor,
Option<Tensor>,
Cache,
Option<Vec<Tensor>>,
Option<Vec<Tensor>>,
),
&'static str,
> {
let (decoder_output, encoder_hidden_states, new_cache, _, _, _, _) = match cache {
Cache::BARTCache(cached_layer_states) => self.base_model.forward_t(input_ids.as_ref(),
attention_mask.as_ref(),
decoder_input_ids.as_ref(),
Some((encoder_outputs.as_ref().unwrap().copy(), None, None)),
None,
cached_layer_states,
train),
Cache::None => self.base_model.forward_t(input_ids.as_ref(),
attention_mask.as_ref(),
decoder_input_ids.as_ref(),
Some((encoder_outputs.as_ref().unwrap().copy(), None, None)),
None,
None,
train),
_ => Err("Cache not compatible with Marian Model")?
Cache::BARTCache(cached_layer_states) => self.base_model.forward_t(
input_ids.as_ref(),
attention_mask.as_ref(),
decoder_input_ids.as_ref(),
Some((encoder_outputs.as_ref().unwrap().copy(), None, None)),
None,
cached_layer_states,
train,
),
Cache::None => self.base_model.forward_t(
input_ids.as_ref(),
attention_mask.as_ref(),
decoder_input_ids.as_ref(),
Some((encoder_outputs.as_ref().unwrap().copy(), None, None)),
None,
None,
train,
),
_ => Err("Cache not compatible with Marian Model")?,
};
let lm_logits = decoder_output.linear::<Tensor>(&self.base_model.embeddings.ws, None) + &self.final_logits_bias;
Ok((lm_logits, Some(encoder_hidden_states), Cache::BARTCache(new_cache), None, None))
let lm_logits = decoder_output.linear::<Tensor>(&self.base_model.embeddings.ws, None)
+ &self.final_logits_bias;
Ok((
lm_logits,
Some(encoder_hidden_states),
Cache::BARTCache(new_cache),
None,
None,
))
}
}
}

View File

@ -15,20 +15,28 @@
//! Pretrained models for a number of language pairs are available and can be downloaded using RemoteResources. These are shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
//!
//! ```no_run
//!# fn main() -> failure::Fallible<()> {
//!#
//! # fn main() -> failure::Fallible<()> {
//! #
//! use tch::{nn, Device};
//!# use std::path::PathBuf;
//! use rust_bert::Config;
//! # use std::path::PathBuf;
//! use rust_bert::bart::{BartConfig, BartModel};
//! use rust_bert::resources::{Resource, download_resource, LocalResource};
//! use rust_tokenizers::preprocessing::tokenizer::marian_tokenizer::MarianTokenizer;
//! use rust_bert::marian::MarianForConditionalGeneration;
//! use rust_bert::resources::{download_resource, LocalResource, Resource};
//! use rust_bert::Config;
//! use rust_tokenizers::preprocessing::tokenizer::marian_tokenizer::MarianTokenizer;
//!
//! let config_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/config.json")});
//! let vocab_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/vocab.json")});
//! let sentence_piece_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/spiece.model")});
//! let weights_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/model.ot")});
//! let config_resource = Resource::Local(LocalResource {
//! local_path: PathBuf::from("path/to/config.json"),
//! });
//! let vocab_resource = Resource::Local(LocalResource {
//! local_path: PathBuf::from("path/to/vocab.json"),
//! });
//! let sentence_piece_resource = Resource::Local(LocalResource {
//! local_path: PathBuf::from("path/to/spiece.model"),
//! });
//! let weights_resource = Resource::Local(LocalResource {
//! local_path: PathBuf::from("path/to/model.ot"),
//! });
//! let config_path = download_resource(&config_resource)?;
//! let vocab_path = download_resource(&vocab_resource)?;
//! let spiece_path = download_resource(&sentence_piece_resource)?;
@ -36,15 +44,22 @@
//!
//! let device = Device::cuda_if_available();
//! let mut vs = nn::VarStore::new(device);
//! let tokenizer = MarianTokenizer::from_files(vocab_path.to_str().unwrap(), spiece_path.to_str().unwrap(), true);
//! let tokenizer = MarianTokenizer::from_files(
//! vocab_path.to_str().unwrap(),
//! spiece_path.to_str().unwrap(),
//! true,
//! );
//! let config = BartConfig::from_file(config_path);
//! let marian_model = MarianForConditionalGeneration::new(&vs.root(), &config, false);
//! vs.load(weights_path)?;
//!
//!# Ok(())
//!# }
//! # Ok(())
//! # }
//! ```
mod marian;
pub use marian::{MarianForConditionalGeneration, MarianModelResources, MarianConfigResources, MarianVocabResources, MarianSpmResources, MarianPrefix};
pub use marian::{
MarianConfigResources, MarianForConditionalGeneration, MarianModelResources, MarianPrefix,
MarianSpmResources, MarianVocabResources,
};

View File

@ -14,19 +14,27 @@
//! Pretrained models are available and can be downloaded using RemoteResources.
//!
//! ```no_run
//!# fn main() -> failure::Fallible<()> {
//! # fn main() -> failure::Fallible<()> {
//! use rust_tokenizers::OpenAiGptTokenizer;
//! use tch::{nn, Device};
//!# use std::path::PathBuf;
//! use rust_bert::Config;
//! # use std::path::PathBuf;
//! use rust_bert::gpt2::Gpt2Config;
//! use rust_bert::openai_gpt::OpenAiGptModel;
//! use rust_bert::resources::{Resource, download_resource, LocalResource};
//! use rust_bert::resources::{download_resource, LocalResource, Resource};
//! use rust_bert::Config;
//!
//! let config_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/config.json")});
//! let vocab_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/vocab.txt")});
//! let merges_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/vocab.txt")});
//! let weights_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/model.ot")});
//! let config_resource = Resource::Local(LocalResource {
//! local_path: PathBuf::from("path/to/config.json"),
//! });
//! let vocab_resource = Resource::Local(LocalResource {
//! local_path: PathBuf::from("path/to/vocab.txt"),
//! });
//! let merges_resource = Resource::Local(LocalResource {
//! local_path: PathBuf::from("path/to/vocab.txt"),
//! });
//! let weights_resource = Resource::Local(LocalResource {
//! local_path: PathBuf::from("path/to/model.ot"),
//! });
//! let config_path = download_resource(&config_resource)?;
//! let vocab_path = download_resource(&vocab_resource)?;
//! let merges_path = download_resource(&merges_resource)?;
@ -34,17 +42,23 @@
//!
//! let device = Device::cuda_if_available();
//! let mut vs = nn::VarStore::new(device);
//! let tokenizer: OpenAiGptTokenizer = OpenAiGptTokenizer::from_file(vocab_path.to_str().unwrap(), merges_path.to_str().unwrap(), true);
//! let tokenizer: OpenAiGptTokenizer = OpenAiGptTokenizer::from_file(
//! vocab_path.to_str().unwrap(),
//! merges_path.to_str().unwrap(),
//! true,
//! );
//! let config = Gpt2Config::from_file(config_path);
//! let gpt_model = OpenAiGptModel::new(&vs.root(), &config);
//! vs.load(weights_path)?;
//!
//!# Ok(())
//!# }
//! # Ok(())
//! # }
//! ```
mod openai_gpt;
mod transformer;
pub use openai_gpt::{OpenAiGptModelResources, OpenAiGptConfigResources, OpenAiGptVocabResources, OpenAiGptMergesResources,
OpenAiGptModel, OpenAIGPTLMHeadModel};
pub use openai_gpt::{
OpenAIGPTLMHeadModel, OpenAiGptConfigResources, OpenAiGptMergesResources, OpenAiGptModel,
OpenAiGptModelResources, OpenAiGptVocabResources,
};

View File

@ -12,15 +12,15 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use tch::{nn, Tensor};
use crate::common::dropout::Dropout;
use tch::nn::embedding;
use tch::kind::Kind::Int64;
use std::borrow::BorrowMut;
use crate::common::linear::{LinearNoBias, linear_no_bias};
use crate::openai_gpt::transformer::Block;
use crate::common::linear::{linear_no_bias, LinearNoBias};
use crate::gpt2::Gpt2Config;
use crate::pipelines::generation::{LMHeadModel, Cache};
use crate::openai_gpt::transformer::Block;
use crate::pipelines::generation::{Cache, LMHeadModel};
use std::borrow::BorrowMut;
use tch::kind::Kind::Int64;
use tch::nn::embedding;
use tch::{nn, Tensor};
/// # GPT Pretrained model weight files
pub struct OpenAiGptModelResources;
@ -36,22 +36,34 @@ pub struct OpenAiGptMergesResources;
impl OpenAiGptModelResources {
/// Shared under MIT license by the OpenAI team at https://github.com/openai/finetune-transformer-lm. Modified with conversion to C-array format.
pub const GPT: (&'static str, &'static str) = ("openai-gpt/model.ot", "https://cdn.huggingface.co/openai-gpt-rust_model.ot");
pub const GPT: (&'static str, &'static str) = (
"openai-gpt/model.ot",
"https://cdn.huggingface.co/openai-gpt-rust_model.ot",
);
}
impl OpenAiGptConfigResources {
/// Shared under MIT license by the OpenAI team at https://github.com/openai/finetune-transformer-lm. Modified with conversion to C-array format.
pub const GPT: (&'static str, &'static str) = ("openai-gpt/config.json", "https://cdn.huggingface.co/openai-gpt-config.json");
pub const GPT: (&'static str, &'static str) = (
"openai-gpt/config.json",
"https://cdn.huggingface.co/openai-gpt-config.json",
);
}
impl OpenAiGptVocabResources {
/// Shared under MIT license by the OpenAI team at https://github.com/openai/finetune-transformer-lm. Modified with conversion to C-array format.
pub const GPT: (&'static str, &'static str) = ("openai-gpt/vocab.txt", "https://cdn.huggingface.co/openai-gpt-vocab.json");
pub const GPT: (&'static str, &'static str) = (
"openai-gpt/vocab.txt",
"https://cdn.huggingface.co/openai-gpt-vocab.json",
);
}
impl OpenAiGptMergesResources {
/// Shared under MIT license by the OpenAI team at https://github.com/openai/finetune-transformer-lm. Modified with conversion to C-array format.
pub const GPT: (&'static str, &'static str) = ("openai-gpt/merges.txt", "https://cdn.huggingface.co/openai-gpt-merges.txt");
pub const GPT: (&'static str, &'static str) = (
"openai-gpt/merges.txt",
"https://cdn.huggingface.co/openai-gpt-merges.txt",
);
}
/// # GPT Base model
@ -82,11 +94,11 @@ impl OpenAiGptModel {
/// # Example
///
/// ```no_run
/// use tch::{nn, Device};
/// use rust_bert::Config;
/// use std::path::Path;
/// use rust_bert::gpt2::Gpt2Config;
/// use rust_bert::openai_gpt::OpenAiGptModel;
/// use rust_bert::Config;
/// use std::path::Path;
/// use tch::{nn, Device};
///
/// let config_path = Path::new("path/to/config.json");
/// let device = Device::Cpu;
@ -94,30 +106,46 @@ impl OpenAiGptModel {
/// let config = Gpt2Config::from_file(config_path);
/// let gpt2: OpenAiGptModel = OpenAiGptModel::new(&(&p.root() / "gpt"), &config);
/// ```
///
pub fn new(p: &nn::Path, config: &Gpt2Config) -> OpenAiGptModel {
let tokens_embed = embedding(&(p / "tokens_embed"), config.vocab_size, config.n_embd, Default::default());
let positions_embed = embedding(&(p / "positions_embed"), config.n_positions, config.n_embd, Default::default());
let tokens_embed = embedding(
&(p / "tokens_embed"),
config.vocab_size,
config.n_embd,
Default::default(),
);
let positions_embed = embedding(
&(p / "positions_embed"),
config.n_positions,
config.n_embd,
Default::default(),
);
let embd_pdrop = match config.embd_pdrop {
Some(value) => value,
None => 0.1
None => 0.1,
};
let drop = Dropout::new(embd_pdrop);
let mut h: Vec<Block> = vec!();
let mut h: Vec<Block> = vec![];
let h_path = &(p / "h");
for layer_index in 0..config.n_layer {
h.push(Block::new(&(h_path / layer_index), config, true));
};
}
let output_attentions = match config.output_attentions {
Some(value) => value,
None => false
None => false,
};
let output_hidden_states = match config.output_hidden_states {
Some(value) => value,
None => false
None => false,
};
OpenAiGptModel { tokens_embed, positions_embed, drop, h, output_hidden_states, output_attentions }
OpenAiGptModel {
tokens_embed,
positions_embed,
drop,
h,
output_hidden_states,
output_attentions,
}
}
/// Forward pass through the model
@ -140,80 +168,99 @@ impl OpenAiGptModel {
/// # Example
///
/// ```no_run
///# use tch::{nn, Device, Tensor, no_grad};
///# use rust_bert::Config;
///# use std::path::Path;
///# use tch::kind::Kind::{Int64, Double};
/// # use tch::{nn, Device, Tensor, no_grad};
/// # use rust_bert::Config;
/// # use std::path::Path;
/// # use tch::kind::Kind::{Int64, Double};
/// use rust_bert::gpt2::Gpt2Config;
/// use rust_bert::openai_gpt::OpenAiGptModel;
///# let config_path = Path::new("path/to/config.json");
///# let vocab_path = Path::new("path/to/vocab.txt");
///# let device = Device::Cpu;
///# let vs = nn::VarStore::new(device);
///# let config = Gpt2Config::from_file(config_path);
///# let gpt_model: OpenAiGptModel = OpenAiGptModel::new(&vs.root(), &config);
/// let (batch_size, sequence_length, past_sequence_length) = (64, 128, 56);
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
/// let attention_mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
/// let token_type_ids = Tensor::ones(&[batch_size, sequence_length], (Int64, device));
/// let position_ids = Tensor::arange(sequence_length, (Int64, device)).expand(&[batch_size, sequence_length], true);
///
/// let (output, hidden_states, attentions) = no_grad(|| {
/// gpt_model
/// .forward_t(&Some(input_tensor),
/// &Some(attention_mask),
/// &Some(token_type_ids),
/// &Some(position_ids),
/// &None,
/// false).unwrap()
/// });
/// # let config_path = Path::new("path/to/config.json");
/// # let vocab_path = Path::new("path/to/vocab.txt");
/// # let device = Device::Cpu;
/// # let vs = nn::VarStore::new(device);
/// # let config = Gpt2Config::from_file(config_path);
/// # let gpt_model: OpenAiGptModel = OpenAiGptModel::new(&vs.root(), &config);
/// let (batch_size, sequence_length, past_sequence_length) = (64, 128, 56);
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
/// let attention_mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
/// let token_type_ids = Tensor::ones(&[batch_size, sequence_length], (Int64, device));
/// let position_ids = Tensor::arange(sequence_length, (Int64, device))
/// .expand(&[batch_size, sequence_length], true);
///
/// let (output, hidden_states, attentions) = no_grad(|| {
/// gpt_model
/// .forward_t(
/// &Some(input_tensor),
/// &Some(attention_mask),
/// &Some(token_type_ids),
/// &Some(position_ids),
/// &None,
/// false,
/// )
/// .unwrap()
/// });
/// ```
///
pub fn forward_t(&self,
input_ids: &Option<Tensor>,
attention_mask: &Option<Tensor>,
token_type_ids: &Option<Tensor>,
position_ids: &Option<Tensor>,
input_embeds: &Option<Tensor>,
train: bool) -> Result<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
pub fn forward_t(
&self,
input_ids: &Option<Tensor>,
attention_mask: &Option<Tensor>,
token_type_ids: &Option<Tensor>,
position_ids: &Option<Tensor>,
input_embeds: &Option<Tensor>,
train: bool,
) -> Result<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
let (input_embeddings, seq_length) = match input_ids {
Some(input_value) => match input_embeds {
Some(_) => { return Err("Only one of input ids or input embeddings may be set"); }
None => (input_value.apply(&self.tokens_embed), *input_value.size().last().unwrap())
}
Some(_) => {
return Err("Only one of input ids or input embeddings may be set");
}
None => (
input_value.apply(&self.tokens_embed),
*input_value.size().last().unwrap(),
),
},
None => match input_embeds {
Some(embeds) => (embeds.copy(), embeds.size()[1]),
None => { return Err("At least one of input ids or input embeddings must be set"); }
}
None => {
return Err("At least one of input ids or input embeddings must be set");
}
},
};
let position_ids = match position_ids {
Some(value) => value.copy(),
None => Tensor::arange(seq_length, (Int64, input_embeddings.device())).unsqueeze(0)
None => Tensor::arange(seq_length, (Int64, input_embeddings.device())).unsqueeze(0),
};
let attention_mask: Option<Tensor> = match attention_mask {
Some(value) => {
Some(
(value
.view((input_embeddings.size()[0], -1))
.unsqueeze(1)
.unsqueeze(2)
- 1.0
) * 10000.0)
}
None => None
Some(value) => Some(
(value
.view((input_embeddings.size()[0], -1))
.unsqueeze(1)
.unsqueeze(2)
- 1.0)
* 10000.0,
),
None => None,
};
let position_embeds = position_ids.apply(&self.positions_embed);
let token_type_embeds = match token_type_ids {
Some(value) => value.apply(&self.tokens_embed),
None => Tensor::zeros_like(&position_embeds)
None => Tensor::zeros_like(&position_embeds),
};
let mut hidden_state: Tensor =
(input_embeddings + position_embeds + token_type_embeds).apply_t(&self.drop, train);
let mut all_hidden_states: Option<Vec<Tensor>> = if self.output_hidden_states {
Some(vec![])
} else {
None
};
let mut all_attentions: Option<Vec<Tensor>> = if self.output_attentions {
Some(vec![])
} else {
None
};
let mut hidden_state: Tensor = (input_embeddings + position_embeds + token_type_embeds).apply_t(&self.drop, train);
let mut all_hidden_states: Option<Vec<Tensor>> = if self.output_hidden_states { Some(vec!()) } else { None };
let mut all_attentions: Option<Vec<Tensor>> = if self.output_attentions { Some(vec!()) } else { None };
let mut layers = self.h.iter();
loop {
@ -229,9 +276,9 @@ impl OpenAiGptModel {
attentions.push(temp.1.as_ref().unwrap().copy());
};
}
None => break
None => break,
};
};
}
Ok((hidden_state, all_hidden_states, all_attentions))
}
@ -258,11 +305,11 @@ impl OpenAIGPTLMHeadModel {
/// # Example
///
/// ```no_run
/// use tch::{nn, Device};
/// use rust_bert::Config;
/// use std::path::Path;
/// use rust_bert::gpt2::Gpt2Config;
/// use rust_bert::openai_gpt::OpenAIGPTLMHeadModel;
/// use rust_bert::Config;
/// use std::path::Path;
/// use tch::{nn, Device};
///
/// let config_path = Path::new("path/to/config.json");
/// let device = Device::Cpu;
@ -270,11 +317,18 @@ impl OpenAIGPTLMHeadModel {
/// let config = Gpt2Config::from_file(config_path);
/// let gpt2: OpenAIGPTLMHeadModel = OpenAIGPTLMHeadModel::new(&(&p.root() / "gpt"), &config);
/// ```
///
pub fn new(p: &nn::Path, config: &Gpt2Config) -> OpenAIGPTLMHeadModel {
let transformer = OpenAiGptModel::new(&p, config);
let lm_head = linear_no_bias(&(p / "lm_head"), config.n_embd, config.vocab_size, Default::default());
OpenAIGPTLMHeadModel { transformer, lm_head }
let lm_head = linear_no_bias(
&(p / "lm_head"),
config.n_embd,
config.vocab_size,
Default::default(),
);
OpenAIGPTLMHeadModel {
transformer,
lm_head,
}
}
}
@ -305,19 +359,19 @@ impl LMHeadModel for OpenAIGPTLMHeadModel {
/// # Example
///
/// ```no_run
///# use tch::{nn, Device, Tensor, no_grad};
///# use rust_bert::Config;
///# use std::path::Path;
///# use tch::kind::Kind::{Int64, Double};
/// # use tch::{nn, Device, Tensor, no_grad};
/// # use rust_bert::Config;
/// # use std::path::Path;
/// # use tch::kind::Kind::{Int64, Double};
/// use rust_bert::gpt2::Gpt2Config;
/// use rust_bert::openai_gpt::OpenAIGPTLMHeadModel;
/// use rust_bert::pipelines::generation::{LMHeadModel, Cache};
///# let config_path = Path::new("path/to/config.json");
///# let vocab_path = Path::new("path/to/vocab.txt");
///# let device = Device::Cpu;
///# let vs = nn::VarStore::new(device);
///# let config = Gpt2Config::from_file(config_path);
///# let mut gpt_model: OpenAIGPTLMHeadModel = OpenAIGPTLMHeadModel::new(&vs.root(), &config);
/// # let config_path = Path::new("path/to/config.json");
/// # let vocab_path = Path::new("path/to/vocab.txt");
/// # let device = Device::Cpu;
/// # let vs = nn::VarStore::new(device);
/// # let config = Gpt2Config::from_file(config_path);
/// # let mut gpt_model: OpenAIGPTLMHeadModel = OpenAIGPTLMHeadModel::new(&vs.root(), &config);
/// let (batch_size, sequence_length, past_sequence_length) = (64, 128, 56);
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
/// let attention_mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
@ -336,29 +390,44 @@ impl LMHeadModel for OpenAIGPTLMHeadModel {
/// &None,
/// false).unwrap()
/// });
///
/// ```
///
fn forward_t(&self,
input_ids: &Option<Tensor>,
_layer_past: Cache,
attention_mask: &Option<Tensor>,
token_type_ids: &Option<Tensor>,
position_ids: &Option<Tensor>,
input_embeds: &Option<Tensor>,
_encoder_outputs: Option<&Tensor>,
_decoder_input_ids: &Option<Tensor>,
train: bool) -> Result<(Tensor, Option<Tensor>, Cache, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
let (output,
all_hidden_states,
all_attentions) = self.transformer.forward_t(input_ids,
attention_mask,
token_type_ids,
position_ids,
input_embeds,
train)?;
fn forward_t(
&self,
input_ids: &Option<Tensor>,
_layer_past: Cache,
attention_mask: &Option<Tensor>,
token_type_ids: &Option<Tensor>,
position_ids: &Option<Tensor>,
input_embeds: &Option<Tensor>,
_encoder_outputs: Option<&Tensor>,
_decoder_input_ids: &Option<Tensor>,
train: bool,
) -> Result<
(
Tensor,
Option<Tensor>,
Cache,
Option<Vec<Tensor>>,
Option<Vec<Tensor>>,
),
&'static str,
> {
let (output, all_hidden_states, all_attentions) = self.transformer.forward_t(
input_ids,
attention_mask,
token_type_ids,
position_ids,
input_embeds,
train,
)?;
let lm_logits = output.apply(&self.lm_head);
Ok((lm_logits, None, Cache::None, all_hidden_states, all_attentions))
Ok((
lm_logits,
None,
Cache::None,
all_hidden_states,
all_attentions,
))
}
}
}

View File

@ -13,9 +13,9 @@
// limitations under the License.
use crate::gpt2::attention::Attention;
use tch::{Tensor, nn};
use crate::gpt2::transformer::MLP;
use crate::gpt2::Gpt2Config;
use tch::{nn, Tensor};
pub struct Block {
ln_1: nn::LayerNorm,
@ -26,21 +26,33 @@ pub struct Block {
impl Block {
pub fn new(p: &nn::Path, config: &Gpt2Config, scale: bool) -> Block {
let layer_norm_config = nn::LayerNormConfig { eps: config.layer_norm_epsilon, ..Default::default() };
let layer_norm_config = nn::LayerNormConfig {
eps: config.layer_norm_epsilon,
..Default::default()
};
let ln_1 = nn::layer_norm(p / "ln_1", vec![config.n_embd], layer_norm_config);
let ln_2 = nn::layer_norm(p / "ln_2", vec![config.n_embd], layer_norm_config);
let attn = Attention::new(&(p / "attn"), config, scale);
let mlp = MLP::new(&(p / "mlp"), config);
Block { ln_1, attn, ln_2, mlp }
Block {
ln_1,
attn,
ln_2,
mlp,
}
}
pub fn forward_t(&self, x: &Tensor, attention_mask: &Option<Tensor>, train: bool)
-> (Tensor, Option<Tensor>) {
pub fn forward_t(
&self,
x: &Tensor,
attention_mask: &Option<Tensor>,
train: bool,
) -> (Tensor, Option<Tensor>) {
let (output, _, attentions) = self.attn.forward_t(x, &None, attention_mask, train);
let x = (x + output).apply(&self.ln_1);
let m = self.mlp.forward_t(&x, train);
let x = (x + m).apply(&self.ln_2);
(x, attentions)
}
}
}

View File

@ -19,13 +19,13 @@
//!
use crate::bert::BertConfig;
use crate::distilbert::DistilBertConfig;
use rust_tokenizers::{BertTokenizer, RobertaTokenizer, TokenizedInput, TruncationStrategy};
use rust_tokenizers::preprocessing::tokenizer::base_tokenizer::Tokenizer;
use std::path::Path;
use crate::Config;
use std::collections::HashMap;
use serde::{Serialize, Deserialize};
use crate::electra::ElectraConfig;
use crate::Config;
use rust_tokenizers::preprocessing::tokenizer::base_tokenizer::Tokenizer;
use rust_tokenizers::{BertTokenizer, RobertaTokenizer, TokenizedInput, TruncationStrategy};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::Path;
#[derive(Clone, Copy, Serialize, Deserialize)]
/// # Identifies the type of model
@ -60,25 +60,42 @@ impl ConfigOption {
match model_type {
ModelType::Bert | ModelType::Roberta => ConfigOption::Bert(BertConfig::from_file(path)),
ModelType::DistilBert => ConfigOption::DistilBert(DistilBertConfig::from_file(path)),
ModelType::Electra => ConfigOption::Electra(ElectraConfig::from_file(path))
ModelType::Electra => ConfigOption::Electra(ElectraConfig::from_file(path)),
}
}
pub fn get_label_mapping(self) -> HashMap<i64, String> {
match self {
Self::Bert(config) => config.id2label.expect("No label dictionary (id2label) provided in configuration file"),
Self::DistilBert(config) => config.id2label.expect("No label dictionary (id2label) provided in configuration file"),
Self::Electra(config) => config.id2label.expect("No label dictionary (id2label) provided in configuration file"),
Self::Bert(config) => config
.id2label
.expect("No label dictionary (id2label) provided in configuration file"),
Self::DistilBert(config) => config
.id2label
.expect("No label dictionary (id2label) provided in configuration file"),
Self::Electra(config) => config
.id2label
.expect("No label dictionary (id2label) provided in configuration file"),
}
}
}
impl TokenizerOption {
/// Interface method to load a tokenizer from file
pub fn from_file(model_type: ModelType, vocab_path: &str, merges_path: Option<&str>, lower_case: bool) -> Self {
pub fn from_file(
model_type: ModelType,
vocab_path: &str,
merges_path: Option<&str>,
lower_case: bool,
) -> Self {
match model_type {
ModelType::Bert | ModelType::DistilBert | ModelType::Electra => TokenizerOption::Bert(BertTokenizer::from_file(vocab_path, lower_case)),
ModelType::Roberta => TokenizerOption::Roberta(RobertaTokenizer::from_file(vocab_path, merges_path.expect("No merges specified!"), lower_case)),
ModelType::Bert | ModelType::DistilBert | ModelType::Electra => {
TokenizerOption::Bert(BertTokenizer::from_file(vocab_path, lower_case))
}
ModelType::Roberta => TokenizerOption::Roberta(RobertaTokenizer::from_file(
vocab_path,
merges_path.expect("No merges specified!"),
lower_case,
)),
}
}
@ -86,15 +103,25 @@ impl TokenizerOption {
pub fn model_type(&self) -> ModelType {
match *self {
Self::Bert(_) => ModelType::Bert,
Self::Roberta(_) => ModelType::Roberta
Self::Roberta(_) => ModelType::Roberta,
}
}
/// Interface method
pub fn encode_list(&self, text_list: Vec<&str>, max_len: usize, truncation_strategy: &TruncationStrategy, stride: usize) -> Vec<TokenizedInput> {
pub fn encode_list(
&self,
text_list: Vec<&str>,
max_len: usize,
truncation_strategy: &TruncationStrategy,
stride: usize,
) -> Vec<TokenizedInput> {
match *self {
Self::Bert(ref tokenizer) => tokenizer.encode_list(text_list, max_len, truncation_strategy, stride),
Self::Roberta(ref tokenizer) => tokenizer.encode_list(text_list, max_len, truncation_strategy, stride)
Self::Bert(ref tokenizer) => {
tokenizer.encode_list(text_list, max_len, truncation_strategy, stride)
}
Self::Roberta(ref tokenizer) => {
tokenizer.encode_list(text_list, max_len, truncation_strategy, stride)
}
}
}
}

File diff suppressed because it is too large Load Diff

View File

@ -6,32 +6,29 @@
//! Extractive question answering from a given question and context. DistilBERT model finetuned on SQuAD (Stanford Question Answering Dataset)
//!
//! ```no_run
//! use rust_bert::pipelines::question_answering::{QuestionAnsweringModel, QaInput};
//!# fn main() -> failure::Fallible<()> {
//! use rust_bert::pipelines::question_answering::{QaInput, QuestionAnsweringModel};
//! # fn main() -> failure::Fallible<()> {
//! let qa_model = QuestionAnsweringModel::new(Default::default())?;
//!
//! let question = String::from("Where does Amy live ?");
//! let context = String::from("Amy lives in Amsterdam");
//!
//! let answers = qa_model.predict(&vec!(QaInput { question, context }), 1, 32);
//!# Ok(())
//!# }
//! let answers = qa_model.predict(&vec![QaInput { question, context }], 1, 32);
//! # Ok(())
//! # }
//! ```
//!
//! Output: \
//! ```no_run
//!# use rust_bert::pipelines::question_answering::Answer;
//!# let output =
//! [
//! Answer {
//! score: 0.9976,
//! start: 13,
//! end: 21,
//! answer: "Amsterdam"
//!# .to_owned()
//! }
//! ]
//!# ;
//! # use rust_bert::pipelines::question_answering::Answer;
//! # let output =
//! [Answer {
//! score: 0.9976,
//! start: 13,
//! end: 21,
//! answer: "Amsterdam", //#### # .to_owned()
//! }]
//! # ;
//! ```
//!
//! #### 2. Translation
@ -46,74 +43,75 @@
//! - English <-> Russian
//! - French <-> German
//! ```no_run
//!# fn main() -> failure::Fallible<()> {
//!# use rust_bert::pipelines::generation::LanguageGenerator;
//! use rust_bert::pipelines::translation::{TranslationModel, TranslationConfig, Language};
//! # fn main() -> failure::Fallible<()> {
//! # use rust_bert::pipelines::generation::LanguageGenerator;
//! use rust_bert::pipelines::translation::{Language, TranslationConfig, TranslationModel};
//! use tch::Device;
//! let translation_config = TranslationConfig::new(Language::EnglishToFrench, Device::cuda_if_available());
//! let translation_config =
//! TranslationConfig::new(Language::EnglishToFrench, Device::cuda_if_available());
//! let mut model = TranslationModel::new(translation_config)?;
//!
//! let input = ["This is a sentence to be translated"];
//!
//! let output = model.translate(&input);
//!# Ok(())
//!# }
//! # Ok(())
//! # }
//! ```
//!
//! Output: \
//! ```no_run
//!# let output =
//! # let output =
//! "Il s'agit d'une phrase à traduire"
//!# ;
//!```
//! # ;
//! ```
//!
//! #### 3. Summarization
//! Abstractive summarization of texts based on the BART encoder-decoder architecture
//! Include techniques such as beam search, top-k and nucleus sampling, temperature setting and repetition penalty.
//!
//! ```no_run
//!# fn main() -> failure::Fallible<()> {
//!# use rust_bert::pipelines::generation::LanguageGenerator;
//! # fn main() -> failure::Fallible<()> {
//! # use rust_bert::pipelines::generation::LanguageGenerator;
//! use rust_bert::pipelines::summarization::SummarizationModel;
//!
//! let mut model = SummarizationModel::new(Default::default())?;
//!
//! let input = ["In findings published Tuesday in Cornell University's arXiv by a team of scientists
//!from the University of Montreal and a separate report published Wednesday in Nature Astronomy by a team
//!from University College London (UCL), the presence of water vapour was confirmed in the atmosphere of K2-18b,
//!a planet circling a star in the constellation Leo. This is the first such discovery in a planet in its star's
//!habitable zone — not too hot and not too cold for liquid water to exist. The Montreal team, led by Björn Benneke,
//!used data from the NASA's Hubble telescope to assess changes in the light coming from K2-18b's star as the planet
//!passed between it and Earth. They found that certain wavelengths of light, which are usually absorbed by water,
//!weakened when the planet was in the way, indicating not only does K2-18b have an atmosphere, but the atmosphere
//!contains water in vapour form. The team from UCL then analyzed the Montreal team's data using their own software
//!and confirmed their conclusion. This was not the first time scientists have found signs of water on an exoplanet,
//!but previous discoveries were made on planets with high temperatures or other pronounced differences from Earth.
//!\"This is the first potentially habitable planet where the temperature is right and where we now know there is water,\"
//!said UCL astronomer Angelos Tsiaras. \"It's the best candidate for habitability right now.\" \"It's a good sign\",
//!said Ryan Cloutier of the HarvardSmithsonian Center for Astrophysics, who was not one of either study's authors.
//!\"Overall,\" he continued, \"the presence of water in its atmosphere certainly improves the prospect of K2-18b being
//!a potentially habitable planet, but further observations will be required to say for sure. \"
//!K2-18b was first identified in 2015 by the Kepler space telescope. It is about 110 light-years from Earth and larger
//!but less dense. Its star, a red dwarf, is cooler than the Sun, but the planet's orbit is much closer, such that a year
//!on K2-18b lasts 33 Earth days. According to The Guardian, astronomers were optimistic that NASA's James Webb space
//!telescope — scheduled for launch in 2021 — and the European Space Agency's 2028 ARIEL program, could reveal more
//!about exoplanets like K2-18b."];
//! from the University of Montreal and a separate report published Wednesday in Nature Astronomy by a team
//! from University College London (UCL), the presence of water vapour was confirmed in the atmosphere of K2-18b,
//! a planet circling a star in the constellation Leo. This is the first such discovery in a planet in its star's
//! habitable zone — not too hot and not too cold for liquid water to exist. The Montreal team, led by Björn Benneke,
//! used data from the NASA's Hubble telescope to assess changes in the light coming from K2-18b's star as the planet
//! passed between it and Earth. They found that certain wavelengths of light, which are usually absorbed by water,
//! weakened when the planet was in the way, indicating not only does K2-18b have an atmosphere, but the atmosphere
//! contains water in vapour form. The team from UCL then analyzed the Montreal team's data using their own software
//! and confirmed their conclusion. This was not the first time scientists have found signs of water on an exoplanet,
//! but previous discoveries were made on planets with high temperatures or other pronounced differences from Earth.
//! \"This is the first potentially habitable planet where the temperature is right and where we now know there is water,\"
//! said UCL astronomer Angelos Tsiaras. \"It's the best candidate for habitability right now.\" \"It's a good sign\",
//! said Ryan Cloutier of the HarvardSmithsonian Center for Astrophysics, who was not one of either study's authors.
//! \"Overall,\" he continued, \"the presence of water in its atmosphere certainly improves the prospect of K2-18b being
//! a potentially habitable planet, but further observations will be required to say for sure. \"
//! K2-18b was first identified in 2015 by the Kepler space telescope. It is about 110 light-years from Earth and larger
//! but less dense. Its star, a red dwarf, is cooler than the Sun, but the planet's orbit is much closer, such that a year
//! on K2-18b lasts 33 Earth days. According to The Guardian, astronomers were optimistic that NASA's James Webb space
//! telescope — scheduled for launch in 2021 — and the European Space Agency's 2028 ARIEL program, could reveal more
//! about exoplanets like K2-18b."];
//!
//! let output = model.summarize(&input);
//!# Ok(())
//!# }
//! # Ok(())
//! # }
//! ```
//! (example from: [WikiNews](https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b))
//!
//! Example output: \
//! ```no_run
//!# let output =
//! # let output =
//! "Scientists have found water vapour on K2-18b, a planet 110 light-years from Earth.
//! This is the first such discovery in a planet in its star's habitable zone.
//! The planet is not too hot and not too cold for liquid water to exist."
//!# ;
//!```
//! # ;
//! ```
//!
//!
//! #### 4. Natural Language Generation
@ -124,18 +122,18 @@
//!
//! ```no_run
//! use rust_bert::pipelines::generation::GPT2Generator;
//!# fn main() -> failure::Fallible<()> {
//!# use rust_bert::pipelines::generation::LanguageGenerator;
//! # fn main() -> failure::Fallible<()> {
//! # use rust_bert::pipelines::generation::LanguageGenerator;
//! let mut model = GPT2Generator::new(Default::default())?;
//! let input_context_1 = "The dog";
//! let input_context_2 = "The cat was";
//! let output = model.generate(Some(vec!(input_context_1, input_context_2)), None);
//!# Ok(())
//!# }
//! let output = model.generate(Some(vec![input_context_1, input_context_2]), None);
//! # Ok(())
//! # }
//! ```
//! Example output: \
//! ```no_run
//!# let output =
//! # let output =
//! [
//! "The dog's owners, however, did not want to be named. According to the lawsuit, the animal's owner, a 29-year",
//! "The dog has always been part of the family. \"He was always going to be my dog and he was always looking out for me",
@ -144,14 +142,14 @@
//! "The cat was pulled from the street by two-year-old Jazmine.\"I didn't know what to do,\" she said",
//! "The cat was attacked by two stray dogs and was taken to a hospital. Two other cats were also injured in the attack and are being treated."
//! ]
//!# ;
//!```
//! # ;
//! ```
//!
//! #### 5. Sentiment analysis
//! Predicts the binary sentiment for a sentence. DistilBERT model finetuned on SST-2.
//! ```no_run
//! use rust_bert::pipelines::sentiment::SentimentModel;
//!# fn main() -> failure::Fallible<()> {
//! # fn main() -> failure::Fallible<()> {
//! let sentiment_model = SentimentModel::new(Default::default())?;
//! let input = [
//! "Probably my all-time favorite movie, a story of selflessness, sacrifice and dedication to a noble cause, but it's not preachy or boring.",
@ -159,59 +157,84 @@
//! "If you like original gut wrenching laughter you will like this movie. If you are young or old then you will love this movie, hell even my mom liked it.",
//! ];
//! let output = sentiment_model.predict(&input);
//!# Ok(())
//!# }
//! # Ok(())
//! # }
//! ```
//! (Example courtesy of [IMDb](http://www.imdb.com))
//!
//! Output: \
//! ```no_run
//!# use rust_bert::pipelines::sentiment::Sentiment;
//!# use rust_bert::pipelines::sentiment::SentimentPolarity::{Positive, Negative};
//!# let output =
//! # use rust_bert::pipelines::sentiment::Sentiment;
//! # use rust_bert::pipelines::sentiment::SentimentPolarity::{Positive, Negative};
//! # let output =
//! [
//! Sentiment { polarity: Positive, score: 0.998 },
//! Sentiment { polarity: Negative, score: 0.992 },
//! Sentiment { polarity: Positive, score: 0.999 }
//! Sentiment {
//! polarity: Positive,
//! score: 0.998,
//! },
//! Sentiment {
//! polarity: Negative,
//! score: 0.992,
//! },
//! Sentiment {
//! polarity: Positive,
//! score: 0.999,
//! },
//! ]
//!# ;
//! # ;
//! ```
//!
//! #### 6. Named Entity Recognition
//! Extracts entities (Person, Location, Organization, Miscellaneous) from text. BERT cased large model finetuned on CoNNL03, contributed by the [MDZ Digital Library team at the Bavarian State Library](https://github.com/dbmdz)
//! ```no_run
//! use rust_bert::pipelines::ner::NERModel;
//!# fn main() -> failure::Fallible<()> {
//! # fn main() -> failure::Fallible<()> {
//! let ner_model = NERModel::new(Default::default())?;
//! let input = [
//! "My name is Amy. I live in Paris.",
//! "Paris is a city in France."
//! "Paris is a city in France.",
//! ];
//! let output = ner_model.predict(&input);
//!# Ok(())
//!# }
//! # Ok(())
//! # }
//! ```
//! Output: \
//! ```no_run
//!# use rust_bert::pipelines::question_answering::Answer;
//!# use rust_bert::pipelines::ner::Entity;
//!# let output =
//! # use rust_bert::pipelines::question_answering::Answer;
//! # use rust_bert::pipelines::ner::Entity;
//! # let output =
//! [
//! Entity { word: String::from("Amy"), score: 0.9986, label: String::from("I-PER") },
//! Entity { word: String::from("Paris"), score: 0.9985, label: String::from("I-LOC") },
//! Entity { word: String::from("Paris"), score: 0.9988, label: String::from("I-LOC") },
//! Entity { word: String::from("France"), score: 0.9993, label: String::from("I-LOC") },
//! Entity {
//! word: String::from("Amy"),
//! score: 0.9986,
//! label: String::from("I-PER"),
//! },
//! Entity {
//! word: String::from("Paris"),
//! score: 0.9985,
//! label: String::from("I-LOC"),
//! },
//! Entity {
//! word: String::from("Paris"),
//! score: 0.9988,
//! label: String::from("I-LOC"),
//! },
//! Entity {
//! word: String::from("France"),
//! score: 0.9993,
//! label: String::from("I-LOC"),
//! },
//! ]
//!# ;
//! # ;
//! ```
//!
pub mod sentiment;
pub mod common;
pub mod token_classification;
pub mod sequence_classification;
pub mod generation;
pub mod ner;
pub mod question_answering;
pub mod generation;
pub mod sentiment;
pub mod sequence_classification;
pub mod summarization;
pub mod token_classification;
pub mod translation;

View File

@ -20,33 +20,48 @@
//!
//! ```no_run
//! use rust_bert::pipelines::ner::NERModel;
//!# fn main() -> failure::Fallible<()> {
//! # fn main() -> failure::Fallible<()> {
//! let ner_model = NERModel::new(Default::default())?;
//!
//! let input = [
//! "My name is Amy. I live in Paris.",
//! "Paris is a city in France."
//! "Paris is a city in France.",
//! ];
//! let output = ner_model.predict(&input);
//!# Ok(())
//!# }
//! # Ok(())
//! # }
//! ```
//! Output: \
//! ```no_run
//!# use rust_bert::pipelines::question_answering::Answer;
//!# use rust_bert::pipelines::ner::Entity;
//!# let output =
//! # use rust_bert::pipelines::question_answering::Answer;
//! # use rust_bert::pipelines::ner::Entity;
//! # let output =
//! [
//! Entity { word: String::from("Amy"), score: 0.9986, label: String::from("I-PER") },
//! Entity { word: String::from("Paris"), score: 0.9985, label: String::from("I-LOC") },
//! Entity { word: String::from("Paris"), score: 0.9988, label: String::from("I-LOC") },
//! Entity { word: String::from("France"), score: 0.9993, label: String::from("I-LOC") },
//! Entity {
//! word: String::from("Amy"),
//! score: 0.9986,
//! label: String::from("I-PER"),
//! },
//! Entity {
//! word: String::from("Paris"),
//! score: 0.9985,
//! label: String::from("I-LOC"),
//! },
//! Entity {
//! word: String::from("Paris"),
//! score: 0.9988,
//! label: String::from("I-LOC"),
//! },
//! Entity {
//! word: String::from("France"),
//! score: 0.9993,
//! label: String::from("I-LOC"),
//! },
//! ]
//!# ;
//! # ;
//! ```
use crate::pipelines::token_classification::{TokenClassificationModel, TokenClassificationConfig};
use crate::pipelines::token_classification::{TokenClassificationConfig, TokenClassificationModel};
#[derive(Debug)]
/// # Entity generated by a `NERModel`
@ -64,7 +79,7 @@ type NERConfig = TokenClassificationConfig;
/// # NERModel to extract named entities
pub struct NERModel {
token_classification_model: TokenClassificationModel
token_classification_model: TokenClassificationModel,
}
impl NERModel {
@ -77,17 +92,18 @@ impl NERModel {
/// # Example
///
/// ```no_run
///# fn main() -> failure::Fallible<()> {
/// # fn main() -> failure::Fallible<()> {
/// use rust_bert::pipelines::ner::NERModel;
///
/// let ner_model = NERModel::new(Default::default())?;
///# Ok(())
///# }
/// let ner_model = NERModel::new(Default::default())?;
/// # Ok(())
/// # }
/// ```
///
pub fn new(ner_config: NERConfig) -> failure::Fallible<NERModel> {
let model = TokenClassificationModel::new(ner_config)?;
Ok(NERModel { token_classification_model: model })
Ok(NERModel {
token_classification_model: model,
})
}
/// Extract entities from a text
@ -103,30 +119,28 @@ impl NERModel {
/// # Example
///
/// ```no_run
///# fn main() -> failure::Fallible<()> {
///# use rust_bert::pipelines::ner::NERModel;
/// # fn main() -> failure::Fallible<()> {
/// # use rust_bert::pipelines::ner::NERModel;
///
/// let ner_model = NERModel::new(Default::default())?;
/// let ner_model = NERModel::new(Default::default())?;
/// let input = [
/// "My name is Amy. I live in Paris.",
/// "Paris is a city in France."
/// "Paris is a city in France.",
/// ];
/// let output = ner_model.predict(&input);
///# Ok(())
///# }
/// # Ok(())
/// # }
/// ```
///
pub fn predict(&self, input: &[&str]) -> Vec<Entity> {
self.token_classification_model
.predict(input, true, false)
.into_iter()
.filter(|token| token.label != "O")
.map(|token| {
Entity {
word: token.text,
score: token.score,
label: token.label,
}
}).collect()
.map(|token| Entity {
word: token.text,
score: token.score,
label: token.label,
})
.collect()
}
}

View File

@ -17,48 +17,48 @@
//! The dependencies will be downloaded to the user's home directory, under ~/.cache/.rustbert/distilbert-qa
//!
//! ```no_run
//! use rust_bert::pipelines::question_answering::{QuestionAnsweringModel, QaInput};
//! use rust_bert::pipelines::question_answering::{QaInput, QuestionAnsweringModel};
//!
//!# fn main() -> failure::Fallible<()> {
//! # fn main() -> failure::Fallible<()> {
//! let qa_model = QuestionAnsweringModel::new(Default::default())?;
//!
//! let question = String::from("Where does Amy live ?");
//! let context = String::from("Amy lives in Amsterdam");
//!
//! let answers = qa_model.predict(&vec!(QaInput { question, context }), 1, 32);
//!# Ok(())
//!# }
//! let answers = qa_model.predict(&vec![QaInput { question, context }], 1, 32);
//! # Ok(())
//! # }
//! ```
//!
//! Output: \
//! ```no_run
//!# use rust_bert::pipelines::question_answering::Answer;
//!# let output =
//! [
//! Answer {
//! score: 0.9976,
//! start: 13,
//! end: 21,
//! answer: "Amsterdam"
//!# .to_owned()
//! }
//! ]
//!# ;
//! # use rust_bert::pipelines::question_answering::Answer;
//! # let output =
//! [Answer {
//! score: 0.9976,
//! start: 13,
//! end: 21,
//! answer: "Amsterdam", //#### # .to_owned()
//! }]
//! # ;
//! ```
use rust_tokenizers::{BertTokenizer, Tokenizer, TruncationStrategy, TokenizedInput};
use rust_tokenizers::preprocessing::tokenizer::base_tokenizer::Mask;
use tch::{Device, Tensor, no_grad};
use std::path::PathBuf;
use rust_tokenizers::tokenization_utils::truncate_sequences;
use std::collections::HashMap;
use std::cmp::min;
use tch::nn::VarStore;
use tch::kind::Kind::Float;
use std::fs;
use crate::common::resources::{download_resource, RemoteResource, Resource};
use crate::distilbert::{
DistilBertConfig, DistilBertConfigResources, DistilBertForQuestionAnswering,
DistilBertModelResources, DistilBertVocabResources,
};
use crate::Config;
use crate::distilbert::{DistilBertForQuestionAnswering, DistilBertConfig, DistilBertConfigResources, DistilBertModelResources, DistilBertVocabResources};
use crate::common::resources::{Resource, RemoteResource, download_resource};
use rust_tokenizers::preprocessing::tokenizer::base_tokenizer::Mask;
use rust_tokenizers::tokenization_utils::truncate_sequences;
use rust_tokenizers::{BertTokenizer, TokenizedInput, Tokenizer, TruncationStrategy};
use std::cmp::min;
use std::collections::HashMap;
use std::fs;
use std::path::PathBuf;
use tch::kind::Kind::Float;
use tch::nn::VarStore;
use tch::{no_grad, Device, Tensor};
/// # Input for Question Answering
/// Includes a context (containing the answer) and question strings
@ -84,7 +84,6 @@ struct QaFeature {
pub token_to_orig_map: HashMap<i64, i64>,
pub p_mask: Vec<i8>,
pub example_index: i64,
}
#[derive(Debug, Clone)]
@ -102,34 +101,38 @@ pub struct Answer {
impl PartialEq for Answer {
fn eq(&self, other: &Self) -> bool {
(self.start == other.start) &&
(self.end == other.end) &&
(self.answer == other.answer)
(self.start == other.start) && (self.end == other.end) && (self.answer == other.answer)
}
}
fn remove_duplicates<T: PartialEq + Clone>(vector: &mut Vec<T>) -> &mut Vec<T> {
let mut potential_duplicates = vec!();
vector.retain(|item| if potential_duplicates.contains(item) {
false
} else {
potential_duplicates.push(item.clone());
true
let mut potential_duplicates = vec![];
vector.retain(|item| {
if potential_duplicates.contains(item) {
false
} else {
potential_duplicates.push(item.clone());
true
}
});
vector
}
impl QaExample {
pub fn new(question: &str, context: &str) -> QaExample {
let question = question.to_owned();
let (doc_tokens, char_to_word_offset) = QaExample::split_context(context);
QaExample { question, context: context.to_owned(), doc_tokens, char_to_word_offset }
QaExample {
question,
context: context.to_owned(),
doc_tokens,
char_to_word_offset,
}
}
fn split_context(context: &str) -> (Vec<String>, Vec<i64>) {
let mut doc_tokens: Vec<String> = vec!();
let mut char_to_word_offset: Vec<i64> = vec!();
let mut doc_tokens: Vec<String> = vec![];
let mut char_to_word_offset: Vec<i64> = vec![];
let max_length = context.len();
let mut current_word = String::with_capacity(max_length);
let mut previous_whitespace = false;
@ -158,11 +161,11 @@ impl QaExample {
}
fn is_whitespace(character: &char) -> bool {
(character == &' ') |
(character == &'\t') |
(character == &'\r') |
(character == &'\n') |
(*character as u32 == 0x202F)
(character == &' ')
| (character == &'\t')
| (character == &'\r')
| (character == &'\n')
| (*character as u32 == 0x202F)
}
}
@ -182,9 +185,15 @@ pub struct QuestionAnsweringConfig {
impl Default for QuestionAnsweringConfig {
fn default() -> QuestionAnsweringConfig {
QuestionAnsweringConfig {
model_resource: Resource::Remote(RemoteResource::from_pretrained(DistilBertModelResources::DISTIL_BERT_SQUAD)),
config_resource: Resource::Remote(RemoteResource::from_pretrained(DistilBertConfigResources::DISTIL_BERT_SQUAD)),
vocab_resource: Resource::Remote(RemoteResource::from_pretrained(DistilBertVocabResources::DISTIL_BERT_SQUAD)),
model_resource: Resource::Remote(RemoteResource::from_pretrained(
DistilBertModelResources::DISTIL_BERT_SQUAD,
)),
config_resource: Resource::Remote(RemoteResource::from_pretrained(
DistilBertConfigResources::DISTIL_BERT_SQUAD,
)),
vocab_resource: Resource::Remote(RemoteResource::from_pretrained(
DistilBertVocabResources::DISTIL_BERT_SQUAD,
)),
device: Device::cuda_if_available(),
}
}
@ -213,26 +222,33 @@ impl QuestionAnsweringModel {
/// # Example
///
/// ```no_run
///# fn main() -> failure::Fallible<()> {
/// # fn main() -> failure::Fallible<()> {
/// use rust_bert::pipelines::question_answering::QuestionAnsweringModel;
///
/// let qa_model = QuestionAnsweringModel::new(Default::default())?;
///# Ok(())
///# }
/// let qa_model = QuestionAnsweringModel::new(Default::default())?;
/// # Ok(())
/// # }
/// ```
///
pub fn new(question_answering_config: QuestionAnsweringConfig) -> failure::Fallible<QuestionAnsweringModel> {
pub fn new(
question_answering_config: QuestionAnsweringConfig,
) -> failure::Fallible<QuestionAnsweringModel> {
let config_path = download_resource(&question_answering_config.config_resource)?;
let vocab_path = download_resource(&question_answering_config.vocab_resource)?;
let weights_path = download_resource(&question_answering_config.model_resource)?;
let device = question_answering_config.device;
let tokenizer = BertTokenizer::from_file(vocab_path.to_str().unwrap(), false);
let pad_idx = *Tokenizer::vocab(&tokenizer).special_values.get("[PAD]").expect("[PAD] token not found in vocabulary");
let sep_idx = *Tokenizer::vocab(&tokenizer).special_values.get("[SEP]").expect("[SEP] token not found in vocabulary");
let pad_idx = *Tokenizer::vocab(&tokenizer)
.special_values
.get("[PAD]")
.expect("[PAD] token not found in vocabulary");
let sep_idx = *Tokenizer::vocab(&tokenizer)
.special_values
.get("[SEP]")
.expect("[SEP] token not found in vocabulary");
let mut var_store = VarStore::new(device);
let mut config = DistilBertConfig::from_file(config_path);
// The config for the current pre-trained question answering model indicates position embeddings which does not seem accurate
// The config for the current pre-trained question answering model indicates position embeddings which does not seem accurate
config.sinusoidal_pos_embds = false;
let distilbert_qa = DistilBertForQuestionAnswering::new(&var_store.root(), &config);
var_store.load(weights_path)?;
@ -249,7 +265,6 @@ impl QuestionAnsweringModel {
})
}
/// Perform extractive question answering given a list of `QaInputs`
///
/// # Arguments
@ -264,25 +279,35 @@ impl QuestionAnsweringModel {
/// # Example
///
/// ```no_run
///# fn main() -> failure::Fallible<()> {
/// use rust_bert::pipelines::question_answering::{QuestionAnsweringModel, QaInput};
/// # fn main() -> failure::Fallible<()> {
/// use rust_bert::pipelines::question_answering::{QaInput, QuestionAnsweringModel};
///
/// let qa_model = QuestionAnsweringModel::new(Default::default())?;
/// let qa_model = QuestionAnsweringModel::new(Default::default())?;
///
/// let question_1 = String::from("Where does Amy live ?");
/// let context_1 = String::from("Amy lives in Amsterdam");
/// let question_2 = String::from("Where does Eric live");
/// let context_2 = String::from("While Amy lives in Amsterdam, Eric is in The Hague.");
///
/// let qa_input_1 = QaInput { question: question_1, context: context_1 };
/// let qa_input_2 = QaInput { question: question_2, context: context_2 };
/// let qa_input_1 = QaInput {
/// question: question_1,
/// context: context_1,
/// };
/// let qa_input_2 = QaInput {
/// question: question_2,
/// context: context_2,
/// };
/// let answers = qa_model.predict(&[qa_input_1, qa_input_2], 1, 32);
///
///# Ok(())
///# }
/// # Ok(())
/// # }
/// ```
///
pub fn predict(&self, qa_inputs: &[QaInput], top_k: i64, batch_size: usize) -> Vec<Vec<Answer>> {
pub fn predict(
&self,
qa_inputs: &[QaInput],
top_k: i64,
batch_size: usize,
) -> Vec<Vec<Answer>> {
let examples: Vec<QaExample> = qa_inputs
.iter()
.map(|qa_input| QaExample::new(&qa_input.question, &qa_input.context))
@ -290,7 +315,15 @@ impl QuestionAnsweringModel {
let features: Vec<QaFeature> = examples
.iter()
.enumerate()
.map(|(example_index, qa_example)| self.generate_features(&qa_example, self.max_seq_len, self.doc_stride, self.max_query_length, example_index as i64))
.map(|(example_index, qa_example)| {
self.generate_features(
&qa_example,
self.max_seq_len,
self.doc_stride,
self.max_query_length,
example_index as i64,
)
})
.flatten()
.collect();
@ -310,28 +343,36 @@ impl QuestionAnsweringModel {
}
let input_ids = Tensor::stack(&input_ids, 0).to(self.var_store.device());
let attention_masks = Tensor::stack(&attention_masks, 0).to(self.var_store.device());
let attention_masks =
Tensor::stack(&attention_masks, 0).to(self.var_store.device());
let (start_logits, end_logits, _, _) = self.distilbert_qa.forward_t(Some(input_ids), Some(attention_masks), None, false).unwrap();
let (start_logits, end_logits, _, _) = self
.distilbert_qa
.forward_t(Some(input_ids), Some(attention_masks), None, false)
.unwrap();
let start_logits = start_logits.detach();
let end_logits = end_logits.detach();
let example_index_to_feature_end_position: Vec<(usize, i64)> = batch_features
.iter()
.enumerate()
.map(|(feature_index, feature)| (feature.example_index as usize, feature_index as i64 + 1))
.map(|(feature_index, feature)| {
(feature.example_index as usize, feature_index as i64 + 1)
})
.collect();
let mut feature_id_start = 0;
for (example_id, max_feature_id) in example_index_to_feature_end_position {
let mut answers: Vec<Answer> = vec!();
let mut answers: Vec<Answer> = vec![];
let example = &examples[example_id];
for feature_idx in feature_id_start..max_feature_id {
let feature = &batch_features[feature_idx as usize];
let start = start_logits.get(feature_idx);
let end = end_logits.get(feature_idx);
let p_mask = (Tensor::of_slice(&feature.p_mask) - 1).abs().to_device(start.device());
let p_mask = (Tensor::of_slice(&feature.p_mask) - 1)
.abs()
.to_device(start.device());
let start: Tensor = start.exp() / start.exp().sum(Float) * &p_mask;
let end: Tensor = end.exp() / end.exp().sum(Float) * &p_mask;
@ -343,33 +384,42 @@ impl QuestionAnsweringModel {
let end_pos = feature.token_to_orig_map[&ends[idx]] as usize;
let answer = example.doc_tokens[start_pos..end_pos + 1].join(" ");
let start = example.char_to_word_offset
let start = example
.char_to_word_offset
.iter()
.position(|&v| v as usize == start_pos)
.unwrap();
let end = example.char_to_word_offset
let end = example
.char_to_word_offset
.iter()
.rposition(|&v| v as usize == end_pos)
.unwrap();
answers.push(Answer { score: scores[idx], start, end, answer });
answers.push(Answer {
score: scores[idx],
start,
end,
answer,
});
}
}
feature_id_start = max_feature_id;
let example_answers = example_top_k_answers_map.entry(example_id).or_insert(vec!());
let example_answers = example_top_k_answers_map
.entry(example_id)
.or_insert(vec![]);
example_answers.extend(answers);
}
});
start = end;
}
let mut all_answers = vec!();
let mut all_answers = vec![];
for example_id in 0..examples.len() {
if let Some(answers) = example_top_k_answers_map.get_mut(&example_id) {
remove_duplicates(answers).sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
all_answers.push(answers[..min(answers.len(), top_k as usize)].to_vec());
} else {
all_answers.push(vec!());
all_answers.push(vec![]);
}
}
all_answers
@ -379,7 +429,10 @@ impl QuestionAnsweringModel {
let outer = start.unsqueeze(-1).matmul(&end.unsqueeze(0));
let start_dim = start.size()[0];
let end_dim = end.size()[0];
let candidates = outer.triu(0).tril(self.max_answer_len as i64 - 1).flatten(0, -1);
let candidates = outer
.triu(0)
.tril(self.max_answer_len as i64 - 1)
.flatten(0, -1);
let idx_sort = if top_k == 1 {
candidates.argmax(0, true)
} else if candidates.size()[0] < top_k {
@ -387,9 +440,9 @@ impl QuestionAnsweringModel {
} else {
candidates.argsort(0, true).slice(0, 0, top_k, 1)
};
let mut start: Vec<i64> = vec!();
let mut end: Vec<i64> = vec!();
let mut scores: Vec<f64> = vec!();
let mut start: Vec<i64> = vec![];
let mut end: Vec<i64> = vec![];
let mut scores: Vec<f64> = vec![];
for flat_index_position in 0..idx_sort.size()[0] {
let flat_index = idx_sort.int64_value(&[flat_index_position]);
scores.push(candidates.double_value(&[flat_index]));
@ -399,10 +452,16 @@ impl QuestionAnsweringModel {
(start, end, scores)
}
fn generate_features(&self, qa_example: &QaExample, max_seq_length: usize, doc_stride: usize, max_query_length: usize, example_index: i64) -> Vec<QaFeature> {
let mut tok_to_orig_index: Vec<i64> = vec!();
let mut all_doc_tokens: Vec<String> = vec!();
fn generate_features(
&self,
qa_example: &QaExample,
max_seq_length: usize,
doc_stride: usize,
max_query_length: usize,
example_index: i64,
) -> Vec<QaFeature> {
let mut tok_to_orig_index: Vec<i64> = vec![];
let mut all_doc_tokens: Vec<String> = vec![];
for (idx, token) in qa_example.doc_tokens.iter().enumerate() {
let sub_tokens = self.tokenizer.tokenize(token);
@ -414,28 +473,61 @@ impl QuestionAnsweringModel {
let truncated_query = self.prepare_query(&qa_example.question, max_query_length);
let sequence_added_tokens = self.tokenizer.build_input_with_special_tokens(vec!(), None, vec!(), None, vec!(), None, vec!(), None).0.len();
let sequence_pair_added_tokens = self.tokenizer.build_input_with_special_tokens(vec!(), Some(vec!()), vec!(), Some(vec!()), vec!(), Some(vec!()), vec!(), Some(vec!())).0.len();
let sequence_added_tokens = self
.tokenizer
.build_input_with_special_tokens(vec![], None, vec![], None, vec![], None, vec![], None)
.0
.len();
let sequence_pair_added_tokens = self
.tokenizer
.build_input_with_special_tokens(
vec![],
Some(vec![]),
vec![],
Some(vec![]),
vec![],
Some(vec![]),
vec![],
Some(vec![]),
)
.0
.len();
let mut spans: Vec<QaFeature> = vec!();
let mut spans: Vec<QaFeature> = vec![];
let mut remaining_tokens = self.tokenizer.convert_tokens_to_ids(&all_doc_tokens);
while (spans.len() * doc_stride as usize) < all_doc_tokens.len() {
let (encoded_span, attention_mask) = self.encode_qa_pair(&truncated_query, &remaining_tokens, max_seq_length, doc_stride, sequence_pair_added_tokens);
let (encoded_span, attention_mask) = self.encode_qa_pair(
&truncated_query,
&remaining_tokens,
max_seq_length,
doc_stride,
sequence_pair_added_tokens,
);
let paragraph_len = min(
all_doc_tokens.len() - spans.len() * doc_stride,
max_seq_length - truncated_query.len() - sequence_pair_added_tokens);
max_seq_length - truncated_query.len() - sequence_pair_added_tokens,
);
let mut token_to_orig_map = HashMap::new();
for i in 0..paragraph_len {
let index = truncated_query.len() + sequence_added_tokens + i;
token_to_orig_map.insert(index as i64, tok_to_orig_index[spans.len() * doc_stride + i] as i64);
token_to_orig_map.insert(
index as i64,
tok_to_orig_index[spans.len() * doc_stride + i] as i64,
);
}
let p_mask = self.get_mask(&encoded_span);
let qa_feature = QaFeature { input_ids: encoded_span.token_ids, attention_mask, token_to_orig_map, p_mask, example_index };
let qa_feature = QaFeature {
input_ids: encoded_span.token_ids,
attention_mask,
token_to_orig_map,
p_mask,
example_index,
};
spans.push(qa_feature);
if encoded_span.num_truncated_tokens == 0 {
@ -447,59 +539,81 @@ impl QuestionAnsweringModel {
}
fn prepare_query(&self, query: &str, max_query_length: usize) -> Vec<i64> {
let truncated_query = self.tokenizer.convert_tokens_to_ids(&self.tokenizer.tokenize(&query));
let num_query_tokens_to_remove = if truncated_query.len() > max_query_length as usize { truncated_query.len() - max_query_length } else { 0 };
let (truncated_query, _, _, _, _, _, _, _, _, _) = truncate_sequences(truncated_query,
None,
vec!(),
None,
vec!(),
None,
vec!(),
None,
num_query_tokens_to_remove,
&TruncationStrategy::OnlyFirst,
0).unwrap();
let truncated_query = self
.tokenizer
.convert_tokens_to_ids(&self.tokenizer.tokenize(&query));
let num_query_tokens_to_remove = if truncated_query.len() > max_query_length as usize {
truncated_query.len() - max_query_length
} else {
0
};
let (truncated_query, _, _, _, _, _, _, _, _, _) = truncate_sequences(
truncated_query,
None,
vec![],
None,
vec![],
None,
vec![],
None,
num_query_tokens_to_remove,
&TruncationStrategy::OnlyFirst,
0,
)
.unwrap();
truncated_query
}
fn encode_qa_pair(&self,
truncated_query: &Vec<i64>,
spans_token_ids: &Vec<i64>,
max_seq_length: usize,
doc_stride: usize,
sequence_pair_added_tokens: usize) -> (TokenizedInput, Vec<i64>) {
fn encode_qa_pair(
&self,
truncated_query: &Vec<i64>,
spans_token_ids: &Vec<i64>,
max_seq_length: usize,
doc_stride: usize,
sequence_pair_added_tokens: usize,
) -> (TokenizedInput, Vec<i64>) {
let len_1 = truncated_query.len();
let len_2 = spans_token_ids.len();
let total_len = len_1 + len_2 + sequence_pair_added_tokens;
let num_truncated_tokens = if total_len > max_seq_length { total_len - max_seq_length } else { 0 };
let num_truncated_tokens = if total_len > max_seq_length {
total_len - max_seq_length
} else {
0
};
let (truncated_query, truncated_context, _, _, _, _, _, _, overflowing_tokens, _)
= truncate_sequences(truncated_query.clone(),
Some(spans_token_ids.clone()),
vec!(),
None,
vec!(),
None,
vec!(),
None,
num_truncated_tokens,
&TruncationStrategy::OnlySecond,
max_seq_length - doc_stride - len_1 - sequence_pair_added_tokens).unwrap();
let (truncated_query, truncated_context, _, _, _, _, _, _, overflowing_tokens, _) =
truncate_sequences(
truncated_query.clone(),
Some(spans_token_ids.clone()),
vec![],
None,
vec![],
None,
vec![],
None,
num_truncated_tokens,
&TruncationStrategy::OnlySecond,
max_seq_length - doc_stride - len_1 - sequence_pair_added_tokens,
)
.unwrap();
let (mut token_ids,
let (
mut token_ids,
mut segment_ids,
special_tokens_mask,
mut token_offsets,
mut reference_offsets,
mut mask) = self.tokenizer.build_input_with_special_tokens(truncated_query,
truncated_context,
vec!(),
None,
vec!(),
None,
vec!(),
None);
mut mask,
) = self.tokenizer.build_input_with_special_tokens(
truncated_query,
truncated_context,
vec![],
None,
vec![],
None,
vec![],
None,
);
let mut attention_mask = vec![1; token_ids.len()];
if token_ids.len() < max_seq_length {
token_ids.append(&mut vec![self.pad_idx; max_seq_length - token_ids.len()]);
@ -509,18 +623,32 @@ impl QuestionAnsweringModel {
reference_offsets.append(&mut vec![vec!(); max_seq_length - token_offsets.len()]);
mask.append(&mut vec![Mask::Special; max_seq_length - mask.len()]);
}
(TokenizedInput { token_ids, segment_ids, special_tokens_mask, overflowing_tokens, num_truncated_tokens, token_offsets, reference_offsets, mask }, attention_mask)
(
TokenizedInput {
token_ids,
segment_ids,
special_tokens_mask,
overflowing_tokens,
num_truncated_tokens,
token_offsets,
reference_offsets,
mask,
},
attention_mask,
)
}
fn get_mask(&self, encoded_span: &TokenizedInput) -> Vec<i8> {
let sep_indices: Vec<usize> = encoded_span.token_ids
let sep_indices: Vec<usize> = encoded_span
.token_ids
.iter()
.enumerate()
.filter(|(_, &value)| value == self.sep_idx)
.map(|(position, _)| position)
.collect();
let mut p_mask: Vec<i8> = encoded_span.segment_ids
let mut p_mask: Vec<i8> = encoded_span
.segment_ids
.iter()
.map(|v| min(v, &1i8))
.map(|&v| 1i8 - v)
@ -534,10 +662,13 @@ impl QuestionAnsweringModel {
pub fn squad_processor(file_path: PathBuf) -> Vec<QaInput> {
let file = fs::File::open(file_path).expect("unable to open file");
let json: serde_json::Value = serde_json::from_reader(file).expect("JSON not properly formatted");
let json: serde_json::Value =
serde_json::from_reader(file).expect("JSON not properly formatted");
let data = json
.get("data").expect("SQuAD file does not contain data field")
.as_array().expect("Data array not properly formatted");
.get("data")
.expect("SQuAD file does not contain data field")
.as_array()
.expect("Data array not properly formatted");
let mut qa_inputs: Vec<QaInput> = Vec::with_capacity(data.len());
for qa_input in data.iter() {
@ -548,8 +679,17 @@ pub fn squad_processor(file_path: PathBuf) -> Vec<QaInput> {
let context = paragraph.get("context").unwrap().as_str().unwrap();
let qas = paragraph.get("qas").unwrap().as_array().unwrap();
for qa in qas.iter() {
let question = qa.as_object().unwrap().get("question").unwrap().as_str().unwrap();
qa_inputs.push(QaInput { question: question.to_owned(), context: context.to_owned() });
let question = qa
.as_object()
.unwrap()
.get("question")
.unwrap()
.as_str()
.unwrap();
qa_inputs.push(QaInput {
question: question.to_owned(),
context: context.to_owned(),
});
}
}
}

View File

@ -19,7 +19,7 @@
//! ```no_run
//! use rust_bert::pipelines::sentiment::SentimentModel;
//!
//!# fn main() -> failure::Fallible<()> {
//! # fn main() -> failure::Fallible<()> {
//! let sentiment_classifier = SentimentModel::new(Default::default())?;
//! let input = [
//! "Probably my all-time favorite movie, a story of selflessness, sacrifice and dedication to a noble cause, but it's not preachy or boring.",
@ -27,29 +27,40 @@
//! "If you like original gut wrenching laughter you will like this movie. If you are young or old then you will love this movie, hell even my mom liked it.",
//! ];
//! let output = sentiment_classifier.predict(&input);
//!# Ok(())
//!# }
//! # Ok(())
//! # }
//! ```
//! (Example courtesy of [IMDb](http://www.imdb.com))
//!
//! Output: \
//! ```no_run
//!# use rust_bert::pipelines::sentiment::Sentiment;
//!# use rust_bert::pipelines::sentiment::SentimentPolarity::{Positive, Negative};
//!# let output =
//! # use rust_bert::pipelines::sentiment::Sentiment;
//! # use rust_bert::pipelines::sentiment::SentimentPolarity::{Positive, Negative};
//! # let output =
//! [
//! Sentiment { polarity: Positive, score: 0.998 },
//! Sentiment { polarity: Negative, score: 0.992 },
//! Sentiment { polarity: Positive, score: 0.999 }
//! Sentiment {
//! polarity: Positive,
//! score: 0.998,
//! },
//! Sentiment {
//! polarity: Negative,
//! score: 0.992,
//! },
//! Sentiment {
//! polarity: Positive,
//! score: 0.999,
//! },
//! ]
//!# ;
//! # ;
//! ```
use std::path::PathBuf;
use std::fs;
use crate::pipelines::sequence_classification::{
SequenceClassificationConfig, SequenceClassificationModel,
};
use serde::Deserialize;
use std::error::Error;
use crate::pipelines::sequence_classification::{SequenceClassificationConfig, SequenceClassificationModel};
use std::fs;
use std::path::PathBuf;
#[derive(Debug, PartialEq)]
/// Enum with the possible sentiment polarities. Note that the pre-trained SST2 model does not include neutral sentiment.
@ -71,7 +82,7 @@ type SentimentConfig = SequenceClassificationConfig;
/// # SentimentClassifier to perform sentiment analysis
pub struct SentimentModel {
sequence_classification_model: SequenceClassificationModel
sequence_classification_model: SequenceClassificationModel,
}
impl SentimentModel {
@ -84,17 +95,18 @@ impl SentimentModel {
/// # Example
///
/// ```no_run
///# fn main() -> failure::Fallible<()> {
/// # fn main() -> failure::Fallible<()> {
/// use rust_bert::pipelines::sentiment::SentimentModel;
///
/// let sentiment_model = SentimentModel::new(Default::default())?;
///# Ok(())
///# }
/// let sentiment_model = SentimentModel::new(Default::default())?;
/// # Ok(())
/// # }
/// ```
///
pub fn new(sentiment_config: SentimentConfig) -> failure::Fallible<SentimentModel> {
let sequence_classification_model = SequenceClassificationModel::new(sentiment_config)?;
Ok(SentimentModel { sequence_classification_model })
Ok(SentimentModel {
sequence_classification_model,
})
}
/// Extract sentiment form an array of text inputs
@ -109,7 +121,7 @@ impl SentimentModel {
/// # Example
///
/// ```no_run
///# fn main() -> failure::Fallible<()> {
/// # fn main() -> failure::Fallible<()> {
/// use rust_bert::pipelines::sentiment::SentimentModel;
///
/// let sentiment_classifier = SentimentModel::new(Default::default())?;
@ -121,17 +133,23 @@ impl SentimentModel {
/// ];
///
/// let output = sentiment_classifier.predict(&input);
///# Ok(())
///# }
/// # Ok(())
/// # }
/// ```
///
pub fn predict(&self, input: &[&str]) -> Vec<Sentiment> {
let labels = self.sequence_classification_model.predict(input);
let mut sentiments = Vec::with_capacity(labels.len());
for label in labels {
let polarity = if label.id == 1 { SentimentPolarity::Positive } else { SentimentPolarity::Negative };
sentiments.push(Sentiment { polarity, score: label.score })
};
let polarity = if label.id == 1 {
SentimentPolarity::Positive
} else {
SentimentPolarity::Negative
};
sentiments.push(Sentiment {
polarity,
score: label.score,
})
}
sentiments
}
}
@ -154,4 +172,4 @@ pub fn ss2_processor(file_path: PathBuf) -> Result<Vec<String>, Box<dyn Error>>
records.push(record.sentence);
}
Ok(records)
}
}

View File

@ -19,7 +19,7 @@
//! use rust_bert::distilbert::{DistilBertModelResources, DistilBertVocabResources, DistilBertConfigResources};
//! use rust_bert::pipelines::sequence_classification::SequenceClassificationModel;
//! use rust_bert::pipelines::common::ModelType;
//!# fn main() -> failure::Fallible<()> {
//! # fn main() -> failure::Fallible<()> {
//!
//! //Load a configuration
//! let config = SequenceClassificationConfig::new(ModelType::DistilBert,
@ -39,34 +39,38 @@
//! "If you like original gut wrenching laughter you will like this movie. If you are young or old then you will love this movie, hell even my mom liked it.",
//! ];
//! let output = sequence_classification_model.predict(&input);
//!# Ok(())
//!# }
//! # Ok(())
//! # }
//! ```
//! (Example courtesy of [IMDb](http://www.imdb.com))
//!
//! Output: \
//! ```no_run
//!# use rust_bert::pipelines::sequence_classification::Label;
//! # use rust_bert::pipelines::sequence_classification::Label;
//! let output =
//! [
//! Label { text: String::from("POSITIVE"), score: 0.9986, id: 1, sentence: 0},
//! Label { text: String::from("NEGATIVE"), score: 0.9985, id: 0, sentence: 1},
//! Label { text: String::from("POSITIVE"), score: 0.9988, id: 1, sentence: 12},
//! ]
//!# ;
//! # ;
//! ```
//!
use tch::nn::VarStore;
use rust_tokenizers::preprocessing::tokenizer::base_tokenizer::{TokenizedInput, TruncationStrategy};
use std::collections::HashMap;
use tch::{Tensor, no_grad, Device, Kind};
use crate::bert::BertForSequenceClassification;
use crate::common::resources::{download_resource, RemoteResource, Resource};
use crate::distilbert::{
DistilBertConfigResources, DistilBertModelClassifier, DistilBertModelResources,
DistilBertVocabResources,
};
use crate::pipelines::common::{ConfigOption, ModelType, TokenizerOption};
use crate::roberta::RobertaForSequenceClassification;
use crate::distilbert::{DistilBertModelResources, DistilBertConfigResources, DistilBertVocabResources, DistilBertModelClassifier};
use crate::common::resources::{Resource, RemoteResource, download_resource};
use serde::{Serialize, Deserialize};
use crate::pipelines::common::{ModelType, ConfigOption, TokenizerOption};
use rust_tokenizers::preprocessing::tokenizer::base_tokenizer::{
TokenizedInput, TruncationStrategy,
};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use tch::nn::VarStore;
use tch::{no_grad, Device, Kind, Tensor};
#[derive(Debug, Serialize, Deserialize)]
/// # Label generated by a `SequenceClassificationModel`
@ -112,8 +116,14 @@ impl SequenceClassificationConfig {
/// * vocab - The `Resource' pointing to the tokenizer's vocabulary to load (e.g. vocab.txt/vocab.json)
/// * vocab - An optional `Resource` tuple (`Option<Resource>`) pointing to the tokenizer's merge file to load (e.g. merges.txt), needed only for Roberta.
/// * lower_case - A `bool' indicating whether the tokeniser should lower case all input (in case of a lower-cased model)
///
pub fn new(model_type: ModelType, model_resource: Resource, config_resource: Resource, vocab_resource: Resource, merges_resource: Option<Resource>, lower_case: bool) -> SequenceClassificationConfig {
pub fn new(
model_type: ModelType,
model_resource: Resource,
config_resource: Resource,
vocab_resource: Resource,
merges_resource: Option<Resource>,
lower_case: bool,
) -> SequenceClassificationConfig {
SequenceClassificationConfig {
model_type,
model_resource,
@ -131,9 +141,15 @@ impl Default for SequenceClassificationConfig {
fn default() -> SequenceClassificationConfig {
SequenceClassificationConfig {
model_type: ModelType::DistilBert,
model_resource: Resource::Remote(RemoteResource::from_pretrained(DistilBertModelResources::DISTIL_BERT_SST2)),
config_resource: Resource::Remote(RemoteResource::from_pretrained(DistilBertConfigResources::DISTIL_BERT_SST2)),
vocab_resource: Resource::Remote(RemoteResource::from_pretrained(DistilBertVocabResources::DISTIL_BERT_SST2)),
model_resource: Resource::Remote(RemoteResource::from_pretrained(
DistilBertModelResources::DISTIL_BERT_SST2,
)),
config_resource: Resource::Remote(RemoteResource::from_pretrained(
DistilBertConfigResources::DISTIL_BERT_SST2,
)),
vocab_resource: Resource::Remote(RemoteResource::from_pretrained(
DistilBertVocabResources::DISTIL_BERT_SST2,
)),
merges_resource: None,
lower_case: true,
device: Device::cuda_if_available(),
@ -160,31 +176,38 @@ impl SequenceClassificationOption {
/// * `p` - `tch::nn::Path` path to the model file to load (e.g. model.ot)
/// * `config` - A configuration (the model type of the configuration must be compatible with the value for
/// `model_type`)
///
pub fn new(model_type: ModelType, p: &tch::nn::Path, config: &ConfigOption) -> Self {
match model_type {
ModelType::Bert => {
if let ConfigOption::Bert(config) = config {
SequenceClassificationOption::Bert(BertForSequenceClassification::new(p, config))
SequenceClassificationOption::Bert(BertForSequenceClassification::new(
p, config,
))
} else {
panic!("You can only supply a BertConfig for Bert!");
}
}
ModelType::DistilBert => {
if let ConfigOption::DistilBert(config) = config {
SequenceClassificationOption::DistilBert(DistilBertModelClassifier::new(p, config))
SequenceClassificationOption::DistilBert(DistilBertModelClassifier::new(
p, config,
))
} else {
panic!("You can only supply a DistilBertConfig for DistilBert!");
}
}
ModelType::Roberta => {
if let ConfigOption::Bert(config) = config {
SequenceClassificationOption::Roberta(RobertaForSequenceClassification::new(p, config))
SequenceClassificationOption::Roberta(RobertaForSequenceClassification::new(
p, config,
))
} else {
panic!("You can only supply a BertConfig for Roberta!");
}
}
ModelType::Electra => { panic!("SequenceClassification not implemented for Electra!"); }
ModelType::Electra => {
panic!("SequenceClassification not implemented for Electra!");
}
}
}
@ -193,27 +216,44 @@ impl SequenceClassificationOption {
match *self {
Self::Bert(_) => ModelType::Bert,
Self::Roberta(_) => ModelType::Roberta,
Self::DistilBert(_) => ModelType::DistilBert
Self::DistilBert(_) => ModelType::DistilBert,
}
}
/// Interface method to forward_t() of the particular models.
pub fn forward_t(&self,
input_ids: Option<Tensor>,
mask: Option<Tensor>,
token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
pub fn forward_t(
&self,
input_ids: Option<Tensor>,
mask: Option<Tensor>,
token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool,
) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
match *self {
Self::Bert(ref model) => model.forward_t(input_ids, mask, token_type_ids, position_ids, input_embeds, train),
Self::DistilBert(ref model) => model.forward_t(input_ids, mask, input_embeds, train).expect("Error in distilbert forward_t"),
Self::Roberta(ref model) => model.forward_t(input_ids, mask, token_type_ids, position_ids, input_embeds, train),
Self::Bert(ref model) => model.forward_t(
input_ids,
mask,
token_type_ids,
position_ids,
input_embeds,
train,
),
Self::DistilBert(ref model) => model
.forward_t(input_ids, mask, input_embeds, train)
.expect("Error in distilbert forward_t"),
Self::Roberta(ref model) => model.forward_t(
input_ids,
mask,
token_type_ids,
position_ids,
input_embeds,
train,
),
}
}
}
/// # SequenceClassificationModel for Classification (e.g. Sentiment Analysis)
pub struct SequenceClassificationModel {
tokenizer: TokenizerOption,
@ -232,15 +272,16 @@ impl SequenceClassificationModel {
/// # Example
///
/// ```no_run
///# fn main() -> failure::Fallible<()> {
/// # fn main() -> failure::Fallible<()> {
/// use rust_bert::pipelines::sequence_classification::SequenceClassificationModel;
///
/// let model = SequenceClassificationModel::new(Default::default())?;
///# Ok(())
///# }
/// # Ok(())
/// # }
/// ```
///
pub fn new(config: SequenceClassificationConfig) -> failure::Fallible<SequenceClassificationModel> {
pub fn new(
config: SequenceClassificationConfig,
) -> failure::Fallible<SequenceClassificationModel> {
let config_path = download_resource(&config.config_resource)?;
let vocab_path = download_resource(&config.vocab_resource)?;
let weights_path = download_resource(&config.model_resource)?;
@ -251,31 +292,44 @@ impl SequenceClassificationModel {
};
let device = config.device;
let tokenizer = TokenizerOption::from_file(config.model_type, vocab_path.to_str().unwrap(), merges_path.map(|path| path.to_str().unwrap()), config.lower_case);
let tokenizer = TokenizerOption::from_file(
config.model_type,
vocab_path.to_str().unwrap(),
merges_path.map(|path| path.to_str().unwrap()),
config.lower_case,
);
let mut var_store = VarStore::new(device);
let model_config = ConfigOption::from_file(config.model_type, config_path);
let sequence_classifier = SequenceClassificationOption::new(config.model_type, &var_store.root(), &model_config);
let sequence_classifier =
SequenceClassificationOption::new(config.model_type, &var_store.root(), &model_config);
let label_mapping = model_config.get_label_mapping();
var_store.load(weights_path)?;
Ok(SequenceClassificationModel { tokenizer, sequence_classifier, label_mapping, var_store })
Ok(SequenceClassificationModel {
tokenizer,
sequence_classifier,
label_mapping,
var_store,
})
}
fn prepare_for_model(&self, input: Vec<&str>) -> Tensor {
let tokenized_input: Vec<TokenizedInput> = self.tokenizer.encode_list(input.to_vec(),
128,
&TruncationStrategy::LongestFirst,
0);
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
let tokenized_input_tensors: Vec<tch::Tensor> = tokenized_input.
iter().
map(|input| input.token_ids.clone()).
map(|mut input| {
let tokenized_input: Vec<TokenizedInput> =
self.tokenizer
.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let tokenized_input_tensors: Vec<tch::Tensor> = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]);
input
}).
map(|input|
Tensor::of_slice(&(input))).
collect::<Vec<_>>();
})
.map(|input| Tensor::of_slice(&(input)))
.collect::<Vec<_>>();
Tensor::stack(tokenized_input_tensors.as_slice(), 0).to(self.var_store.device())
}
@ -292,8 +346,8 @@ impl SequenceClassificationModel {
/// # Example
///
/// ```no_run
///# fn main() -> failure::Fallible<()> {
///# use rust_bert::pipelines::sequence_classification::SequenceClassificationModel;
/// # fn main() -> failure::Fallible<()> {
/// # use rust_bert::pipelines::sequence_classification::SequenceClassificationModel;
///
/// let sequence_classification_model = SequenceClassificationModel::new(Default::default())?;
/// let input = [
@ -302,29 +356,36 @@ impl SequenceClassificationModel {
/// "If you like original gut wrenching laughter you will like this movie. If you are young or old then you will love this movie, hell even my mom liked it.",
/// ];
/// let output = sequence_classification_model.predict(&input);
///# Ok(())
///# }
/// # Ok(())
/// # }
/// ```
pub fn predict(&self, input: &[&str]) -> Vec<Label> {
let input_tensor = self.prepare_for_model(input.to_vec());
let output = no_grad(|| {
let (output, _, _) = self.sequence_classifier
.forward_t(Some(input_tensor.copy()),
None,
None,
None,
None,
false);
let (output, _, _) = self.sequence_classifier.forward_t(
Some(input_tensor.copy()),
None,
None,
None,
None,
false,
);
output.softmax(-1, Kind::Float).detach().to(Device::Cpu)
});
let label_indices = output.as_ref().argmax(-1, true).squeeze1(1);
let scores = output.gather(1, &label_indices.unsqueeze(-1), false).squeeze1(1);
let scores = output
.gather(1, &label_indices.unsqueeze(-1), false)
.squeeze1(1);
let label_indices = label_indices.iter::<i64>().unwrap().collect::<Vec<i64>>();
let scores = scores.iter::<f64>().unwrap().collect::<Vec<f64>>();
let mut labels: Vec<Label> = vec!();
let mut labels: Vec<Label> = vec![];
for sentence_idx in 0..label_indices.len() {
let label_string = self.label_mapping.get(&label_indices[sentence_idx]).unwrap().clone();
let label_string = self
.label_mapping
.get(&label_indices[sentence_idx])
.unwrap()
.clone();
let label = Label {
text: label_string,
score: scores[sentence_idx],
@ -350,8 +411,8 @@ impl SequenceClassificationModel {
/// # Example
///
/// ```no_run
///# fn main() -> failure::Fallible<()> {
///# use rust_bert::pipelines::sequence_classification::SequenceClassificationModel;
/// # fn main() -> failure::Fallible<()> {
/// # use rust_bert::pipelines::sequence_classification::SequenceClassificationModel;
///
/// let sequence_classification_model = SequenceClassificationModel::new(Default::default())?;
/// let input = [
@ -360,34 +421,37 @@ impl SequenceClassificationModel {
/// "If you like original gut wrenching laughter you will like this movie. If you are young or old then you will love this movie, hell even my mom liked it.",
/// ];
/// let output = sequence_classification_model.predict_multilabel(&input, 0.5);
///# Ok(())
///# }
/// # Ok(())
/// # }
/// ```
pub fn predict_multilabel(&self, input: &[&str], threshold: f64) -> Vec<Vec<Label>> {
let input_tensor = self.prepare_for_model(input.to_vec());
let output = no_grad(|| {
let (output, _, _) = self.sequence_classifier
.forward_t(Some(input_tensor.copy()),
None,
None,
None,
None,
false);
let (output, _, _) = self.sequence_classifier.forward_t(
Some(input_tensor.copy()),
None,
None,
None,
None,
false,
);
output.sigmoid().detach().to(Device::Cpu)
});
let label_indices = output.as_ref().ge(threshold).nonzero();
let mut labels: Vec<Vec<Label>> = vec!();
let mut sequence_labels: Vec<Label> = vec!();
let mut labels: Vec<Vec<Label>> = vec![];
let mut sequence_labels: Vec<Label> = vec![];
for sentence_idx in 0..label_indices.size()[0] {
let label_index_tensor = label_indices.get(sentence_idx);
let sentence_label = label_index_tensor.iter::<i64>().unwrap().collect::<Vec<i64>>();
let sentence_label = label_index_tensor
.iter::<i64>()
.unwrap()
.collect::<Vec<i64>>();
let (sentence, id) = (sentence_label[0], sentence_label[1]);
if sentence as usize > labels.len() {
labels.push(sequence_labels);
sequence_labels = vec!();
sequence_labels = vec![];
}
let score = output.double_value(sentence_label.as_slice());
let label_string = self.label_mapping.get(&id).unwrap().to_owned();
@ -398,7 +462,6 @@ impl SequenceClassificationModel {
sentence: sentence as usize,
};
sequence_labels.push(label);
}
if sequence_labels.len() > 0 {
labels.push(sequence_labels);

View File

@ -11,7 +11,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
//! # Summarization pipeline
//! Abstractive summarization of texts based on the BART encoder-decoder architecture
//! Include techniques such as beam search, top-k and nucleus sampling, temperature setting and repetition penalty.
@ -21,52 +20,54 @@
//!
//!
//! ```no_run
//!# fn main() -> failure::Fallible<()> {
//!# use rust_bert::pipelines::generation::LanguageGenerator;
//! # fn main() -> failure::Fallible<()> {
//! # use rust_bert::pipelines::generation::LanguageGenerator;
//! use rust_bert::pipelines::summarization::SummarizationModel;
//! let mut model = SummarizationModel::new(Default::default())?;
//!
//! let input = ["In findings published Tuesday in Cornell University's arXiv by a team of scientists
//!from the University of Montreal and a separate report published Wednesday in Nature Astronomy by a team
//!from University College London (UCL), the presence of water vapour was confirmed in the atmosphere of K2-18b,
//!a planet circling a star in the constellation Leo. This is the first such discovery in a planet in its star's
//!habitable zone — not too hot and not too cold for liquid water to exist. The Montreal team, led by Björn Benneke,
//!used data from the NASA's Hubble telescope to assess changes in the light coming from K2-18b's star as the planet
//!passed between it and Earth. They found that certain wavelengths of light, which are usually absorbed by water,
//!weakened when the planet was in the way, indicating not only does K2-18b have an atmosphere, but the atmosphere
//!contains water in vapour form. The team from UCL then analyzed the Montreal team's data using their own software
//!and confirmed their conclusion. This was not the first time scientists have found signs of water on an exoplanet,
//!but previous discoveries were made on planets with high temperatures or other pronounced differences from Earth.
//!\"This is the first potentially habitable planet where the temperature is right and where we now know there is water,\"
//!said UCL astronomer Angelos Tsiaras. \"It's the best candidate for habitability right now.\" \"It's a good sign\",
//!said Ryan Cloutier of the HarvardSmithsonian Center for Astrophysics, who was not one of either study's authors.
//!\"Overall,\" he continued, \"the presence of water in its atmosphere certainly improves the prospect of K2-18b being
//!a potentially habitable planet, but further observations will be required to say for sure. \"
//!K2-18b was first identified in 2015 by the Kepler space telescope. It is about 110 light-years from Earth and larger
//!but less dense. Its star, a red dwarf, is cooler than the Sun, but the planet's orbit is much closer, such that a year
//!on K2-18b lasts 33 Earth days. According to The Guardian, astronomers were optimistic that NASA's James Webb space
//!telescope — scheduled for launch in 2021 — and the European Space Agency's 2028 ARIEL program, could reveal more
//!about exoplanets like K2-18b."];
//! from the University of Montreal and a separate report published Wednesday in Nature Astronomy by a team
//! from University College London (UCL), the presence of water vapour was confirmed in the atmosphere of K2-18b,
//! a planet circling a star in the constellation Leo. This is the first such discovery in a planet in its star's
//! habitable zone — not too hot and not too cold for liquid water to exist. The Montreal team, led by Björn Benneke,
//! used data from the NASA's Hubble telescope to assess changes in the light coming from K2-18b's star as the planet
//! passed between it and Earth. They found that certain wavelengths of light, which are usually absorbed by water,
//! weakened when the planet was in the way, indicating not only does K2-18b have an atmosphere, but the atmosphere
//! contains water in vapour form. The team from UCL then analyzed the Montreal team's data using their own software
//! and confirmed their conclusion. This was not the first time scientists have found signs of water on an exoplanet,
//! but previous discoveries were made on planets with high temperatures or other pronounced differences from Earth.
//! \"This is the first potentially habitable planet where the temperature is right and where we now know there is water,\"
//! said UCL astronomer Angelos Tsiaras. \"It's the best candidate for habitability right now.\" \"It's a good sign\",
//! said Ryan Cloutier of the HarvardSmithsonian Center for Astrophysics, who was not one of either study's authors.
//! \"Overall,\" he continued, \"the presence of water in its atmosphere certainly improves the prospect of K2-18b being
//! a potentially habitable planet, but further observations will be required to say for sure. \"
//! K2-18b was first identified in 2015 by the Kepler space telescope. It is about 110 light-years from Earth and larger
//! but less dense. Its star, a red dwarf, is cooler than the Sun, but the planet's orbit is much closer, such that a year
//! on K2-18b lasts 33 Earth days. According to The Guardian, astronomers were optimistic that NASA's James Webb space
//! telescope — scheduled for launch in 2021 — and the European Space Agency's 2028 ARIEL program, could reveal more
//! about exoplanets like K2-18b."];
//!
//! let output = model.summarize(&input);
//!# Ok(())
//!# }
//! # Ok(())
//! # }
//! ```
//! (New sample credits: [WikiNews](https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b))
//!
//! Example output: \
//! ```no_run
//!# let output =
//! # let output =
//! "Scientists have found water vapour on K2-18b, a planet 110 light-years from Earth.
//! This is the first such discovery in a planet in its star's habitable zone.
//! The planet is not too hot and not too cold for liquid water to exist."
//!# ;
//!```
//! # ;
//! ```
use crate::bart::{
BartConfigResources, BartMergesResources, BartModelResources, BartVocabResources,
};
use crate::common::resources::{RemoteResource, Resource};
use crate::pipelines::generation::{BartGenerator, GenerateConfig, LanguageGenerator};
use tch::Device;
use crate::common::resources::{Resource, RemoteResource};
use crate::bart::{BartModelResources, BartConfigResources, BartVocabResources, BartMergesResources};
/// # Configuration for text summarization
/// Contains information regarding the model to load, mirrors the GenerationConfig, with a
@ -111,10 +112,18 @@ pub struct SummarizationConfig {
impl Default for SummarizationConfig {
fn default() -> SummarizationConfig {
SummarizationConfig {
model_resource: Resource::Remote(RemoteResource::from_pretrained(BartModelResources::BART_CNN)),
config_resource: Resource::Remote(RemoteResource::from_pretrained(BartConfigResources::BART_CNN)),
vocab_resource: Resource::Remote(RemoteResource::from_pretrained(BartVocabResources::BART_CNN)),
merges_resource: Resource::Remote(RemoteResource::from_pretrained(BartMergesResources::BART_CNN)),
model_resource: Resource::Remote(RemoteResource::from_pretrained(
BartModelResources::BART_CNN,
)),
config_resource: Resource::Remote(RemoteResource::from_pretrained(
BartConfigResources::BART_CNN,
)),
vocab_resource: Resource::Remote(RemoteResource::from_pretrained(
BartVocabResources::BART_CNN,
)),
merges_resource: Resource::Remote(RemoteResource::from_pretrained(
BartMergesResources::BART_CNN,
)),
min_length: 56,
max_length: 142,
do_sample: false,
@ -134,7 +143,7 @@ impl Default for SummarizationConfig {
/// # SummarizationModel to perform summarization
pub struct SummarizationModel {
model: BartGenerator
model: BartGenerator,
}
impl SummarizationModel {
@ -147,16 +156,14 @@ impl SummarizationModel {
/// # Example
///
/// ```no_run
///# fn main() -> failure::Fallible<()> {
/// # fn main() -> failure::Fallible<()> {
/// use rust_bert::pipelines::summarization::SummarizationModel;
///
/// let mut summarization_model = SummarizationModel::new(Default::default())?;
///# Ok(())
///# }
/// let mut summarization_model = SummarizationModel::new(Default::default())?;
/// # Ok(())
/// # }
/// ```
///
pub fn new(summarization_config: SummarizationConfig)
-> failure::Fallible<SummarizationModel> {
pub fn new(summarization_config: SummarizationConfig) -> failure::Fallible<SummarizationModel> {
let generate_config = GenerateConfig {
model_resource: summarization_config.model_resource,
config_resource: summarization_config.config_resource,
@ -194,40 +201,39 @@ impl SummarizationModel {
/// # Example
///
/// ```no_run
///# fn main() -> failure::Fallible<()> {
/// # fn main() -> failure::Fallible<()> {
/// use rust_bert::pipelines::generation::LanguageGenerator;
/// use rust_bert::pipelines::summarization::SummarizationModel;
/// let model = SummarizationModel::new(Default::default())?;
///
/// let input = ["In findings published Tuesday in Cornell University's arXiv by a team of scientists
///from the University of Montreal and a separate report published Wednesday in Nature Astronomy by a team
///from University College London (UCL), the presence of water vapour was confirmed in the atmosphere of K2-18b,
///a planet circling a star in the constellation Leo. This is the first such discovery in a planet in its star's
///habitable zone — not too hot and not too cold for liquid water to exist. The Montreal team, led by Björn Benneke,
///used data from the NASA's Hubble telescope to assess changes in the light coming from K2-18b's star as the planet
///passed between it and Earth. They found that certain wavelengths of light, which are usually absorbed by water,
///weakened when the planet was in the way, indicating not only does K2-18b have an atmosphere, but the atmosphere
///contains water in vapour form. The team from UCL then analyzed the Montreal team's data using their own software
///and confirmed their conclusion. This was not the first time scientists have found signs of water on an exoplanet,
///but previous discoveries were made on planets with high temperatures or other pronounced differences from Earth.
///\"This is the first potentially habitable planet where the temperature is right and where we now know there is water,\"
///said UCL astronomer Angelos Tsiaras. \"It's the best candidate for habitability right now.\" \"It's a good sign\",
///said Ryan Cloutier of the HarvardSmithsonian Center for Astrophysics, who was not one of either study's authors.
///\"Overall,\" he continued, \"the presence of water in its atmosphere certainly improves the prospect of K2-18b being
///a potentially habitable planet, but further observations will be required to say for sure. \"
///K2-18b was first identified in 2015 by the Kepler space telescope. It is about 110 light-years from Earth and larger
///but less dense. Its star, a red dwarf, is cooler than the Sun, but the planet's orbit is much closer, such that a year
///on K2-18b lasts 33 Earth days. According to The Guardian, astronomers were optimistic that NASA's James Webb space
///telescope — scheduled for launch in 2021 — and the European Space Agency's 2028 ARIEL program, could reveal more
///about exoplanets like K2-18b."];
/// from the University of Montreal and a separate report published Wednesday in Nature Astronomy by a team
/// from University College London (UCL), the presence of water vapour was confirmed in the atmosphere of K2-18b,
/// a planet circling a star in the constellation Leo. This is the first such discovery in a planet in its star's
/// habitable zone — not too hot and not too cold for liquid water to exist. The Montreal team, led by Björn Benneke,
/// used data from the NASA's Hubble telescope to assess changes in the light coming from K2-18b's star as the planet
/// passed between it and Earth. They found that certain wavelengths of light, which are usually absorbed by water,
/// weakened when the planet was in the way, indicating not only does K2-18b have an atmosphere, but the atmosphere
/// contains water in vapour form. The team from UCL then analyzed the Montreal team's data using their own software
/// and confirmed their conclusion. This was not the first time scientists have found signs of water on an exoplanet,
/// but previous discoveries were made on planets with high temperatures or other pronounced differences from Earth.
/// \"This is the first potentially habitable planet where the temperature is right and where we now know there is water,\"
/// said UCL astronomer Angelos Tsiaras. \"It's the best candidate for habitability right now.\" \"It's a good sign\",
/// said Ryan Cloutier of the HarvardSmithsonian Center for Astrophysics, who was not one of either study's authors.
/// \"Overall,\" he continued, \"the presence of water in its atmosphere certainly improves the prospect of K2-18b being
/// a potentially habitable planet, but further observations will be required to say for sure. \"
/// K2-18b was first identified in 2015 by the Kepler space telescope. It is about 110 light-years from Earth and larger
/// but less dense. Its star, a red dwarf, is cooler than the Sun, but the planet's orbit is much closer, such that a year
/// on K2-18b lasts 33 Earth days. According to The Guardian, astronomers were optimistic that NASA's James Webb space
/// telescope — scheduled for launch in 2021 — and the European Space Agency's 2028 ARIEL program, could reveal more
/// about exoplanets like K2-18b."];
///
/// let output = model.summarize(&input);
///# Ok(())
///# }
/// # Ok(())
/// # }
/// ```
/// (New sample credits: [WikiNews](https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b))
///
pub fn summarize(&self, texts: &[&str]) -> Vec<String> {
self.model.generate(Some(texts.to_vec()), None)
}
}
}

View File

@ -19,7 +19,7 @@
//! use rust_bert::resources::{Resource,RemoteResource};
//! use rust_bert::bert::{BertModelResources, BertVocabResources, BertConfigResources};
//! use rust_bert::pipelines::common::ModelType;
//!# fn main() -> failure::Fallible<()> {
//! # fn main() -> failure::Fallible<()> {
//!
//! //Load a configuration
//! use rust_bert::pipelines::token_classification::LabelAggregationOption;
@ -40,41 +40,94 @@
//! "Paris is a city in France."
//! ];
//! let output = token_classification_model.predict(&input, true, true); //ignore_first_label = true (only returns the NER parts, ignoring first label O)
//!# Ok(())
//!# }
//! # Ok(())
//! # }
//! ```
//! Output: \
//! ```no_run
//!# use rust_bert::pipelines::token_classification::Token;
//! # use rust_bert::pipelines::token_classification::Token;
//! use rust_tokenizers::preprocessing::tokenizer::base_tokenizer::Mask::Special;
//! use rust_tokenizers::preprocessing::tokenizer::base_tokenizer::{Offset, Mask};
//!# let output =
//! use rust_tokenizers::preprocessing::tokenizer::base_tokenizer::{Mask, Offset};
//! # let output =
//! [
//! Token { text: String::from("[CLS]"), score: 0.9995001554489136, label: String::from("O"), label_index: 0, sentence: 0, index: 0, word_index: 0, offset: None, mask: Special },
//! Token { text: String::from("My"), score: 0.9980450868606567, label: String::from("O"), label_index: 0, sentence: 0, index: 1, word_index: 1, offset: Some(Offset { begin: 0, end: 2 }), mask: Mask::None },
//! Token { text: String::from("name"), score: 0.9995062351226807, label: String::from("O"), label_index: 0, sentence: 0, index: 2, word_index: 2, offset: Some(Offset { begin: 3, end: 7 }), mask: Mask::None },
//! Token { text: String::from("is"), score: 0.9997343420982361, label: String::from("O"), label_index: 0, sentence: 0, index: 3, word_index: 3, offset: Some(Offset { begin: 8, end: 10 }), mask: Mask::None },
//! Token { text: String::from("Amélie"), score: 0.9913727683112525, label: String::from("I-PER"), label_index: 4, sentence: 0, index: 4, word_index: 4, offset: Some(Offset { begin: 11, end: 17 }), mask: Mask::None }
//! // ...
//! Token {
//! text: String::from("[CLS]"),
//! score: 0.9995001554489136,
//! label: String::from("O"),
//! label_index: 0,
//! sentence: 0,
//! index: 0,
//! word_index: 0,
//! offset: None,
//! mask: Special,
//! },
//! Token {
//! text: String::from("My"),
//! score: 0.9980450868606567,
//! label: String::from("O"),
//! label_index: 0,
//! sentence: 0,
//! index: 1,
//! word_index: 1,
//! offset: Some(Offset { begin: 0, end: 2 }),
//! mask: Mask::None,
//! },
//! Token {
//! text: String::from("name"),
//! score: 0.9995062351226807,
//! label: String::from("O"),
//! label_index: 0,
//! sentence: 0,
//! index: 2,
//! word_index: 2,
//! offset: Some(Offset { begin: 3, end: 7 }),
//! mask: Mask::None,
//! },
//! Token {
//! text: String::from("is"),
//! score: 0.9997343420982361,
//! label: String::from("O"),
//! label_index: 0,
//! sentence: 0,
//! index: 3,
//! word_index: 3,
//! offset: Some(Offset { begin: 8, end: 10 }),
//! mask: Mask::None,
//! },
//! Token {
//! text: String::from("Amélie"),
//! score: 0.9913727683112525,
//! label: String::from("I-PER"),
//! label_index: 4,
//! sentence: 0,
//! index: 4,
//! word_index: 4,
//! offset: Some(Offset { begin: 11, end: 17 }),
//! mask: Mask::None,
//! }, // ...
//! ]
//!# ;
//! # ;
//! ```
use tch::nn::VarStore;
use rust_tokenizers::preprocessing::tokenizer::base_tokenizer::{Tokenizer, TokenizedInput, TruncationStrategy, Mask, Offset, ConsolidatableTokens, ConsolidatedTokenIterator, TokenTrait};
use std::collections::HashMap;
use tch::{Tensor, no_grad, Device};
use tch::kind::Kind::Float;
use crate::bert::{BertForTokenClassification, BertModelResources, BertConfigResources, BertVocabResources};
use crate::roberta::RobertaForTokenClassification;
use crate::bert::{
BertConfigResources, BertForTokenClassification, BertModelResources, BertVocabResources,
};
use crate::common::resources::{download_resource, RemoteResource, Resource};
use crate::distilbert::DistilBertForTokenClassification;
use crate::common::resources::{Resource, RemoteResource, download_resource};
use crate::pipelines::common::{ModelType, ConfigOption, TokenizerOption};
use crate::electra::ElectraForTokenClassification;
use crate::pipelines::common::{ConfigOption, ModelType, TokenizerOption};
use crate::roberta::RobertaForTokenClassification;
use itertools::Itertools;
use rust_tokenizers::preprocessing::tokenizer::base_tokenizer::{
ConsolidatableTokens, ConsolidatedTokenIterator, Mask, Offset, TokenTrait, TokenizedInput,
Tokenizer, TruncationStrategy,
};
use serde::{Deserialize, Serialize};
use std::cmp::min;
use serde::{Serialize, Deserialize};
use std::collections::HashMap;
use tch::kind::Kind::Float;
use tch::nn::VarStore;
use tch::{no_grad, Device, Tensor};
#[derive(Debug, Clone, Serialize, Deserialize)]
/// # Token generated by a `TokenClassificationModel`
@ -140,7 +193,6 @@ pub enum LabelAggregationOption {
Custom(Box<dyn Fn(&[Token]) -> (i64, String)>),
}
/// # Configuration for TokenClassificationModel
/// Contains information regarding the model to load and device to place the model on.
pub struct TokenClassificationConfig {
@ -173,14 +225,15 @@ impl TokenClassificationConfig {
/// * vocab - The `Resource' pointing to the tokenizer's vocabulary to load (e.g. vocab.txt/vocab.json)
/// * vocab - An optional `Resource` tuple (`Option<Resource>`) pointing to the tokenizer's merge file to load (e.g. merges.txt), needed only for Roberta.
/// * lower_case - A `bool' indicating whether the tokenizer should lower case all input (in case of a lower-cased model)
///
pub fn new(model_type: ModelType,
model_resource: Resource,
config_resource: Resource,
vocab_resource: Resource,
merges_resource: Option<Resource>,
lower_case: bool,
label_aggregation_function: LabelAggregationOption) -> TokenClassificationConfig {
pub fn new(
model_type: ModelType,
model_resource: Resource,
config_resource: Resource,
vocab_resource: Resource,
merges_resource: Option<Resource>,
lower_case: bool,
label_aggregation_function: LabelAggregationOption,
) -> TokenClassificationConfig {
TokenClassificationConfig {
model_type,
model_resource,
@ -199,9 +252,15 @@ impl Default for TokenClassificationConfig {
fn default() -> TokenClassificationConfig {
TokenClassificationConfig {
model_type: ModelType::Bert,
model_resource: Resource::Remote(RemoteResource::from_pretrained(BertModelResources::BERT_NER)),
config_resource: Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT_NER)),
vocab_resource: Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT_NER)),
model_resource: Resource::Remote(RemoteResource::from_pretrained(
BertModelResources::BERT_NER,
)),
config_resource: Resource::Remote(RemoteResource::from_pretrained(
BertConfigResources::BERT_NER,
)),
vocab_resource: Resource::Remote(RemoteResource::from_pretrained(
BertVocabResources::BERT_NER,
)),
merges_resource: None,
lower_case: false,
device: Device::cuda_if_available(),
@ -231,7 +290,6 @@ impl TokenClassificationOption {
/// * `p` - `tch::nn::Path` path to the model file to load (e.g. model.ot)
/// * `config` - A configuration (the model type of the configuration must be compatible with the value for
/// `model_type`)
///
pub fn new(model_type: ModelType, p: &tch::nn::Path, config: &ConfigOption) -> Self {
match model_type {
ModelType::Bert => {
@ -243,21 +301,27 @@ impl TokenClassificationOption {
}
ModelType::DistilBert => {
if let ConfigOption::DistilBert(config) = config {
TokenClassificationOption::DistilBert(DistilBertForTokenClassification::new(p, config))
TokenClassificationOption::DistilBert(DistilBertForTokenClassification::new(
p, config,
))
} else {
panic!("You can only supply a DistilBertConfig for DistilBert!");
}
}
ModelType::Roberta => {
if let ConfigOption::Bert(config) = config {
TokenClassificationOption::Roberta(RobertaForTokenClassification::new(p, config))
TokenClassificationOption::Roberta(RobertaForTokenClassification::new(
p, config,
))
} else {
panic!("You can only supply a BertConfig for Roberta!");
}
}
ModelType::Electra => {
if let ConfigOption::Electra(config) = config {
TokenClassificationOption::Electra(ElectraForTokenClassification::new(p, config))
TokenClassificationOption::Electra(ElectraForTokenClassification::new(
p, config,
))
} else {
panic!("You can only supply a BertConfig for Roberta!");
}
@ -271,27 +335,51 @@ impl TokenClassificationOption {
Self::Bert(_) => ModelType::Bert,
Self::Roberta(_) => ModelType::Roberta,
Self::DistilBert(_) => ModelType::DistilBert,
Self::Electra(_) => ModelType::Electra
Self::Electra(_) => ModelType::Electra,
}
}
fn forward_t(&self,
input_ids: Option<Tensor>,
mask: Option<Tensor>,
token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
fn forward_t(
&self,
input_ids: Option<Tensor>,
mask: Option<Tensor>,
token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool,
) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
match *self {
Self::Bert(ref model) => model.forward_t(input_ids, mask, token_type_ids, position_ids, input_embeds, train),
Self::DistilBert(ref model) => model.forward_t(input_ids, mask, input_embeds, train).expect("Error in distilbert forward_t"),
Self::Roberta(ref model) => model.forward_t(input_ids, mask, token_type_ids, position_ids, input_embeds, train),
Self::Electra(ref model) => model.forward_t(input_ids, mask, token_type_ids, position_ids, input_embeds, train),
Self::Bert(ref model) => model.forward_t(
input_ids,
mask,
token_type_ids,
position_ids,
input_embeds,
train,
),
Self::DistilBert(ref model) => model
.forward_t(input_ids, mask, input_embeds, train)
.expect("Error in distilbert forward_t"),
Self::Roberta(ref model) => model.forward_t(
input_ids,
mask,
token_type_ids,
position_ids,
input_embeds,
train,
),
Self::Electra(ref model) => model.forward_t(
input_ids,
mask,
token_type_ids,
position_ids,
input_embeds,
train,
),
}
}
}
/// # TokenClassificationModel for Named Entity Recognition or Part-of-Speech tagging
pub struct TokenClassificationModel {
tokenizer: TokenizerOption,
@ -311,14 +399,13 @@ impl TokenClassificationModel {
/// # Example
///
/// ```no_run
///# fn main() -> failure::Fallible<()> {
/// # fn main() -> failure::Fallible<()> {
/// use rust_bert::pipelines::token_classification::TokenClassificationModel;
///
/// let model = TokenClassificationModel::new(Default::default())?;
///# Ok(())
///# }
/// # Ok(())
/// # }
/// ```
///
pub fn new(config: TokenClassificationConfig) -> failure::Fallible<TokenClassificationModel> {
let config_path = download_resource(&config.config_resource)?;
let vocab_path = download_resource(&config.vocab_resource)?;
@ -331,32 +418,49 @@ impl TokenClassificationModel {
let device = config.device;
let label_aggregation_function = config.label_aggregation_function;
let tokenizer = TokenizerOption::from_file(config.model_type, vocab_path.to_str().unwrap(), merges_path.map(|path| path.to_str().unwrap()), config.lower_case);
let tokenizer = TokenizerOption::from_file(
config.model_type,
vocab_path.to_str().unwrap(),
merges_path.map(|path| path.to_str().unwrap()),
config.lower_case,
);
let mut var_store = VarStore::new(device);
let model_config = ConfigOption::from_file(config.model_type, config_path);
let token_sequence_classifier = TokenClassificationOption::new(config.model_type, &var_store.root(), &model_config);
let token_sequence_classifier =
TokenClassificationOption::new(config.model_type, &var_store.root(), &model_config);
let label_mapping = model_config.get_label_mapping();
var_store.load(weights_path)?;
Ok(TokenClassificationModel { tokenizer, token_sequence_classifier, label_mapping, var_store, label_aggregation_function })
Ok(TokenClassificationModel {
tokenizer,
token_sequence_classifier,
label_mapping,
var_store,
label_aggregation_function,
})
}
fn prepare_for_model(&self, input: Vec<&str>) -> (Vec<TokenizedInput>, Tensor) {
let tokenized_input: Vec<TokenizedInput> = self.tokenizer.encode_list(input.to_vec(),
128,
&TruncationStrategy::LongestFirst,
0);
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
let tokenized_input_tensors: Vec<tch::Tensor> = tokenized_input.
iter().
map(|input| input.token_ids.clone()).
map(|mut input| {
let tokenized_input: Vec<TokenizedInput> =
self.tokenizer
.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let tokenized_input_tensors: Vec<tch::Tensor> = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]);
input
}).
map(|input|
Tensor::of_slice(&(input))).
collect::<Vec<_>>();
(tokenized_input, Tensor::stack(tokenized_input_tensors.as_slice(), 0).to(self.var_store.device()))
})
.map(|input| Tensor::of_slice(&(input)))
.collect::<Vec<_>>();
(
tokenized_input,
Tensor::stack(tokenized_input_tensors.as_slice(), 0).to(self.var_store.device()),
)
}
/// Classify tokens in a text sequence
@ -374,33 +478,39 @@ impl TokenClassificationModel {
/// # Example
///
/// ```no_run
///# fn main() -> failure::Fallible<()> {
///# use rust_bert::pipelines::token_classification::TokenClassificationModel;
/// # fn main() -> failure::Fallible<()> {
/// # use rust_bert::pipelines::token_classification::TokenClassificationModel;
///
/// let ner_model = TokenClassificationModel::new(Default::default())?;
/// let ner_model = TokenClassificationModel::new(Default::default())?;
/// let input = [
/// "My name is Amy. I live in Paris.",
/// "Paris is a city in France."
/// "Paris is a city in France.",
/// ];
/// let output = ner_model.predict(&input, true, true);
///# Ok(())
///# }
/// # Ok(())
/// # }
/// ```
pub fn predict(&self, input: &[&str], consolidate_sub_tokens: bool, return_special: bool) -> Vec<Token> {
pub fn predict(
&self,
input: &[&str],
consolidate_sub_tokens: bool,
return_special: bool,
) -> Vec<Token> {
let (tokenized_input, input_tensor) = self.prepare_for_model(input.to_vec());
let (output, _, _) = no_grad(|| {
self.token_sequence_classifier
.forward_t(Some(input_tensor.copy()),
None,
None,
None,
None,
false)
self.token_sequence_classifier.forward_t(
Some(input_tensor.copy()),
None,
None,
None,
None,
false,
)
});
let output = output.detach().to(Device::Cpu);
let score: Tensor = output.exp() / output.exp().sum1(&[-1], true, Float);
let labels_idx = &score.argmax(-1, true);
let mut tokens: Vec<Token> = vec!();
let mut tokens: Vec<Token> = vec![];
for sentence_idx in 0..labels_idx.size()[0] {
let labels = labels_idx.get(sentence_idx);
let sentence_tokens = &tokenized_input[sentence_idx as usize];
@ -415,7 +525,16 @@ impl TokenClassificationModel {
word_idx += 1;
}
let token = {
self.decode_token(&original_chars, sentence_tokens, &input_tensor, &labels, &score, sentence_idx, position_idx as i64, word_idx - 1)
self.decode_token(
&original_chars,
sentence_tokens,
&input_tensor,
&labels,
&score,
sentence_idx,
position_idx as i64,
word_idx - 1,
)
};
tokens.push(token);
}
@ -426,8 +545,17 @@ impl TokenClassificationModel {
tokens
}
fn decode_token(&self, original_sentence_chars: &Vec<char>, sentence_tokens: &TokenizedInput, input_tensor: &Tensor,
labels: &Tensor, score: &Tensor, sentence_idx: i64, position_idx: i64, word_index: u16) -> Token {
fn decode_token(
&self,
original_sentence_chars: &Vec<char>,
sentence_tokens: &TokenizedInput,
input_tensor: &Tensor,
labels: &Tensor,
score: &Tensor,
sentence_idx: i64,
position_idx: i64,
word_index: u16,
) -> Token {
let label_id = labels.int64_value(&[position_idx as i64]);
let token_id = input_tensor.int64_value(&[sentence_idx, position_idx as i64]);
@ -435,13 +563,19 @@ impl TokenClassificationModel {
let text = match offsets {
None => match self.tokenizer {
TokenizerOption::Bert(ref tokenizer) => Tokenizer::decode(tokenizer, vec!(token_id), false, false),
TokenizerOption::Roberta(ref tokenizer) => Tokenizer::decode(tokenizer, vec!(token_id), false, false),
TokenizerOption::Bert(ref tokenizer) => {
Tokenizer::decode(tokenizer, vec![token_id], false, false)
}
TokenizerOption::Roberta(ref tokenizer) => {
Tokenizer::decode(tokenizer, vec![token_id], false, false)
}
},
Some(offsets) => {
let (start_char, end_char) = (offsets.begin as usize, offsets.end as usize);
let end_char = min(end_char, original_sentence_chars.len());
let text = original_sentence_chars[start_char..end_char].iter().collect();
let text = original_sentence_chars[start_char..end_char]
.iter()
.collect();
text
}
};
@ -449,7 +583,11 @@ impl TokenClassificationModel {
Token {
text,
score: score.double_value(&[sentence_idx, position_idx, label_id]),
label: self.label_mapping.get(&label_id).expect("Index out of vocabulary bounds.").to_owned(),
label: self
.label_mapping
.get(&label_id)
.expect("Index out of vocabulary bounds.")
.to_owned(),
label_index: label_id,
sentence: sentence_idx as usize,
index: position_idx as u16,
@ -459,24 +597,29 @@ impl TokenClassificationModel {
}
}
fn consolidate_tokens(&self, tokens: &mut Vec<Token>, label_aggregation_function: &LabelAggregationOption) {
let mut tokens_to_replace = vec!();
fn consolidate_tokens(
&self,
tokens: &mut Vec<Token>,
label_aggregation_function: &LabelAggregationOption,
) {
let mut tokens_to_replace = vec![];
let mut token_iter = tokens.iter_consolidate_tokens();
let mut cursor = 0;
while let Some(sub_tokens) = token_iter.next() {
if sub_tokens.len() > 1 {
let (label_index, label) = self.consolidate_labels(sub_tokens, label_aggregation_function);
let (label_index, label) =
self.consolidate_labels(sub_tokens, label_aggregation_function);
let sentence = (&sub_tokens[0]).sentence;
let index = (&sub_tokens[0]).index;
let word_index = (&sub_tokens[0]).word_index;
let offset_start = match &sub_tokens.first().unwrap().offset {
Some(offset) => Some(offset.begin),
None => None
None => None,
};
let offset_end = match &sub_tokens.last().unwrap().offset {
Some(offset) => Some(offset.end),
None => None
None => None,
};
let offset = if offset_start.is_some() & offset_end.is_some() {
Some(Offset::new(offset_start.unwrap(), offset_end.unwrap()))
@ -513,7 +656,11 @@ impl TokenClassificationModel {
}
}
fn consolidate_labels(&self, tokens: &[Token], aggregation: &LabelAggregationOption) -> (i64, String) {
fn consolidate_labels(
&self,
tokens: &[Token],
aggregation: &LabelAggregationOption,
) -> (i64, String) {
match aggregation {
LabelAggregationOption::First => {
let token = tokens.first().unwrap();
@ -524,22 +671,17 @@ impl TokenClassificationModel {
(token.label_index, token.label.clone())
}
LabelAggregationOption::Mode => {
let counts = tokens
.iter()
.fold(
HashMap::new(),
|mut m, c| {
*m.entry((c.label_index, c.label.as_str())).or_insert(0) += 1;
m
},
);
let counts = tokens.iter().fold(HashMap::new(), |mut m, c| {
*m.entry((c.label_index, c.label.as_str())).or_insert(0) += 1;
m
});
counts
.into_iter()
.max_by(|a, b| a.1.cmp(&b.1))
.map(|((label_index, label), _)| (label_index, label.to_owned()))
.unwrap()
}
LabelAggregationOption::Custom(function) => function(tokens)
LabelAggregationOption::Custom(function) => function(tokens),
}
}
}

View File

@ -11,7 +11,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
//! # Translation pipeline
//! Translation based on the Marian encoder-decoder architecture
//! Include techniques such as beam search, top-k and nucleus sampling, temperature setting and repetition penalty.
@ -33,31 +32,35 @@
//!
//!
//! ```no_run
//!# fn main() -> failure::Fallible<()> {
//!# use rust_bert::pipelines::generation::LanguageGenerator;
//! use rust_bert::pipelines::translation::{TranslationModel, TranslationConfig, Language};
//! # fn main() -> failure::Fallible<()> {
//! # use rust_bert::pipelines::generation::LanguageGenerator;
//! use rust_bert::pipelines::translation::{Language, TranslationConfig, TranslationModel};
//! use tch::Device;
//! let translation_config = TranslationConfig::new(Language::EnglishToFrench, Device::cuda_if_available());
//! let translation_config =
//! TranslationConfig::new(Language::EnglishToFrench, Device::cuda_if_available());
//! let mut model = TranslationModel::new(translation_config)?;
//!
//! let input = ["This is a sentence to be translated"];
//!
//! let output = model.translate(&input);
//!# Ok(())
//!# }
//! # Ok(())
//! # }
//! ```
//!
//! Output: \
//! ```no_run
//!# let output =
//! # let output =
//! "Il s'agit d'une phrase à traduire"
//!# ;
//!```
//! # ;
//! ```
use crate::pipelines::generation::{MarianGenerator, GenerateConfig, LanguageGenerator};
use crate::common::resources::{RemoteResource, Resource};
use crate::marian::{
MarianConfigResources, MarianModelResources, MarianPrefix, MarianSpmResources,
MarianVocabResources,
};
use crate::pipelines::generation::{GenerateConfig, LanguageGenerator, MarianGenerator};
use tch::Device;
use crate::common::resources::{Resource, RemoteResource};
use crate::marian::{MarianModelResources, MarianConfigResources, MarianVocabResources, MarianSpmResources, MarianPrefix};
/// Pretrained languages available for direct use
pub enum Language {
@ -84,47 +87,244 @@ pub enum Language {
struct RemoteTranslationResources;
impl RemoteTranslationResources {
pub const ENGLISH2FRENCH: ((&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), Option<&'static str>) =
(MarianModelResources::ENGLISH2ROMANCE, MarianConfigResources::ENGLISH2ROMANCE, MarianVocabResources::ENGLISH2ROMANCE, MarianSpmResources::ENGLISH2ROMANCE, MarianPrefix::ENGLISH2FRENCH);
pub const ENGLISH2CATALAN: ((&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), Option<&'static str>) =
(MarianModelResources::ENGLISH2ROMANCE, MarianConfigResources::ENGLISH2ROMANCE, MarianVocabResources::ENGLISH2ROMANCE, MarianSpmResources::ENGLISH2ROMANCE, MarianPrefix::ENGLISH2CATALAN);
pub const ENGLISH2SPANISH: ((&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), Option<&'static str>) =
(MarianModelResources::ENGLISH2ROMANCE, MarianConfigResources::ENGLISH2ROMANCE, MarianVocabResources::ENGLISH2ROMANCE, MarianSpmResources::ENGLISH2ROMANCE, MarianPrefix::ENGLISH2SPANISH);
pub const ENGLISH2PORTUGUESE: ((&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), Option<&'static str>) =
(MarianModelResources::ENGLISH2ROMANCE, MarianConfigResources::ENGLISH2ROMANCE, MarianVocabResources::ENGLISH2ROMANCE, MarianSpmResources::ENGLISH2ROMANCE, MarianPrefix::ENGLISH2PORTUGUESE);
pub const ENGLISH2ITALIAN: ((&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), Option<&'static str>) =
(MarianModelResources::ENGLISH2ROMANCE, MarianConfigResources::ENGLISH2ROMANCE, MarianVocabResources::ENGLISH2ROMANCE, MarianSpmResources::ENGLISH2ROMANCE, MarianPrefix::ENGLISH2ITALIAN);
pub const ENGLISH2ROMANIAN: ((&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), Option<&'static str>) =
(MarianModelResources::ENGLISH2ROMANCE, MarianConfigResources::ENGLISH2ROMANCE, MarianVocabResources::ENGLISH2ROMANCE, MarianSpmResources::ENGLISH2ROMANCE, MarianPrefix::ENGLISH2ROMANIAN);
pub const ENGLISH2GERMAN: ((&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), Option<&'static str>) =
(MarianModelResources::ENGLISH2GERMAN, MarianConfigResources::ENGLISH2GERMAN, MarianVocabResources::ENGLISH2GERMAN, MarianSpmResources::ENGLISH2GERMAN, MarianPrefix::ENGLISH2GERMAN);
pub const ENGLISH2RUSSIAN: ((&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), Option<&'static str>) =
(MarianModelResources::ENGLISH2RUSSIAN, MarianConfigResources::ENGLISH2RUSSIAN, MarianVocabResources::ENGLISH2RUSSIAN, MarianSpmResources::ENGLISH2RUSSIAN, MarianPrefix::ENGLISH2RUSSIAN);
pub const ENGLISH2FRENCH: (
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
) = (
MarianModelResources::ENGLISH2ROMANCE,
MarianConfigResources::ENGLISH2ROMANCE,
MarianVocabResources::ENGLISH2ROMANCE,
MarianSpmResources::ENGLISH2ROMANCE,
MarianPrefix::ENGLISH2FRENCH,
);
pub const ENGLISH2CATALAN: (
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
) = (
MarianModelResources::ENGLISH2ROMANCE,
MarianConfigResources::ENGLISH2ROMANCE,
MarianVocabResources::ENGLISH2ROMANCE,
MarianSpmResources::ENGLISH2ROMANCE,
MarianPrefix::ENGLISH2CATALAN,
);
pub const ENGLISH2SPANISH: (
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
) = (
MarianModelResources::ENGLISH2ROMANCE,
MarianConfigResources::ENGLISH2ROMANCE,
MarianVocabResources::ENGLISH2ROMANCE,
MarianSpmResources::ENGLISH2ROMANCE,
MarianPrefix::ENGLISH2SPANISH,
);
pub const ENGLISH2PORTUGUESE: (
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
) = (
MarianModelResources::ENGLISH2ROMANCE,
MarianConfigResources::ENGLISH2ROMANCE,
MarianVocabResources::ENGLISH2ROMANCE,
MarianSpmResources::ENGLISH2ROMANCE,
MarianPrefix::ENGLISH2PORTUGUESE,
);
pub const ENGLISH2ITALIAN: (
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
) = (
MarianModelResources::ENGLISH2ROMANCE,
MarianConfigResources::ENGLISH2ROMANCE,
MarianVocabResources::ENGLISH2ROMANCE,
MarianSpmResources::ENGLISH2ROMANCE,
MarianPrefix::ENGLISH2ITALIAN,
);
pub const ENGLISH2ROMANIAN: (
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
) = (
MarianModelResources::ENGLISH2ROMANCE,
MarianConfigResources::ENGLISH2ROMANCE,
MarianVocabResources::ENGLISH2ROMANCE,
MarianSpmResources::ENGLISH2ROMANCE,
MarianPrefix::ENGLISH2ROMANIAN,
);
pub const ENGLISH2GERMAN: (
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
) = (
MarianModelResources::ENGLISH2GERMAN,
MarianConfigResources::ENGLISH2GERMAN,
MarianVocabResources::ENGLISH2GERMAN,
MarianSpmResources::ENGLISH2GERMAN,
MarianPrefix::ENGLISH2GERMAN,
);
pub const ENGLISH2RUSSIAN: (
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
) = (
MarianModelResources::ENGLISH2RUSSIAN,
MarianConfigResources::ENGLISH2RUSSIAN,
MarianVocabResources::ENGLISH2RUSSIAN,
MarianSpmResources::ENGLISH2RUSSIAN,
MarianPrefix::ENGLISH2RUSSIAN,
);
pub const FRENCH2ENGLISH: ((&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), Option<&'static str>) =
(MarianModelResources::ROMANCE2ENGLISH, MarianConfigResources::ROMANCE2ENGLISH, MarianVocabResources::ROMANCE2ENGLISH, MarianSpmResources::ROMANCE2ENGLISH, MarianPrefix::FRENCH2ENGLISH);
pub const CATALAN2ENGLISH: ((&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), Option<&'static str>) =
(MarianModelResources::ROMANCE2ENGLISH, MarianConfigResources::ROMANCE2ENGLISH, MarianVocabResources::ROMANCE2ENGLISH, MarianSpmResources::ROMANCE2ENGLISH, MarianPrefix::CATALAN2ENGLISH);
pub const SPANISH2ENGLISH: ((&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), Option<&'static str>) =
(MarianModelResources::ROMANCE2ENGLISH, MarianConfigResources::ROMANCE2ENGLISH, MarianVocabResources::ROMANCE2ENGLISH, MarianSpmResources::ROMANCE2ENGLISH, MarianPrefix::SPANISH2ENGLISH);
pub const PORTUGUESE2ENGLISH: ((&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), Option<&'static str>) =
(MarianModelResources::ROMANCE2ENGLISH, MarianConfigResources::ROMANCE2ENGLISH, MarianVocabResources::ROMANCE2ENGLISH, MarianSpmResources::ROMANCE2ENGLISH, MarianPrefix::PORTUGUESE2ENGLISH);
pub const ITALIAN2ENGLISH: ((&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), Option<&'static str>) =
(MarianModelResources::ROMANCE2ENGLISH, MarianConfigResources::ROMANCE2ENGLISH, MarianVocabResources::ROMANCE2ENGLISH, MarianSpmResources::ROMANCE2ENGLISH, MarianPrefix::ITALIAN2ENGLISH);
pub const ROMANIAN2ENGLISH: ((&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), Option<&'static str>) =
(MarianModelResources::ROMANCE2ENGLISH, MarianConfigResources::ROMANCE2ENGLISH, MarianVocabResources::ROMANCE2ENGLISH, MarianSpmResources::ROMANCE2ENGLISH, MarianPrefix::ROMANIAN2ENGLISH);
pub const GERMAN2ENGLISH: ((&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), Option<&'static str>) =
(MarianModelResources::GERMAN2ENGLISH, MarianConfigResources::GERMAN2ENGLISH, MarianVocabResources::GERMAN2ENGLISH, MarianSpmResources::GERMAN2ENGLISH, MarianPrefix::GERMAN2ENGLISH);
pub const RUSSIAN2ENGLISH: ((&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), Option<&'static str>) =
(MarianModelResources::RUSSIAN2ENGLISH, MarianConfigResources::RUSSIAN2ENGLISH, MarianVocabResources::RUSSIAN2ENGLISH, MarianSpmResources::RUSSIAN2ENGLISH, MarianPrefix::RUSSIAN2ENGLISH);
pub const FRENCH2ENGLISH: (
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
) = (
MarianModelResources::ROMANCE2ENGLISH,
MarianConfigResources::ROMANCE2ENGLISH,
MarianVocabResources::ROMANCE2ENGLISH,
MarianSpmResources::ROMANCE2ENGLISH,
MarianPrefix::FRENCH2ENGLISH,
);
pub const CATALAN2ENGLISH: (
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
) = (
MarianModelResources::ROMANCE2ENGLISH,
MarianConfigResources::ROMANCE2ENGLISH,
MarianVocabResources::ROMANCE2ENGLISH,
MarianSpmResources::ROMANCE2ENGLISH,
MarianPrefix::CATALAN2ENGLISH,
);
pub const SPANISH2ENGLISH: (
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
) = (
MarianModelResources::ROMANCE2ENGLISH,
MarianConfigResources::ROMANCE2ENGLISH,
MarianVocabResources::ROMANCE2ENGLISH,
MarianSpmResources::ROMANCE2ENGLISH,
MarianPrefix::SPANISH2ENGLISH,
);
pub const PORTUGUESE2ENGLISH: (
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
) = (
MarianModelResources::ROMANCE2ENGLISH,
MarianConfigResources::ROMANCE2ENGLISH,
MarianVocabResources::ROMANCE2ENGLISH,
MarianSpmResources::ROMANCE2ENGLISH,
MarianPrefix::PORTUGUESE2ENGLISH,
);
pub const ITALIAN2ENGLISH: (
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
) = (
MarianModelResources::ROMANCE2ENGLISH,
MarianConfigResources::ROMANCE2ENGLISH,
MarianVocabResources::ROMANCE2ENGLISH,
MarianSpmResources::ROMANCE2ENGLISH,
MarianPrefix::ITALIAN2ENGLISH,
);
pub const ROMANIAN2ENGLISH: (
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
) = (
MarianModelResources::ROMANCE2ENGLISH,
MarianConfigResources::ROMANCE2ENGLISH,
MarianVocabResources::ROMANCE2ENGLISH,
MarianSpmResources::ROMANCE2ENGLISH,
MarianPrefix::ROMANIAN2ENGLISH,
);
pub const GERMAN2ENGLISH: (
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
) = (
MarianModelResources::GERMAN2ENGLISH,
MarianConfigResources::GERMAN2ENGLISH,
MarianVocabResources::GERMAN2ENGLISH,
MarianSpmResources::GERMAN2ENGLISH,
MarianPrefix::GERMAN2ENGLISH,
);
pub const RUSSIAN2ENGLISH: (
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
) = (
MarianModelResources::RUSSIAN2ENGLISH,
MarianConfigResources::RUSSIAN2ENGLISH,
MarianVocabResources::RUSSIAN2ENGLISH,
MarianSpmResources::RUSSIAN2ENGLISH,
MarianPrefix::RUSSIAN2ENGLISH,
);
pub const FRENCH2GERMAN: ((&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), Option<&'static str>) =
(MarianModelResources::FRENCH2GERMAN, MarianConfigResources::FRENCH2GERMAN, MarianVocabResources::FRENCH2GERMAN, MarianSpmResources::FRENCH2GERMAN, MarianPrefix::FRENCH2GERMAN);
pub const GERMAN2FRENCH: ((&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), Option<&'static str>) =
(MarianModelResources::GERMAN2FRENCH, MarianConfigResources::GERMAN2FRENCH, MarianVocabResources::GERMAN2FRENCH, MarianSpmResources::GERMAN2FRENCH, MarianPrefix::GERMAN2FRENCH);
pub const FRENCH2GERMAN: (
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
) = (
MarianModelResources::FRENCH2GERMAN,
MarianConfigResources::FRENCH2GERMAN,
MarianVocabResources::FRENCH2GERMAN,
MarianSpmResources::FRENCH2GERMAN,
MarianPrefix::FRENCH2GERMAN,
);
pub const GERMAN2FRENCH: (
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
) = (
MarianModelResources::GERMAN2FRENCH,
MarianConfigResources::GERMAN2FRENCH,
MarianVocabResources::GERMAN2FRENCH,
MarianSpmResources::GERMAN2FRENCH,
MarianPrefix::GERMAN2FRENCH,
);
}
/// # Configuration for text translation
/// Contains information regarding the model to load, mirrors the GenerationConfig, with a
/// different set of default parameters and sets the device to place the model on.
@ -178,45 +378,46 @@ impl TranslationConfig {
/// # Example
///
/// ```no_run
///# fn main() -> failure::Fallible<()> {
/// use rust_bert::pipelines::translation::{TranslationConfig, Language};
/// # fn main() -> failure::Fallible<()> {
/// use rust_bert::pipelines::translation::{Language, TranslationConfig};
/// use tch::Device;
///
/// let translation_config = TranslationConfig::new(Language::FrenchToEnglish, Device::cuda_if_available());
///# Ok(())
///# }
/// let translation_config =
/// TranslationConfig::new(Language::FrenchToEnglish, Device::cuda_if_available());
/// # Ok(())
/// # }
/// ```
///
pub fn new(language: Language, device: Device) -> TranslationConfig {
let (model_resource, config_resource, vocab_resource, merges_resource, prefix) = match language {
Language::EnglishToFrench => RemoteTranslationResources::ENGLISH2FRENCH,
Language::EnglishToCatalan => RemoteTranslationResources::ENGLISH2CATALAN,
Language::EnglishToSpanish => RemoteTranslationResources::ENGLISH2SPANISH,
Language::EnglishToPortuguese => RemoteTranslationResources::ENGLISH2PORTUGUESE,
Language::EnglishToItalian => RemoteTranslationResources::ENGLISH2ITALIAN,
Language::EnglishToRomanian => RemoteTranslationResources::ENGLISH2ROMANIAN,
Language::EnglishToGerman => RemoteTranslationResources::ENGLISH2GERMAN,
Language::EnglishToRussian => RemoteTranslationResources::ENGLISH2RUSSIAN,
let (model_resource, config_resource, vocab_resource, merges_resource, prefix) =
match language {
Language::EnglishToFrench => RemoteTranslationResources::ENGLISH2FRENCH,
Language::EnglishToCatalan => RemoteTranslationResources::ENGLISH2CATALAN,
Language::EnglishToSpanish => RemoteTranslationResources::ENGLISH2SPANISH,
Language::EnglishToPortuguese => RemoteTranslationResources::ENGLISH2PORTUGUESE,
Language::EnglishToItalian => RemoteTranslationResources::ENGLISH2ITALIAN,
Language::EnglishToRomanian => RemoteTranslationResources::ENGLISH2ROMANIAN,
Language::EnglishToGerman => RemoteTranslationResources::ENGLISH2GERMAN,
Language::EnglishToRussian => RemoteTranslationResources::ENGLISH2RUSSIAN,
Language::FrenchToEnglish => RemoteTranslationResources::FRENCH2ENGLISH,
Language::CatalanToEnglish => RemoteTranslationResources::CATALAN2ENGLISH,
Language::SpanishToEnglish => RemoteTranslationResources::SPANISH2ENGLISH,
Language::PortugueseToEnglish => RemoteTranslationResources::PORTUGUESE2ENGLISH,
Language::ItalianToEnglish => RemoteTranslationResources::ITALIAN2ENGLISH,
Language::RomanianToEnglish => RemoteTranslationResources::ROMANIAN2ENGLISH,
Language::GermanToEnglish => RemoteTranslationResources::GERMAN2ENGLISH,
Language::RussianToEnglish => RemoteTranslationResources::RUSSIAN2ENGLISH,
Language::FrenchToEnglish => RemoteTranslationResources::FRENCH2ENGLISH,
Language::CatalanToEnglish => RemoteTranslationResources::CATALAN2ENGLISH,
Language::SpanishToEnglish => RemoteTranslationResources::SPANISH2ENGLISH,
Language::PortugueseToEnglish => RemoteTranslationResources::PORTUGUESE2ENGLISH,
Language::ItalianToEnglish => RemoteTranslationResources::ITALIAN2ENGLISH,
Language::RomanianToEnglish => RemoteTranslationResources::ROMANIAN2ENGLISH,
Language::GermanToEnglish => RemoteTranslationResources::GERMAN2ENGLISH,
Language::RussianToEnglish => RemoteTranslationResources::RUSSIAN2ENGLISH,
Language::FrenchToGerman => RemoteTranslationResources::FRENCH2GERMAN,
Language::GermanToFrench => RemoteTranslationResources::GERMAN2FRENCH,
};
Language::FrenchToGerman => RemoteTranslationResources::FRENCH2GERMAN,
Language::GermanToFrench => RemoteTranslationResources::GERMAN2FRENCH,
};
let model_resource = Resource::Remote(RemoteResource::from_pretrained(model_resource));
let config_resource = Resource::Remote(RemoteResource::from_pretrained(config_resource));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(vocab_resource));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(merges_resource));
let prefix = match prefix {
Some(value) => Some(value.to_string()),
None => None
None => None,
};
TranslationConfig {
model_resource,
@ -253,33 +454,44 @@ impl TranslationConfig {
/// # Example
///
/// ```no_run
///# fn main() -> failure::Fallible<()> {
/// # fn main() -> failure::Fallible<()> {
/// use rust_bert::pipelines::translation::TranslationConfig;
/// use tch::Device;
/// use rust_bert::resources::{Resource, LocalResource};
/// use rust_bert::resources::{LocalResource, Resource};
/// use std::path::PathBuf;
/// use tch::Device;
///
/// let config_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/config.json") });
/// let model_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/model.ot") });
/// let vocab_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/vocab.json") });
/// let sentence_piece_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/spiece.model") });
/// let config_resource = Resource::Local(LocalResource {
/// local_path: PathBuf::from("path/to/config.json"),
/// });
/// let model_resource = Resource::Local(LocalResource {
/// local_path: PathBuf::from("path/to/model.ot"),
/// });
/// let vocab_resource = Resource::Local(LocalResource {
/// local_path: PathBuf::from("path/to/vocab.json"),
/// });
/// let sentence_piece_resource = Resource::Local(LocalResource {
/// local_path: PathBuf::from("path/to/spiece.model"),
/// });
///
/// let translation_config = TranslationConfig::new_from_resources(model_resource,
/// config_resource,
/// vocab_resource,
/// sentence_piece_resource,
/// Some(">>fr<<".to_string()),
/// Device::cuda_if_available());
///# Ok(())
///# }
/// let translation_config = TranslationConfig::new_from_resources(
/// model_resource,
/// config_resource,
/// vocab_resource,
/// sentence_piece_resource,
/// Some(">>fr<<".to_string()),
/// Device::cuda_if_available(),
/// );
/// # Ok(())
/// # }
/// ```
///
pub fn new_from_resources(model_resource: Resource,
config_resource: Resource,
vocab_resource: Resource,
sentence_piece_resource: Resource,
prefix: Option<String>,
device: Device) -> TranslationConfig {
pub fn new_from_resources(
model_resource: Resource,
config_resource: Resource,
vocab_resource: Resource,
sentence_piece_resource: Resource,
prefix: Option<String>,
device: Device,
) -> TranslationConfig {
TranslationConfig {
model_resource,
config_resource,
@ -319,18 +531,17 @@ impl TranslationModel {
/// # Example
///
/// ```no_run
///# fn main() -> failure::Fallible<()> {
/// use rust_bert::pipelines::translation::{TranslationModel, TranslationConfig, Language};
/// # fn main() -> failure::Fallible<()> {
/// use rust_bert::pipelines::translation::{Language, TranslationConfig, TranslationModel};
/// use tch::Device;
///
/// let translation_config = TranslationConfig::new(Language::FrenchToEnglish, Device::cuda_if_available());
/// let mut summarization_model = TranslationModel::new(translation_config)?;
///# Ok(())
///# }
/// let translation_config =
/// TranslationConfig::new(Language::FrenchToEnglish, Device::cuda_if_available());
/// let mut summarization_model = TranslationModel::new(translation_config)?;
/// # Ok(())
/// # }
/// ```
///
pub fn new(translation_config: TranslationConfig)
-> failure::Fallible<TranslationModel> {
pub fn new(translation_config: TranslationConfig) -> failure::Fallible<TranslationModel> {
let generate_config = GenerateConfig {
model_resource: translation_config.model_resource,
config_resource: translation_config.config_resource,
@ -353,7 +564,10 @@ impl TranslationModel {
let model = MarianGenerator::new(generate_config)?;
Ok(TranslationModel { model, prefix: translation_config.prefix })
Ok(TranslationModel {
model,
prefix: translation_config.prefix,
})
}
/// Translates texts provided
@ -368,31 +582,32 @@ impl TranslationModel {
/// # Example
///
/// ```no_run
///# fn main() -> failure::Fallible<()> {
/// # fn main() -> failure::Fallible<()> {
/// use rust_bert::pipelines::generation::LanguageGenerator;
/// use rust_bert::pipelines::translation::{TranslationModel, TranslationConfig, Language};
/// use rust_bert::pipelines::translation::{Language, TranslationConfig, TranslationModel};
/// use tch::Device;
///
/// let translation_config = TranslationConfig::new(Language::EnglishToFrench, Device::cuda_if_available());
/// let translation_config =
/// TranslationConfig::new(Language::EnglishToFrench, Device::cuda_if_available());
/// let model = TranslationModel::new(translation_config)?;
///
/// let input = ["This is a sentence to be translated"];
///
/// let output = model.translate(&input);
///# Ok(())
///# }
/// # Ok(())
/// # }
/// ```
///
pub fn translate(&self, texts: &[&str]) -> Vec<String> {
match &self.prefix {
Some(value) => {
let texts: Vec<String> = texts
.into_iter()
.map(|&v| { format!("{} {}", value, v) })
.map(|&v| format!("{} {}", value, v))
.collect();
self.model.generate(Some(texts.iter().map(AsRef::as_ref).collect()), None)
self.model
.generate(Some(texts.iter().map(AsRef::as_ref).collect()), None)
}
None => self.model.generate(Some(texts.to_vec()), None)
None => self.model.generate(Some(texts.to_vec()), None),
}
}
}
}

View File

@ -11,10 +11,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use tch::{nn, Tensor, Kind};
use crate::common::dropout::Dropout;
use tch::nn::{EmbeddingConfig, embedding};
use crate::bert::{BertConfig, BertEmbedding};
use crate::common::dropout::Dropout;
use tch::nn::{embedding, EmbeddingConfig};
use tch::{nn, Kind, Tensor};
#[derive(Debug)]
/// # BertEmbeddings implementation for RoBERTa model
@ -36,8 +36,12 @@ impl RobertaEmbeddings {
fn create_position_ids_from_embeddings(&self, x: &Tensor) -> Tensor {
let input_shape = x.size();
let input_shape = vec!(input_shape[0], input_shape[1]);
let position_ids: Tensor = Tensor::arange1(self.padding_index + 1, input_shape[0], (Kind::Int64, x.device()));
let input_shape = vec![input_shape[0], input_shape[1]];
let position_ids: Tensor = Tensor::arange1(
self.padding_index + 1,
input_shape[0],
(Kind::Int64, x.device()),
);
position_ids.unsqueeze(0).expand(&input_shape, true)
}
}
@ -54,10 +58,10 @@ impl BertEmbedding for RobertaEmbeddings {
///
/// ```no_run
/// use rust_bert::bert::{BertConfig, BertEmbedding};
/// use tch::{nn, Device};
/// use rust_bert::roberta::RobertaEmbeddings;
/// use rust_bert::Config;
/// use std::path::Path;
/// use rust_bert::roberta::RobertaEmbeddings;
/// use tch::{nn, Device};
///
/// let config_path = Path::new("path/to/config.json");
/// let device = Device::Cpu;
@ -65,29 +69,48 @@ impl BertEmbedding for RobertaEmbeddings {
/// let config = BertConfig::from_file(config_path);
/// let robert_embeddings = RobertaEmbeddings::new(&(&p.root() / "bert_embeddings"), &config);
/// ```
///
fn new(p: &nn::Path, config: &BertConfig) -> RobertaEmbeddings {
let embedding_config = EmbeddingConfig { padding_idx: 1, ..Default::default() };
let embedding_config = EmbeddingConfig {
padding_idx: 1,
..Default::default()
};
let word_embeddings: nn::Embedding = embedding(p / "word_embeddings",
config.vocab_size,
config.hidden_size,
embedding_config);
let word_embeddings: nn::Embedding = embedding(
p / "word_embeddings",
config.vocab_size,
config.hidden_size,
embedding_config,
);
let position_embeddings: nn::Embedding = embedding(p / "position_embeddings",
config.max_position_embeddings,
config.hidden_size,
Default::default());
let position_embeddings: nn::Embedding = embedding(
p / "position_embeddings",
config.max_position_embeddings,
config.hidden_size,
Default::default(),
);
let token_type_embeddings: nn::Embedding = embedding(p / "token_type_embeddings",
config.type_vocab_size,
config.hidden_size,
Default::default());
let token_type_embeddings: nn::Embedding = embedding(
p / "token_type_embeddings",
config.type_vocab_size,
config.hidden_size,
Default::default(),
);
let layer_norm_config = nn::LayerNormConfig { eps: 1e-12, ..Default::default() };
let layer_norm: nn::LayerNorm = nn::layer_norm(p / "LayerNorm", vec![config.hidden_size], layer_norm_config);
let layer_norm_config = nn::LayerNormConfig {
eps: 1e-12,
..Default::default()
};
let layer_norm: nn::LayerNorm =
nn::layer_norm(p / "LayerNorm", vec![config.hidden_size], layer_norm_config);
let dropout: Dropout = Dropout::new(config.hidden_dropout_prob);
RobertaEmbeddings { word_embeddings, position_embeddings, token_type_embeddings, layer_norm, dropout, padding_index: 1 }
RobertaEmbeddings {
word_embeddings,
position_embeddings,
token_type_embeddings,
layer_norm,
dropout,
padding_index: 1,
}
}
/// Forward pass through the embedding layer.
@ -108,68 +131,82 @@ impl BertEmbedding for RobertaEmbeddings {
/// # Example
///
/// ```no_run
///# use rust_bert::bert::{BertConfig, BertEmbedding};
///# use tch::{nn, Device, Tensor, no_grad};
///# use rust_bert::Config;
///# use std::path::Path;
///# use tch::kind::Kind::Int64;
/// # use rust_bert::bert::{BertConfig, BertEmbedding};
/// # use tch::{nn, Device, Tensor, no_grad};
/// # use rust_bert::Config;
/// # use std::path::Path;
/// # use tch::kind::Kind::Int64;
/// use rust_bert::roberta::RobertaEmbeddings;
///# let config_path = Path::new("path/to/config.json");
///# let vocab_path = Path::new("path/to/vocab.txt");
///# let device = Device::Cpu;
///# let vs = nn::VarStore::new(device);
///# let config = BertConfig::from_file(config_path);
///# let roberta_embeddings = RobertaEmbeddings::new(&vs.root(), &config);
/// let (batch_size, sequence_length) = (64, 128);
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
/// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
/// let position_ids = Tensor::arange(sequence_length, (Int64, device)).expand(&[batch_size, sequence_length], true);
/// # let config_path = Path::new("path/to/config.json");
/// # let vocab_path = Path::new("path/to/vocab.txt");
/// # let device = Device::Cpu;
/// # let vs = nn::VarStore::new(device);
/// # let config = BertConfig::from_file(config_path);
/// # let roberta_embeddings = RobertaEmbeddings::new(&vs.root(), &config);
/// let (batch_size, sequence_length) = (64, 128);
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
/// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
/// let position_ids = Tensor::arange(sequence_length, (Int64, device))
/// .expand(&[batch_size, sequence_length], true);
///
/// let embedded_output = no_grad(|| {
/// roberta_embeddings
/// .forward_t(Some(input_tensor),
/// Some(token_type_ids),
/// Some(position_ids),
/// None,
/// false).unwrap()
/// });
/// let embedded_output = no_grad(|| {
/// roberta_embeddings
/// .forward_t(
/// Some(input_tensor),
/// Some(token_type_ids),
/// Some(position_ids),
/// None,
/// false,
/// )
/// .unwrap()
/// });
/// ```
///
fn forward_t(&self,
input_ids: Option<Tensor>,
token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool) -> Result<Tensor, &'static str> {
fn forward_t(
&self,
input_ids: Option<Tensor>,
token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool,
) -> Result<Tensor, &'static str> {
let (input_embeddings, input_shape) = match &input_ids {
Some(input_value) => match &input_embeds {
Some(_) => { return Err("Only one of input ids or input embeddings may be set"); }
None => (input_value.apply_t(&self.word_embeddings, train), input_value.size())
}
Some(_) => {
return Err("Only one of input ids or input embeddings may be set");
}
None => (
input_value.apply_t(&self.word_embeddings, train),
input_value.size(),
),
},
None => match &input_embeds {
Some(embeds) => (embeds.copy(), vec!(embeds.size()[0], embeds.size()[1])),
None => { return Err("Only one of input ids or input embeddings may be set"); }
}
Some(embeds) => (embeds.copy(), vec![embeds.size()[0], embeds.size()[1]]),
None => {
return Err("Only one of input ids or input embeddings may be set");
}
},
};
let position_ids = match position_ids {
Some(value) => value,
None => match input_ids {
Some(value) => self.create_position_ids_from_input_ids(&value),
None => self.create_position_ids_from_embeddings(&input_embeds.unwrap())
}
None => self.create_position_ids_from_embeddings(&input_embeds.unwrap()),
},
};
let token_type_ids = match token_type_ids {
Some(value) => value,
None => Tensor::zeros(&input_shape, (Kind::Int64, input_embeddings.device()))
None => Tensor::zeros(&input_shape, (Kind::Int64, input_embeddings.device())),
};
let position_embeddings = position_ids.apply(&self.position_embeddings);
let token_type_embeddings = token_type_ids.apply(&self.token_type_embeddings);
let input_embeddings: Tensor = input_embeddings + position_embeddings + token_type_embeddings;
Ok(input_embeddings.apply(&self.layer_norm).apply_t(&self.dropout, train))
let input_embeddings: Tensor =
input_embeddings + position_embeddings + token_type_embeddings;
Ok(input_embeddings
.apply(&self.layer_norm)
.apply_t(&self.dropout, train))
}
}
}

View File

@ -19,20 +19,28 @@
//! Pretrained models are available and can be downloaded using RemoteResources.
//!
//! ```no_run
//!# fn main() -> failure::Fallible<()> {
//!#
//! # fn main() -> failure::Fallible<()> {
//! #
//! use rust_tokenizers::RobertaTokenizer;
//! use tch::{nn, Device};
//!# use std::path::PathBuf;
//! # use std::path::PathBuf;
//! use rust_bert::bert::BertConfig;
//! use rust_bert::Config;
//! use rust_bert::resources::{download_resource, LocalResource, Resource};
//! use rust_bert::roberta::RobertaForMaskedLM;
//! use rust_bert::resources::{Resource, download_resource, LocalResource};
//! use rust_bert::Config;
//!
//! let config_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/config.json")});
//! let vocab_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/vocab.txt")});
//! let merges_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/vocab.txt")});
//! let weights_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/model.ot")});
//! let config_resource = Resource::Local(LocalResource {
//! local_path: PathBuf::from("path/to/config.json"),
//! });
//! let vocab_resource = Resource::Local(LocalResource {
//! local_path: PathBuf::from("path/to/vocab.txt"),
//! });
//! let merges_resource = Resource::Local(LocalResource {
//! local_path: PathBuf::from("path/to/vocab.txt"),
//! });
//! let weights_resource = Resource::Local(LocalResource {
//! local_path: PathBuf::from("path/to/model.ot"),
//! });
//! let config_path = download_resource(&config_resource)?;
//! let vocab_path = download_resource(&vocab_resource)?;
//! let merges_path = download_resource(&merges_resource)?;
@ -40,19 +48,25 @@
//!
//! let device = Device::cuda_if_available();
//! let mut vs = nn::VarStore::new(device);
//! let tokenizer: RobertaTokenizer = RobertaTokenizer::from_file(vocab_path.to_str().unwrap(), merges_path.to_str().unwrap(), true);
//! let tokenizer: RobertaTokenizer = RobertaTokenizer::from_file(
//! vocab_path.to_str().unwrap(),
//! merges_path.to_str().unwrap(),
//! true,
//! );
//! let config = BertConfig::from_file(config_path);
//! let bert_model = RobertaForMaskedLM::new(&vs.root(), &config);
//! vs.load(weights_path)?;
//!
//!# Ok(())
//!# }
//! # Ok(())
//! # }
//! ```
mod embeddings;
mod roberta;
pub use roberta::{RobertaModelResources, RobertaConfigResources, RobertaVocabResources, RobertaMergesResources,
RobertaForMaskedLM, RobertaForMultipleChoice, RobertaForTokenClassification, RobertaForQuestionAnswering, RobertaForSequenceClassification};
pub use embeddings::RobertaEmbeddings;
pub use embeddings::RobertaEmbeddings;
pub use roberta::{
RobertaConfigResources, RobertaForMaskedLM, RobertaForMultipleChoice,
RobertaForQuestionAnswering, RobertaForSequenceClassification, RobertaForTokenClassification,
RobertaMergesResources, RobertaModelResources, RobertaVocabResources,
};

View File

@ -11,13 +11,13 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use tch::{nn, Tensor};
use crate::common::linear::{linear_no_bias, LinearNoBias};
use tch::nn::Init;
use crate::common::activations::_gelu;
use crate::roberta::embeddings::RobertaEmbeddings;
use crate::common::dropout::Dropout;
use crate::bert::{BertConfig, BertModel};
use crate::common::activations::_gelu;
use crate::common::dropout::Dropout;
use crate::common::linear::{linear_no_bias, LinearNoBias};
use crate::roberta::embeddings::RobertaEmbeddings;
use tch::nn::Init;
use tch::{nn, Tensor};
/// # RoBERTa Pretrained model weight files
pub struct RobertaModelResources;
@ -33,22 +33,34 @@ pub struct RobertaMergesResources;
impl RobertaModelResources {
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
pub const ROBERTA: (&'static str, &'static str) = ("roberta/model.ot", "https://cdn.huggingface.co/roberta-base-rust_model.ot");
pub const ROBERTA: (&'static str, &'static str) = (
"roberta/model.ot",
"https://cdn.huggingface.co/roberta-base-rust_model.ot",
);
}
impl RobertaConfigResources {
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
pub const ROBERTA: (&'static str, &'static str) = ("roberta/config.json", "https://cdn.huggingface.co/roberta-base-config.json");
pub const ROBERTA: (&'static str, &'static str) = (
"roberta/config.json",
"https://cdn.huggingface.co/roberta-base-config.json",
);
}
impl RobertaVocabResources {
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
pub const ROBERTA: (&'static str, &'static str) = ("roberta/vocab.txt", "https://cdn.huggingface.co/roberta-base-vocab.json");
pub const ROBERTA: (&'static str, &'static str) = (
"roberta/vocab.txt",
"https://cdn.huggingface.co/roberta-base-vocab.json",
);
}
impl RobertaMergesResources {
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
pub const ROBERTA: (&'static str, &'static str) = ("roberta/merges.txt", "https://cdn.huggingface.co/roberta-base-merges.txt");
pub const ROBERTA: (&'static str, &'static str) = (
"roberta/merges.txt",
"https://cdn.huggingface.co/roberta-base-merges.txt",
);
}
pub struct RobertaLMHead {
@ -60,17 +72,42 @@ pub struct RobertaLMHead {
impl RobertaLMHead {
pub fn new(p: &nn::Path, config: &BertConfig) -> RobertaLMHead {
let dense = nn::linear(p / "dense", config.hidden_size, config.hidden_size, Default::default());
let layer_norm_config = nn::LayerNormConfig { eps: 1e-12, ..Default::default() };
let layer_norm = nn::layer_norm(p / "layer_norm", vec![config.hidden_size], layer_norm_config);
let decoder = linear_no_bias(&(p / "decoder"), config.hidden_size, config.vocab_size, Default::default());
let dense = nn::linear(
p / "dense",
config.hidden_size,
config.hidden_size,
Default::default(),
);
let layer_norm_config = nn::LayerNormConfig {
eps: 1e-12,
..Default::default()
};
let layer_norm = nn::layer_norm(
p / "layer_norm",
vec![config.hidden_size],
layer_norm_config,
);
let decoder = linear_no_bias(
&(p / "decoder"),
config.hidden_size,
config.vocab_size,
Default::default(),
);
let bias = p.var("bias", &[config.vocab_size], Init::KaimingUniform);
RobertaLMHead { dense, decoder, layer_norm, bias }
RobertaLMHead {
dense,
decoder,
layer_norm,
bias,
}
}
pub fn forward(&self, hidden_states: &Tensor) -> Tensor {
(_gelu(&hidden_states.apply(&self.dense))).apply(&self.layer_norm).apply(&self.decoder) + &self.bias
(_gelu(&hidden_states.apply(&self.dense)))
.apply(&self.layer_norm)
.apply(&self.decoder)
+ &self.bias
}
}
@ -96,10 +133,10 @@ impl RobertaForMaskedLM {
///
/// ```no_run
/// use rust_bert::bert::BertConfig;
/// use tch::{nn, Device};
/// use rust_bert::roberta::RobertaForMaskedLM;
/// use rust_bert::Config;
/// use std::path::Path;
/// use rust_bert::roberta::RobertaForMaskedLM;
/// use tch::{nn, Device};
///
/// let config_path = Path::new("path/to/config.json");
/// let device = Device::Cpu;
@ -107,7 +144,6 @@ impl RobertaForMaskedLM {
/// let config = BertConfig::from_file(config_path);
/// let roberta = RobertaForMaskedLM::new(&(&p.root() / "roberta"), &config);
/// ```
///
pub fn new(p: &nn::Path, config: &BertConfig) -> RobertaForMaskedLM {
let roberta = BertModel::<RobertaEmbeddings>::new(&(p / "roberta"), config);
let lm_head = RobertaLMHead::new(&(p / "lm_head"), config);
@ -137,49 +173,62 @@ impl RobertaForMaskedLM {
/// # Example
///
/// ```no_run
///# use rust_bert::bert::BertConfig;
///# use tch::{nn, Device, Tensor, no_grad};
///# use rust_bert::Config;
///# use std::path::Path;
///# use tch::kind::Kind::Int64;
/// # use rust_bert::bert::BertConfig;
/// # use tch::{nn, Device, Tensor, no_grad};
/// # use rust_bert::Config;
/// # use std::path::Path;
/// # use tch::kind::Kind::Int64;
/// use rust_bert::roberta::RobertaForMaskedLM;
///# let config_path = Path::new("path/to/config.json");
///# let vocab_path = Path::new("path/to/vocab.txt");
///# let device = Device::Cpu;
///# let vs = nn::VarStore::new(device);
///# let config = BertConfig::from_file(config_path);
///# let roberta_model = RobertaForMaskedLM::new(&vs.root(), &config);
/// let (batch_size, sequence_length) = (64, 128);
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
/// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
/// let position_ids = Tensor::arange(sequence_length, (Int64, device)).expand(&[batch_size, sequence_length], true);
///
/// let (output, all_hidden_states, all_attentions) = no_grad(|| {
/// roberta_model
/// .forward_t(Some(input_tensor),
/// Some(mask),
/// Some(token_type_ids),
/// Some(position_ids),
/// None,
/// &None,
/// &None,
/// false)
/// });
/// # let config_path = Path::new("path/to/config.json");
/// # let vocab_path = Path::new("path/to/vocab.txt");
/// # let device = Device::Cpu;
/// # let vs = nn::VarStore::new(device);
/// # let config = BertConfig::from_file(config_path);
/// # let roberta_model = RobertaForMaskedLM::new(&vs.root(), &config);
/// let (batch_size, sequence_length) = (64, 128);
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
/// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
/// let position_ids = Tensor::arange(sequence_length, (Int64, device))
/// .expand(&[batch_size, sequence_length], true);
///
/// let (output, all_hidden_states, all_attentions) = no_grad(|| {
/// roberta_model.forward_t(
/// Some(input_tensor),
/// Some(mask),
/// Some(token_type_ids),
/// Some(position_ids),
/// None,
/// &None,
/// &None,
/// false,
/// )
/// });
/// ```
///
pub fn forward_t(&self,
input_ids: Option<Tensor>,
mask: Option<Tensor>,
token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
encoder_hidden_states: &Option<Tensor>,
encoder_mask: &Option<Tensor>,
train: bool) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
let (hidden_state, _, all_hidden_states, all_attentions) = self.roberta.forward_t(input_ids, mask, token_type_ids, position_ids,
input_embeds, encoder_hidden_states, encoder_mask, train).unwrap();
pub fn forward_t(
&self,
input_ids: Option<Tensor>,
mask: Option<Tensor>,
token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
encoder_hidden_states: &Option<Tensor>,
encoder_mask: &Option<Tensor>,
train: bool,
) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
let (hidden_state, _, all_hidden_states, all_attentions) = self
.roberta
.forward_t(
input_ids,
mask,
token_type_ids,
position_ids,
input_embeds,
encoder_hidden_states,
encoder_mask,
train,
)
.unwrap();
let prediction_scores = self.lm_head.forward(&hidden_state);
(prediction_scores, all_hidden_states, all_attentions)
@ -194,12 +243,30 @@ pub struct RobertaClassificationHead {
impl RobertaClassificationHead {
pub fn new(p: &nn::Path, config: &BertConfig) -> RobertaClassificationHead {
let dense = nn::linear(p / "dense", config.hidden_size, config.hidden_size, Default::default());
let num_labels = config.id2label.as_ref().expect("num_labels not provided in configuration").len() as i64;
let out_proj = nn::linear(p / "out_proj", config.hidden_size, num_labels, Default::default());
let dense = nn::linear(
p / "dense",
config.hidden_size,
config.hidden_size,
Default::default(),
);
let num_labels = config
.id2label
.as_ref()
.expect("num_labels not provided in configuration")
.len() as i64;
let out_proj = nn::linear(
p / "out_proj",
config.hidden_size,
num_labels,
Default::default(),
);
let dropout = Dropout::new(config.hidden_dropout_prob);
RobertaClassificationHead { dense, dropout, out_proj }
RobertaClassificationHead {
dense,
dropout,
out_proj,
}
}
pub fn forward_t(&self, hidden_states: &Tensor, train: bool) -> Tensor {
@ -235,10 +302,10 @@ impl RobertaForSequenceClassification {
///
/// ```no_run
/// use rust_bert::bert::BertConfig;
/// use tch::{nn, Device};
/// use rust_bert::roberta::RobertaForSequenceClassification;
/// use rust_bert::Config;
/// use std::path::Path;
/// use rust_bert::roberta::RobertaForSequenceClassification;
/// use tch::{nn, Device};
///
/// let config_path = Path::new("path/to/config.json");
/// let device = Device::Cpu;
@ -246,12 +313,14 @@ impl RobertaForSequenceClassification {
/// let config = BertConfig::from_file(config_path);
/// let roberta = RobertaForSequenceClassification::new(&(&p.root() / "roberta"), &config);
/// ```
///
pub fn new(p: &nn::Path, config: &BertConfig) -> RobertaForSequenceClassification {
let roberta = BertModel::<RobertaEmbeddings>::new(&(p / "roberta"), config);
let classifier = RobertaClassificationHead::new(&(p / "classifier"), config);
RobertaForSequenceClassification { roberta, classifier }
RobertaForSequenceClassification {
roberta,
classifier,
}
}
/// Forward pass through the model
@ -274,45 +343,58 @@ impl RobertaForSequenceClassification {
/// # Example
///
/// ```no_run
///# use rust_bert::bert::BertConfig;
///# use tch::{nn, Device, Tensor, no_grad};
///# use rust_bert::Config;
///# use std::path::Path;
///# use tch::kind::Kind::Int64;
/// # use rust_bert::bert::BertConfig;
/// # use tch::{nn, Device, Tensor, no_grad};
/// # use rust_bert::Config;
/// # use std::path::Path;
/// # use tch::kind::Kind::Int64;
/// use rust_bert::roberta::RobertaForSequenceClassification;
///# let config_path = Path::new("path/to/config.json");
///# let vocab_path = Path::new("path/to/vocab.txt");
///# let device = Device::Cpu;
///# let vs = nn::VarStore::new(device);
///# let config = BertConfig::from_file(config_path);
///# let roberta_model = RobertaForSequenceClassification::new(&vs.root(), &config);
/// let (batch_size, sequence_length) = (64, 128);
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
/// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
/// let position_ids = Tensor::arange(sequence_length, (Int64, device)).expand(&[batch_size, sequence_length], true);
///
/// let (labels, all_hidden_states, all_attentions) = no_grad(|| {
/// roberta_model
/// .forward_t(Some(input_tensor),
/// Some(mask),
/// Some(token_type_ids),
/// Some(position_ids),
/// None,
/// false)
/// });
/// # let config_path = Path::new("path/to/config.json");
/// # let vocab_path = Path::new("path/to/vocab.txt");
/// # let device = Device::Cpu;
/// # let vs = nn::VarStore::new(device);
/// # let config = BertConfig::from_file(config_path);
/// # let roberta_model = RobertaForSequenceClassification::new(&vs.root(), &config);
/// let (batch_size, sequence_length) = (64, 128);
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
/// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
/// let position_ids = Tensor::arange(sequence_length, (Int64, device))
/// .expand(&[batch_size, sequence_length], true);
///
/// let (labels, all_hidden_states, all_attentions) = no_grad(|| {
/// roberta_model.forward_t(
/// Some(input_tensor),
/// Some(mask),
/// Some(token_type_ids),
/// Some(position_ids),
/// None,
/// false,
/// )
/// });
/// ```
///
pub fn forward_t(&self,
input_ids: Option<Tensor>,
mask: Option<Tensor>,
token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
let (hidden_state, _, all_hidden_states, all_attentions) = self.roberta.forward_t(input_ids, mask, token_type_ids, position_ids,
input_embeds, &None, &None, train).unwrap();
pub fn forward_t(
&self,
input_ids: Option<Tensor>,
mask: Option<Tensor>,
token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool,
) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
let (hidden_state, _, all_hidden_states, all_attentions) = self
.roberta
.forward_t(
input_ids,
mask,
token_type_ids,
position_ids,
input_embeds,
&None,
&None,
train,
)
.unwrap();
let output = self.classifier.forward_t(&hidden_state, train);
(output, all_hidden_states, all_attentions)
@ -344,10 +426,10 @@ impl RobertaForMultipleChoice {
///
/// ```no_run
/// use rust_bert::bert::BertConfig;
/// use tch::{nn, Device};
/// use rust_bert::roberta::RobertaForMultipleChoice;
/// use rust_bert::Config;
/// use std::path::Path;
/// use rust_bert::roberta::RobertaForMultipleChoice;
/// use tch::{nn, Device};
///
/// let config_path = Path::new("path/to/config.json");
/// let device = Device::Cpu;
@ -355,13 +437,16 @@ impl RobertaForMultipleChoice {
/// let config = BertConfig::from_file(config_path);
/// let roberta = RobertaForMultipleChoice::new(&(&p.root() / "roberta"), &config);
/// ```
///
pub fn new(p: &nn::Path, config: &BertConfig) -> RobertaForMultipleChoice {
let roberta = BertModel::<RobertaEmbeddings>::new(&(p / "roberta"), config);
let dropout = Dropout::new(config.hidden_dropout_prob);
let classifier = nn::linear(p / "classifier", config.hidden_size, 1, Default::default());
RobertaForMultipleChoice { roberta, dropout, classifier }
RobertaForMultipleChoice {
roberta,
dropout,
classifier,
}
}
/// Forward pass through the model
@ -383,61 +468,77 @@ impl RobertaForMultipleChoice {
/// # Example
///
/// ```no_run
///# use rust_bert::bert::BertConfig;
///# use tch::{nn, Device, Tensor, no_grad};
///# use rust_bert::Config;
///# use std::path::Path;
///# use tch::kind::Kind::Int64;
/// # use rust_bert::bert::BertConfig;
/// # use tch::{nn, Device, Tensor, no_grad};
/// # use rust_bert::Config;
/// # use std::path::Path;
/// # use tch::kind::Kind::Int64;
/// use rust_bert::roberta::RobertaForMultipleChoice;
///# let config_path = Path::new("path/to/config.json");
///# let vocab_path = Path::new("path/to/vocab.txt");
///# let device = Device::Cpu;
///# let vs = nn::VarStore::new(device);
///# let config = BertConfig::from_file(config_path);
///# let roberta_model = RobertaForMultipleChoice::new(&vs.root(), &config);
/// let (num_choices, sequence_length) = (3, 128);
/// let input_tensor = Tensor::rand(&[num_choices, sequence_length], (Int64, device));
/// let mask = Tensor::zeros(&[num_choices, sequence_length], (Int64, device));
/// let token_type_ids = Tensor::zeros(&[num_choices, sequence_length], (Int64, device));
/// let position_ids = Tensor::arange(sequence_length, (Int64, device)).expand(&[num_choices, sequence_length], true);
///
/// let (choices, all_hidden_states, all_attentions) = no_grad(|| {
/// roberta_model
/// .forward_t(input_tensor,
/// Some(mask),
/// Some(token_type_ids),
/// Some(position_ids),
/// false)
/// });
/// # let config_path = Path::new("path/to/config.json");
/// # let vocab_path = Path::new("path/to/vocab.txt");
/// # let device = Device::Cpu;
/// # let vs = nn::VarStore::new(device);
/// # let config = BertConfig::from_file(config_path);
/// # let roberta_model = RobertaForMultipleChoice::new(&vs.root(), &config);
/// let (num_choices, sequence_length) = (3, 128);
/// let input_tensor = Tensor::rand(&[num_choices, sequence_length], (Int64, device));
/// let mask = Tensor::zeros(&[num_choices, sequence_length], (Int64, device));
/// let token_type_ids = Tensor::zeros(&[num_choices, sequence_length], (Int64, device));
/// let position_ids = Tensor::arange(sequence_length, (Int64, device))
/// .expand(&[num_choices, sequence_length], true);
///
/// let (choices, all_hidden_states, all_attentions) = no_grad(|| {
/// roberta_model.forward_t(
/// input_tensor,
/// Some(mask),
/// Some(token_type_ids),
/// Some(position_ids),
/// false,
/// )
/// });
/// ```
///
pub fn forward_t(&self,
input_ids: Tensor,
mask: Option<Tensor>,
token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>,
train: bool) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
pub fn forward_t(
&self,
input_ids: Tensor,
mask: Option<Tensor>,
token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>,
train: bool,
) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
let num_choices = input_ids.size()[1];
let flat_input_ids = Some(input_ids.view((-1i64, *input_ids.size().last().unwrap())));
let flat_position_ids = match position_ids {
Some(value) => Some(value.view((-1i64, *value.size().last().unwrap()))),
None => None
None => None,
};
let flat_token_type_ids = match token_type_ids {
Some(value) => Some(value.view((-1i64, *value.size().last().unwrap()))),
None => None
None => None,
};
let flat_mask = match mask {
Some(value) => Some(value.view((-1i64, *value.size().last().unwrap()))),
None => None
None => None,
};
let (_, pooled_output, all_hidden_states, all_attentions) = self.roberta.forward_t(flat_input_ids, flat_mask, flat_token_type_ids, flat_position_ids,
None, &None, &None, train).unwrap();
let (_, pooled_output, all_hidden_states, all_attentions) = self
.roberta
.forward_t(
flat_input_ids,
flat_mask,
flat_token_type_ids,
flat_position_ids,
None,
&None,
&None,
train,
)
.unwrap();
let output = pooled_output.apply_t(&self.dropout, train).apply(&self.classifier).view((-1, num_choices));
let output = pooled_output
.apply_t(&self.dropout, train)
.apply(&self.classifier)
.view((-1, num_choices));
(output, all_hidden_states, all_attentions)
}
}
@ -466,10 +567,10 @@ impl RobertaForTokenClassification {
///
/// ```no_run
/// use rust_bert::bert::BertConfig;
/// use tch::{nn, Device};
/// use rust_bert::roberta::RobertaForTokenClassification;
/// use rust_bert::Config;
/// use std::path::Path;
/// use rust_bert::roberta::RobertaForTokenClassification;
/// use tch::{nn, Device};
///
/// let config_path = Path::new("path/to/config.json");
/// let device = Device::Cpu;
@ -477,14 +578,26 @@ impl RobertaForTokenClassification {
/// let config = BertConfig::from_file(config_path);
/// let roberta = RobertaForTokenClassification::new(&(&p.root() / "roberta"), &config);
/// ```
///
pub fn new(p: &nn::Path, config: &BertConfig) -> RobertaForTokenClassification {
let roberta = BertModel::<RobertaEmbeddings>::new(&(p / "roberta"), config);
let dropout = Dropout::new(config.hidden_dropout_prob);
let num_labels = config.id2label.as_ref().expect("num_labels not provided in configuration").len() as i64;
let classifier = nn::linear(p / "classifier", config.hidden_size, num_labels, Default::default());
let num_labels = config
.id2label
.as_ref()
.expect("num_labels not provided in configuration")
.len() as i64;
let classifier = nn::linear(
p / "classifier",
config.hidden_size,
num_labels,
Default::default(),
);
RobertaForTokenClassification { roberta, dropout, classifier }
RobertaForTokenClassification {
roberta,
dropout,
classifier,
}
}
/// Forward pass through the model
@ -507,47 +620,62 @@ impl RobertaForTokenClassification {
/// # Example
///
/// ```no_run
///# use rust_bert::bert::BertConfig;
///# use tch::{nn, Device, Tensor, no_grad};
///# use rust_bert::Config;
///# use std::path::Path;
///# use tch::kind::Kind::Int64;
/// # use rust_bert::bert::BertConfig;
/// # use tch::{nn, Device, Tensor, no_grad};
/// # use rust_bert::Config;
/// # use std::path::Path;
/// # use tch::kind::Kind::Int64;
/// use rust_bert::roberta::RobertaForTokenClassification;
///# let config_path = Path::new("path/to/config.json");
///# let vocab_path = Path::new("path/to/vocab.txt");
///# let device = Device::Cpu;
///# let vs = nn::VarStore::new(device);
///# let config = BertConfig::from_file(config_path);
///# let roberta_model = RobertaForTokenClassification::new(&vs.root(), &config);
/// let (batch_size, sequence_length) = (64, 128);
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
/// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
/// let position_ids = Tensor::arange(sequence_length, (Int64, device)).expand(&[batch_size, sequence_length], true);
///
/// let (token_labels, all_hidden_states, all_attentions) = no_grad(|| {
/// roberta_model
/// .forward_t(Some(input_tensor),
/// Some(mask),
/// Some(token_type_ids),
/// Some(position_ids),
/// None,
/// false)
/// });
/// # let config_path = Path::new("path/to/config.json");
/// # let vocab_path = Path::new("path/to/vocab.txt");
/// # let device = Device::Cpu;
/// # let vs = nn::VarStore::new(device);
/// # let config = BertConfig::from_file(config_path);
/// # let roberta_model = RobertaForTokenClassification::new(&vs.root(), &config);
/// let (batch_size, sequence_length) = (64, 128);
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
/// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
/// let position_ids = Tensor::arange(sequence_length, (Int64, device))
/// .expand(&[batch_size, sequence_length], true);
///
/// let (token_labels, all_hidden_states, all_attentions) = no_grad(|| {
/// roberta_model.forward_t(
/// Some(input_tensor),
/// Some(mask),
/// Some(token_type_ids),
/// Some(position_ids),
/// None,
/// false,
/// )
/// });
/// ```
///
pub fn forward_t(&self,
input_ids: Option<Tensor>,
mask: Option<Tensor>,
token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
let (hidden_state, _, all_hidden_states, all_attentions) = self.roberta.forward_t(input_ids, mask, token_type_ids, position_ids,
input_embeds, &None, &None, train).unwrap();
pub fn forward_t(
&self,
input_ids: Option<Tensor>,
mask: Option<Tensor>,
token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool,
) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
let (hidden_state, _, all_hidden_states, all_attentions) = self
.roberta
.forward_t(
input_ids,
mask,
token_type_ids,
position_ids,
input_embeds,
&None,
&None,
train,
)
.unwrap();
let sequence_output = hidden_state.apply_t(&self.dropout, train).apply(&self.classifier);
let sequence_output = hidden_state
.apply_t(&self.dropout, train)
.apply(&self.classifier);
(sequence_output, all_hidden_states, all_attentions)
}
}
@ -576,10 +704,10 @@ impl RobertaForQuestionAnswering {
///
/// ```no_run
/// use rust_bert::bert::BertConfig;
/// use tch::{nn, Device};
/// use rust_bert::roberta::RobertaForQuestionAnswering;
/// use rust_bert::Config;
/// use std::path::Path;
/// use rust_bert::roberta::RobertaForQuestionAnswering;
/// use tch::{nn, Device};
///
/// let config_path = Path::new("path/to/config.json");
/// let device = Device::Cpu;
@ -587,13 +715,20 @@ impl RobertaForQuestionAnswering {
/// let config = BertConfig::from_file(config_path);
/// let roberta = RobertaForQuestionAnswering::new(&(&p.root() / "roberta"), &config);
/// ```
///
pub fn new(p: &nn::Path, config: &BertConfig) -> RobertaForQuestionAnswering {
let roberta = BertModel::<RobertaEmbeddings>::new(&(p / "roberta"), config);
let num_labels = 2;
let qa_outputs = nn::linear(p / "qa_outputs", config.hidden_size, num_labels, Default::default());
let qa_outputs = nn::linear(
p / "qa_outputs",
config.hidden_size,
num_labels,
Default::default(),
);
RobertaForQuestionAnswering { roberta, qa_outputs }
RobertaForQuestionAnswering {
roberta,
qa_outputs,
}
}
/// Forward pass through the model
@ -617,45 +752,58 @@ impl RobertaForQuestionAnswering {
/// # Example
///
/// ```no_run
///# use rust_bert::bert::BertConfig;
///# use tch::{nn, Device, Tensor, no_grad};
///# use rust_bert::Config;
///# use std::path::Path;
///# use tch::kind::Kind::Int64;
/// # use rust_bert::bert::BertConfig;
/// # use tch::{nn, Device, Tensor, no_grad};
/// # use rust_bert::Config;
/// # use std::path::Path;
/// # use tch::kind::Kind::Int64;
/// use rust_bert::roberta::RobertaForQuestionAnswering;
///# let config_path = Path::new("path/to/config.json");
///# let vocab_path = Path::new("path/to/vocab.txt");
///# let device = Device::Cpu;
///# let vs = nn::VarStore::new(device);
///# let config = BertConfig::from_file(config_path);
///# let roberta_model = RobertaForQuestionAnswering::new(&vs.root(), &config);
/// let (batch_size, sequence_length) = (64, 128);
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
/// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
/// let position_ids = Tensor::arange(sequence_length, (Int64, device)).expand(&[batch_size, sequence_length], true);
///
/// let (start_scores, end_scores, all_hidden_states, all_attentions) = no_grad(|| {
/// roberta_model
/// .forward_t(Some(input_tensor),
/// Some(mask),
/// Some(token_type_ids),
/// Some(position_ids),
/// None,
/// false)
/// });
/// # let config_path = Path::new("path/to/config.json");
/// # let vocab_path = Path::new("path/to/vocab.txt");
/// # let device = Device::Cpu;
/// # let vs = nn::VarStore::new(device);
/// # let config = BertConfig::from_file(config_path);
/// # let roberta_model = RobertaForQuestionAnswering::new(&vs.root(), &config);
/// let (batch_size, sequence_length) = (64, 128);
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
/// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
/// let position_ids = Tensor::arange(sequence_length, (Int64, device))
/// .expand(&[batch_size, sequence_length], true);
///
/// let (start_scores, end_scores, all_hidden_states, all_attentions) = no_grad(|| {
/// roberta_model.forward_t(
/// Some(input_tensor),
/// Some(mask),
/// Some(token_type_ids),
/// Some(position_ids),
/// None,
/// false,
/// )
/// });
/// ```
///
pub fn forward_t(&self,
input_ids: Option<Tensor>,
mask: Option<Tensor>,
token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool) -> (Tensor, Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
let (hidden_state, _, all_hidden_states, all_attentions) = self.roberta.forward_t(input_ids, mask, token_type_ids, position_ids,
input_embeds, &None, &None, train).unwrap();
pub fn forward_t(
&self,
input_ids: Option<Tensor>,
mask: Option<Tensor>,
token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool,
) -> (Tensor, Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
let (hidden_state, _, all_hidden_states, all_attentions) = self
.roberta
.forward_t(
input_ids,
mask,
token_type_ids,
position_ids,
input_embeds,
&None,
&None,
train,
)
.unwrap();
let sequence_output = hidden_state.apply(&self.qa_outputs);
let logits = sequence_output.split(1, -1);
@ -665,4 +813,4 @@ impl RobertaForQuestionAnswering {
(start_logits, end_logits, all_hidden_states, all_attentions)
}
}
}

View File

@ -1,67 +1,77 @@
extern crate failure;
extern crate dirs;
extern crate failure;
use tch::{Device, nn, Tensor, no_grad};
use rust_tokenizers::{TruncationStrategy, Tokenizer, Vocab, AlbertTokenizer};
use rust_bert::albert::{
AlbertConfig, AlbertConfigResources, AlbertForMaskedLM, AlbertForMultipleChoice,
AlbertForQuestionAnswering, AlbertForSequenceClassification, AlbertForTokenClassification,
AlbertModelResources, AlbertVocabResources,
};
use rust_bert::resources::{download_resource, RemoteResource, Resource};
use rust_bert::Config;
use rust_bert::resources::{Resource, RemoteResource, download_resource};
use rust_bert::albert::{AlbertConfigResources, AlbertVocabResources, AlbertModelResources, AlbertConfig, AlbertForMaskedLM, AlbertForSequenceClassification, AlbertForMultipleChoice, AlbertForTokenClassification, AlbertForQuestionAnswering};
use rust_tokenizers::{AlbertTokenizer, Tokenizer, TruncationStrategy, Vocab};
use std::collections::HashMap;
use tch::{nn, no_grad, Device, Tensor};
#[test]
fn albert_masked_lm() -> failure::Fallible<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(AlbertConfigResources::ALBERT_BASE_V2));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(AlbertVocabResources::ALBERT_BASE_V2));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(AlbertModelResources::ALBERT_BASE_V2));
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
AlbertConfigResources::ALBERT_BASE_V2,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
AlbertVocabResources::ALBERT_BASE_V2,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
AlbertModelResources::ALBERT_BASE_V2,
));
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
let weights_path = download_resource(&weights_resource)?;
// Set-up masked LM model
// Set-up masked LM model
let device = Device::Cpu;
let mut vs = nn::VarStore::new(device);
let tokenizer: AlbertTokenizer = AlbertTokenizer::from_file(vocab_path.to_str().unwrap(), true, false);
let tokenizer: AlbertTokenizer =
AlbertTokenizer::from_file(vocab_path.to_str().unwrap(), true, false);
let config = AlbertConfig::from_file(config_path);
let albert_model = AlbertForMaskedLM::new(&vs.root(), &config);
vs.load(weights_path)?;
// Define input
let input = ["Looks like one [MASK] is missing", "It\'s like comparing [MASK] to apples"];
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
let tokenized_input = tokenized_input.
iter().
map(|input| input.token_ids.clone()).
map(|mut input| {
// Define input
let input = [
"Looks like one [MASK] is missing",
"It\'s like comparing [MASK] to apples",
];
let tokenized_input =
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let tokenized_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]);
input
}).
map(|input|
Tensor::of_slice(&(input))).
collect::<Vec<_>>();
})
.map(|input| Tensor::of_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
let (output, _, _) = no_grad(|| {
albert_model
.forward_t(Some(input_tensor),
None,
None,
None,
None,
false)
});
// Forward pass
let (output, _, _) =
no_grad(|| albert_model.forward_t(Some(input_tensor), None, None, None, None, false));
// Print masked tokens
// Print masked tokens
let index_1 = output.get(0).get(4).argmax(0, false);
let index_2 = output.get(1).get(6).argmax(0, false);
let word_1 = tokenizer.vocab().id_to_token(&index_1.int64_value(&[]));
let word_2 = tokenizer.vocab().id_to_token(&index_2.int64_value(&[]));
assert_eq!("▁them", word_1); // Outputs "_them" : "Looks like one [them] is missing (? this is identical with the original implementation)"
assert_eq!("▁grapes", word_2);// Outputs "grapes" : "It\'s like comparing [grapes] to apples"
assert_eq!("▁grapes", word_2); // Outputs "grapes" : "It\'s like comparing [grapes] to apples"
assert!((output.double_value(&[0, 0, 0]) - 4.6143).abs() < 1e-4);
Ok(())
}
@ -69,15 +79,20 @@ fn albert_masked_lm() -> failure::Fallible<()> {
#[test]
fn albert_for_sequence_classification() -> failure::Fallible<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(AlbertConfigResources::ALBERT_BASE_V2));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(AlbertVocabResources::ALBERT_BASE_V2));
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
AlbertConfigResources::ALBERT_BASE_V2,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
AlbertVocabResources::ALBERT_BASE_V2,
));
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
// Set-up model
// Set-up model
let device = Device::Cpu;
let vs = nn::VarStore::new(device);
let tokenizer: AlbertTokenizer = AlbertTokenizer::from_file(vocab_path.to_str().unwrap(), true, false);
let tokenizer: AlbertTokenizer =
AlbertTokenizer::from_file(vocab_path.to_str().unwrap(), true, false);
let mut config = AlbertConfig::from_file(config_path);
let mut dummy_label_mapping = HashMap::new();
dummy_label_mapping.insert(0, String::from("Positive"));
@ -88,37 +103,42 @@ fn albert_for_sequence_classification() -> failure::Fallible<()> {
config.output_hidden_states = Some(true);
let albert_model = AlbertForSequenceClassification::new(&vs.root(), &config);
// Define input
let input = ["Looks like one thing is missing", "It\'s like comparing oranges to apples"];
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
let tokenized_input = tokenized_input.
iter().
map(|input| input.token_ids.clone()).
map(|mut input| {
// Define input
let input = [
"Looks like one thing is missing",
"It\'s like comparing oranges to apples",
];
let tokenized_input =
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let tokenized_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]);
input
}).
map(|input|
Tensor::of_slice(&(input))).
collect::<Vec<_>>();
})
.map(|input| Tensor::of_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
let (output, all_hidden_states, all_attentions) = no_grad(|| {
albert_model
.forward_t(Some(input_tensor),
None,
None,
None,
None,
false)
});
// Forward pass
let (output, all_hidden_states, all_attentions) =
no_grad(|| albert_model.forward_t(Some(input_tensor), None, None, None, None, false));
assert_eq!(output.size(), &[2, 3]);
assert_eq!(config.num_hidden_layers as usize, all_hidden_states.unwrap().len());
assert_eq!(config.num_hidden_layers as usize, all_attentions.unwrap().len());
assert_eq!(
config.num_hidden_layers as usize,
all_hidden_states.unwrap().len()
);
assert_eq!(
config.num_hidden_layers as usize,
all_attentions.unwrap().len()
);
Ok(())
}
@ -126,50 +146,66 @@ fn albert_for_sequence_classification() -> failure::Fallible<()> {
#[test]
fn albert_for_multiple_choice() -> failure::Fallible<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(AlbertConfigResources::ALBERT_BASE_V2));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(AlbertVocabResources::ALBERT_BASE_V2));
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
AlbertConfigResources::ALBERT_BASE_V2,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
AlbertVocabResources::ALBERT_BASE_V2,
));
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
// Set-up model
// Set-up model
let device = Device::Cpu;
let vs = nn::VarStore::new(device);
let tokenizer: AlbertTokenizer = AlbertTokenizer::from_file(vocab_path.to_str().unwrap(), true, false);
let tokenizer: AlbertTokenizer =
AlbertTokenizer::from_file(vocab_path.to_str().unwrap(), true, false);
let mut config = AlbertConfig::from_file(config_path);
config.output_attentions = Some(true);
config.output_hidden_states = Some(true);
let albert_model = AlbertForMultipleChoice::new(&vs.root(), &config);
// Define input
let input = ["Looks like one thing is missing", "It\'s like comparing oranges to apples"];
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
let tokenized_input = tokenized_input.
iter().
map(|input| input.token_ids.clone()).
map(|mut input| {
// Define input
let input = [
"Looks like one thing is missing",
"It\'s like comparing oranges to apples",
];
let tokenized_input =
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let tokenized_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]);
input
}).
map(|input|
Tensor::of_slice(&(input))).
collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device).unsqueeze(0);
})
.map(|input| Tensor::of_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0)
.to(device)
.unsqueeze(0);
// Forward pass
// Forward pass
let (output, all_hidden_states, all_attentions) = no_grad(|| {
albert_model
.forward_t(Some(input_tensor),
None,
None,
None,
None,
false).unwrap()
.forward_t(Some(input_tensor), None, None, None, None, false)
.unwrap()
});
assert_eq!(output.size(), &[1, 2]);
assert_eq!(config.num_hidden_layers as usize, all_hidden_states.unwrap().len());
assert_eq!(config.num_hidden_layers as usize, all_attentions.unwrap().len());
assert_eq!(
config.num_hidden_layers as usize,
all_hidden_states.unwrap().len()
);
assert_eq!(
config.num_hidden_layers as usize,
all_attentions.unwrap().len()
);
Ok(())
}
@ -177,15 +213,20 @@ fn albert_for_multiple_choice() -> failure::Fallible<()> {
#[test]
fn albert_for_token_classification() -> failure::Fallible<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(AlbertConfigResources::ALBERT_BASE_V2));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(AlbertVocabResources::ALBERT_BASE_V2));
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
AlbertConfigResources::ALBERT_BASE_V2,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
AlbertVocabResources::ALBERT_BASE_V2,
));
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
// Set-up model
// Set-up model
let device = Device::Cpu;
let vs = nn::VarStore::new(device);
let tokenizer: AlbertTokenizer = AlbertTokenizer::from_file(vocab_path.to_str().unwrap(), true, false);
let tokenizer: AlbertTokenizer =
AlbertTokenizer::from_file(vocab_path.to_str().unwrap(), true, false);
let mut config = AlbertConfig::from_file(config_path);
let mut dummy_label_mapping = HashMap::new();
dummy_label_mapping.insert(0, String::from("O"));
@ -197,37 +238,42 @@ fn albert_for_token_classification() -> failure::Fallible<()> {
config.output_hidden_states = Some(true);
let bert_model = AlbertForTokenClassification::new(&vs.root(), &config);
// Define input
let input = ["Looks like one thing is missing", "It\'s like comparing oranges to apples"];
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
let tokenized_input = tokenized_input.
iter().
map(|input| input.token_ids.clone()).
map(|mut input| {
// Define input
let input = [
"Looks like one thing is missing",
"It\'s like comparing oranges to apples",
];
let tokenized_input =
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let tokenized_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]);
input
}).
map(|input|
Tensor::of_slice(&(input))).
collect::<Vec<_>>();
})
.map(|input| Tensor::of_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
let (output, all_hidden_states, all_attentions) = no_grad(|| {
bert_model
.forward_t(Some(input_tensor),
None,
None,
None,
None,
false)
});
// Forward pass
let (output, all_hidden_states, all_attentions) =
no_grad(|| bert_model.forward_t(Some(input_tensor), None, None, None, None, false));
assert_eq!(output.size(), &[2, 12, 4]);
assert_eq!(config.num_hidden_layers as usize, all_hidden_states.unwrap().len());
assert_eq!(config.num_hidden_layers as usize, all_attentions.unwrap().len());
assert_eq!(
config.num_hidden_layers as usize,
all_hidden_states.unwrap().len()
);
assert_eq!(
config.num_hidden_layers as usize,
all_attentions.unwrap().len()
);
Ok(())
}
@ -235,51 +281,62 @@ fn albert_for_token_classification() -> failure::Fallible<()> {
#[test]
fn albert_for_question_answering() -> failure::Fallible<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(AlbertConfigResources::ALBERT_BASE_V2));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(AlbertVocabResources::ALBERT_BASE_V2));
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
AlbertConfigResources::ALBERT_BASE_V2,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
AlbertVocabResources::ALBERT_BASE_V2,
));
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
// Set-up model
// Set-up model
let device = Device::Cpu;
let vs = nn::VarStore::new(device);
let tokenizer: AlbertTokenizer = AlbertTokenizer::from_file(vocab_path.to_str().unwrap(), true, false);
let tokenizer: AlbertTokenizer =
AlbertTokenizer::from_file(vocab_path.to_str().unwrap(), true, false);
let mut config = AlbertConfig::from_file(config_path);
config.output_attentions = Some(true);
config.output_hidden_states = Some(true);
let albert_model = AlbertForQuestionAnswering::new(&vs.root(), &config);
// Define input
let input = ["Looks like one thing is missing", "It\'s like comparing oranges to apples"];
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
let tokenized_input = tokenized_input.
iter().
map(|input| input.token_ids.clone()).
map(|mut input| {
// Define input
let input = [
"Looks like one thing is missing",
"It\'s like comparing oranges to apples",
];
let tokenized_input =
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let tokenized_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]);
input
}).
map(|input|
Tensor::of_slice(&(input))).
collect::<Vec<_>>();
})
.map(|input| Tensor::of_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
let (start_scores, end_scores, all_hidden_states, all_attentions) = no_grad(|| {
albert_model
.forward_t(Some(input_tensor),
None,
None,
None,
None,
false)
});
// Forward pass
let (start_scores, end_scores, all_hidden_states, all_attentions) =
no_grad(|| albert_model.forward_t(Some(input_tensor), None, None, None, None, false));
assert_eq!(start_scores.size(), &[2, 12]);
assert_eq!(end_scores.size(), &[2, 12]);
assert_eq!(config.num_hidden_layers as usize, all_hidden_states.unwrap().len());
assert_eq!(config.num_hidden_layers as usize, all_attentions.unwrap().len());
assert_eq!(
config.num_hidden_layers as usize,
all_hidden_states.unwrap().len()
);
assert_eq!(
config.num_hidden_layers as usize,
all_attentions.unwrap().len()
);
Ok(())
}

View File

@ -1,56 +1,65 @@
use tch::{Device, nn, Tensor};
use rust_tokenizers::{TruncationStrategy, Tokenizer, RobertaTokenizer};
use rust_bert::Config;
use rust_bert::bart::{BartConfig, BartConfigResources, BartVocabResources, BartModelResources, BartMergesResources, BartModel};
use rust_bert::bart::{
BartConfig, BartConfigResources, BartMergesResources, BartModel, BartModelResources,
BartVocabResources,
};
use rust_bert::pipelines::summarization::{SummarizationConfig, SummarizationModel};
use rust_bert::resources::{Resource, RemoteResource, download_resource};
use rust_bert::resources::{download_resource, RemoteResource, Resource};
use rust_bert::Config;
use rust_tokenizers::{RobertaTokenizer, Tokenizer, TruncationStrategy};
use tch::{nn, Device, Tensor};
#[test]
#[cfg_attr(not(feature = "all-tests"), ignore)]
fn bart_lm_model() -> failure::Fallible<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(BartConfigResources::BART));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(BartVocabResources::BART));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(BartMergesResources::BART));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(BartModelResources::BART));
let config_resource =
Resource::Remote(RemoteResource::from_pretrained(BartConfigResources::BART));
let vocab_resource =
Resource::Remote(RemoteResource::from_pretrained(BartVocabResources::BART));
let merges_resource =
Resource::Remote(RemoteResource::from_pretrained(BartMergesResources::BART));
let weights_resource =
Resource::Remote(RemoteResource::from_pretrained(BartModelResources::BART));
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
let merges_path = download_resource(&merges_resource)?;
let weights_path = download_resource(&weights_resource)?;
// Set-up masked LM model
// Set-up masked LM model
let device = Device::Cpu;
let mut vs = nn::VarStore::new(device);
let tokenizer: RobertaTokenizer = RobertaTokenizer::from_file(vocab_path.to_str().unwrap(), merges_path.to_str().unwrap(), false);
let tokenizer: RobertaTokenizer = RobertaTokenizer::from_file(
vocab_path.to_str().unwrap(),
merges_path.to_str().unwrap(),
false,
);
let config = BartConfig::from_file(config_path);
let bart_model = BartModel::new(&vs.root(), &config, false);
vs.load(weights_path)?;
// Define input
// Define input
let input = ["One two three four"];
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
let tokenized_input = tokenized_input.
iter().
map(|input| input.token_ids.clone()).
map(|mut input| {
let tokenized_input =
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let tokenized_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]);
input
}).
map(|input|
Tensor::of_slice(&(input))).
collect::<Vec<_>>();
})
.map(|input| Tensor::of_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
let (output, encoder_outputs, _, _, _, _, _) = bart_model.forward_t(
Some(&input_tensor),
None,
None,
None,
None,
None,
false);
// Forward pass
let (output, encoder_outputs, _, _, _, _, _) =
bart_model.forward_t(Some(&input_tensor), None, None, None, None, None, false);
assert_eq!(output.size(), vec!(1, 6, 1024));
assert_eq!(encoder_outputs.size(), vec!(1, 6, 1024));
@ -58,12 +67,10 @@ fn bart_lm_model() -> failure::Fallible<()> {
Ok(())
}
#[test]
#[cfg_attr(not(feature = "all-tests"), ignore)]
fn bart_summarization_greedy() -> failure::Fallible<()> {
// Set-up masked LM model
// Set-up masked LM model
let summarization_config = SummarizationConfig {
num_beams: 1,
device: Device::Cpu,
@ -93,7 +100,7 @@ on K2-18b lasts 33 Earth days. According to The Guardian, astronomers were optim
telescope scheduled for launch in 2021 and the European Space Agency's 2028 ARIEL program, could reveal more \
about exoplanets like K2-18b."];
// Credits: WikiNews, CC BY 2.5 license (https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b)
// Credits: WikiNews, CC BY 2.5 license (https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b)
let output = model.summarize(&input);
assert_eq!(output.len(), 1);
@ -107,8 +114,7 @@ about exoplanets like K2-18b."];
#[test]
#[cfg_attr(not(feature = "all-tests"), ignore)]
fn bart_summarization_beam_search() -> failure::Fallible<()> {
// Set-up masked LM model
// Set-up masked LM model
let summarization_config = SummarizationConfig {
num_beams: 3,
device: Device::Cpu,
@ -138,7 +144,7 @@ on K2-18b lasts 33 Earth days. According to The Guardian, astronomers were optim
telescope scheduled for launch in 2021 and the European Space Agency's 2028 ARIEL program, could reveal more \
about exoplanets like K2-18b."];
// Credits: WikiNews, CC BY 2.5 license (https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b)
// Credits: WikiNews, CC BY 2.5 license (https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b)
let output = model.summarize(&input);
assert_eq!(output.len(), 1);
@ -148,4 +154,4 @@ about exoplanets like K2-18b."];
star as the planet passed between it and Earth.");
Ok(())
}
}

View File

@ -1,27 +1,32 @@
extern crate failure;
extern crate dirs;
extern crate failure;
use tch::{Device, nn, Tensor, no_grad};
use rust_tokenizers::{BertTokenizer, TruncationStrategy, Tokenizer, Vocab};
use rust_bert::Config;
use rust_bert::bert::{BertConfig, BertForMaskedLM, BertForSequenceClassification, BertForMultipleChoice, BertForTokenClassification, BertForQuestionAnswering,
BertConfigResources, BertVocabResources, BertModelResources};
use rust_bert::bert::{
BertConfig, BertConfigResources, BertForMaskedLM, BertForMultipleChoice,
BertForQuestionAnswering, BertForSequenceClassification, BertForTokenClassification,
BertModelResources, BertVocabResources,
};
use rust_bert::pipelines::ner::NERModel;
use rust_bert::resources::{Resource, RemoteResource, download_resource};
use rust_bert::resources::{download_resource, RemoteResource, Resource};
use rust_bert::Config;
use rust_tokenizers::{BertTokenizer, Tokenizer, TruncationStrategy, Vocab};
use std::collections::HashMap;
use tch::{nn, no_grad, Device, Tensor};
#[test]
fn bert_masked_lm() -> failure::Fallible<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(BertModelResources::BERT));
let config_resource =
Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT));
let vocab_resource =
Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT));
let weights_resource =
Resource::Remote(RemoteResource::from_pretrained(BertModelResources::BERT));
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
let weights_path = download_resource(&weights_resource)?;
// Set-up masked LM model
// Set-up masked LM model
let device = Device::Cpu;
let mut vs = nn::VarStore::new(device);
let tokenizer: BertTokenizer = BertTokenizer::from_file(vocab_path.to_str().unwrap(), true);
@ -29,50 +34,58 @@ fn bert_masked_lm() -> failure::Fallible<()> {
let bert_model = BertForMaskedLM::new(&vs.root(), &config);
vs.load(weights_path)?;
// Define input
let input = ["Looks like one thing is missing", "It\'s like comparing oranges to apples"];
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
let mut tokenized_input = tokenized_input.
iter().
map(|input| input.token_ids.clone()).
map(|mut input| {
// Define input
let input = [
"Looks like one thing is missing",
"It\'s like comparing oranges to apples",
];
let tokenized_input =
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let mut tokenized_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]);
input
}).
collect::<Vec<_>>();
})
.collect::<Vec<_>>();
// Masking the token [thing] of sentence 1 and [oranges] of sentence 2
// Masking the token [thing] of sentence 1 and [oranges] of sentence 2
tokenized_input[0][4] = 103;
tokenized_input[1][6] = 103;
let tokenized_input = tokenized_input.
iter().
map(|input|
Tensor::of_slice(&(input))).
collect::<Vec<_>>();
let tokenized_input = tokenized_input
.iter()
.map(|input| Tensor::of_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
// Forward pass
let (output, _, _) = no_grad(|| {
bert_model
.forward_t(Some(input_tensor),
None,
None,
None,
None,
&None,
&None,
false)
bert_model.forward_t(
Some(input_tensor),
None,
None,
None,
None,
&None,
&None,
false,
)
});
// Print masked tokens
// Print masked tokens
let index_1 = output.get(0).get(4).argmax(0, false);
let index_2 = output.get(1).get(6).argmax(0, false);
let word_1 = tokenizer.vocab().id_to_token(&index_1.int64_value(&[]));
let word_2 = tokenizer.vocab().id_to_token(&index_2.int64_value(&[]));
assert_eq!("person", word_1); // Outputs "person" : "Looks like one [person] is missing"
assert_eq!("orange", word_2);// Outputs "pear" : "It\'s like comparing [pear] to apples"
assert_eq!("orange", word_2); // Outputs "pear" : "It\'s like comparing [pear] to apples"
Ok(())
}
@ -80,12 +93,14 @@ fn bert_masked_lm() -> failure::Fallible<()> {
#[test]
fn bert_for_sequence_classification() -> failure::Fallible<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT));
let config_resource =
Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT));
let vocab_resource =
Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT));
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
// Set-up model
// Set-up model
let device = Device::Cpu;
let vs = nn::VarStore::new(device);
let tokenizer: BertTokenizer = BertTokenizer::from_file(vocab_path.to_str().unwrap(), true);
@ -99,37 +114,42 @@ fn bert_for_sequence_classification() -> failure::Fallible<()> {
config.output_hidden_states = Some(true);
let bert_model = BertForSequenceClassification::new(&vs.root(), &config);
// Define input
let input = ["Looks like one thing is missing", "It\'s like comparing oranges to apples"];
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
let tokenized_input = tokenized_input.
iter().
map(|input| input.token_ids.clone()).
map(|mut input| {
// Define input
let input = [
"Looks like one thing is missing",
"It\'s like comparing oranges to apples",
];
let tokenized_input =
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let tokenized_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]);
input
}).
map(|input|
Tensor::of_slice(&(input))).
collect::<Vec<_>>();
})
.map(|input| Tensor::of_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
let (output, all_hidden_states, all_attentions) = no_grad(|| {
bert_model
.forward_t(Some(input_tensor),
None,
None,
None,
None,
false)
});
// Forward pass
let (output, all_hidden_states, all_attentions) =
no_grad(|| bert_model.forward_t(Some(input_tensor), None, None, None, None, false));
assert_eq!(output.size(), &[2, 3]);
assert_eq!(config.num_hidden_layers as usize, all_hidden_states.unwrap().len());
assert_eq!(config.num_hidden_layers as usize, all_attentions.unwrap().len());
assert_eq!(
config.num_hidden_layers as usize,
all_hidden_states.unwrap().len()
);
assert_eq!(
config.num_hidden_layers as usize,
all_attentions.unwrap().len()
);
Ok(())
}
@ -137,12 +157,14 @@ fn bert_for_sequence_classification() -> failure::Fallible<()> {
#[test]
fn bert_for_multiple_choice() -> failure::Fallible<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT));
let config_resource =
Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT));
let vocab_resource =
Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT));
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
// Set-up model
// Set-up model
let device = Device::Cpu;
let vs = nn::VarStore::new(device);
let tokenizer: BertTokenizer = BertTokenizer::from_file(vocab_path.to_str().unwrap(), true);
@ -151,35 +173,44 @@ fn bert_for_multiple_choice() -> failure::Fallible<()> {
config.output_hidden_states = Some(true);
let bert_model = BertForMultipleChoice::new(&vs.root(), &config);
// Define input
let input = ["Looks like one thing is missing", "It\'s like comparing oranges to apples"];
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
let tokenized_input = tokenized_input.
iter().
map(|input| input.token_ids.clone()).
map(|mut input| {
// Define input
let input = [
"Looks like one thing is missing",
"It\'s like comparing oranges to apples",
];
let tokenized_input =
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let tokenized_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]);
input
}).
map(|input|
Tensor::of_slice(&(input))).
collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device).unsqueeze(0);
})
.map(|input| Tensor::of_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0)
.to(device)
.unsqueeze(0);
// Forward pass
let (output, all_hidden_states, all_attentions) = no_grad(|| {
bert_model
.forward_t(input_tensor,
None,
None,
None,
false)
});
// Forward pass
let (output, all_hidden_states, all_attentions) =
no_grad(|| bert_model.forward_t(input_tensor, None, None, None, false));
assert_eq!(output.size(), &[1, 2]);
assert_eq!(config.num_hidden_layers as usize, all_hidden_states.unwrap().len());
assert_eq!(config.num_hidden_layers as usize, all_attentions.unwrap().len());
assert_eq!(
config.num_hidden_layers as usize,
all_hidden_states.unwrap().len()
);
assert_eq!(
config.num_hidden_layers as usize,
all_attentions.unwrap().len()
);
Ok(())
}
@ -187,12 +218,14 @@ fn bert_for_multiple_choice() -> failure::Fallible<()> {
#[test]
fn bert_for_token_classification() -> failure::Fallible<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT));
let config_resource =
Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT));
let vocab_resource =
Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT));
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
// Set-up model
// Set-up model
let device = Device::Cpu;
let vs = nn::VarStore::new(device);
let tokenizer: BertTokenizer = BertTokenizer::from_file(vocab_path.to_str().unwrap(), true);
@ -207,37 +240,42 @@ fn bert_for_token_classification() -> failure::Fallible<()> {
config.output_hidden_states = Some(true);
let bert_model = BertForTokenClassification::new(&vs.root(), &config);
// Define input
let input = ["Looks like one thing is missing", "It\'s like comparing oranges to apples"];
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
let tokenized_input = tokenized_input.
iter().
map(|input| input.token_ids.clone()).
map(|mut input| {
// Define input
let input = [
"Looks like one thing is missing",
"It\'s like comparing oranges to apples",
];
let tokenized_input =
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let tokenized_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]);
input
}).
map(|input|
Tensor::of_slice(&(input))).
collect::<Vec<_>>();
})
.map(|input| Tensor::of_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
let (output, all_hidden_states, all_attentions) = no_grad(|| {
bert_model
.forward_t(Some(input_tensor),
None,
None,
None,
None,
false)
});
// Forward pass
let (output, all_hidden_states, all_attentions) =
no_grad(|| bert_model.forward_t(Some(input_tensor), None, None, None, None, false));
assert_eq!(output.size(), &[2, 11, 4]);
assert_eq!(config.num_hidden_layers as usize, all_hidden_states.unwrap().len());
assert_eq!(config.num_hidden_layers as usize, all_attentions.unwrap().len());
assert_eq!(
config.num_hidden_layers as usize,
all_hidden_states.unwrap().len()
);
assert_eq!(
config.num_hidden_layers as usize,
all_attentions.unwrap().len()
);
Ok(())
}
@ -245,12 +283,14 @@ fn bert_for_token_classification() -> failure::Fallible<()> {
#[test]
fn bert_for_question_answering() -> failure::Fallible<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT));
let config_resource =
Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT));
let vocab_resource =
Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT));
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
// Set-up model
// Set-up model
let device = Device::Cpu;
let vs = nn::VarStore::new(device);
let tokenizer: BertTokenizer = BertTokenizer::from_file(vocab_path.to_str().unwrap(), true);
@ -259,53 +299,59 @@ fn bert_for_question_answering() -> failure::Fallible<()> {
config.output_hidden_states = Some(true);
let bert_model = BertForQuestionAnswering::new(&vs.root(), &config);
// Define input
let input = ["Looks like one thing is missing", "It\'s like comparing oranges to apples"];
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
let tokenized_input = tokenized_input.
iter().
map(|input| input.token_ids.clone()).
map(|mut input| {
// Define input
let input = [
"Looks like one thing is missing",
"It\'s like comparing oranges to apples",
];
let tokenized_input =
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let tokenized_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]);
input
}).
map(|input|
Tensor::of_slice(&(input))).
collect::<Vec<_>>();
})
.map(|input| Tensor::of_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
let (start_scores, end_scores, all_hidden_states, all_attentions) = no_grad(|| {
bert_model
.forward_t(Some(input_tensor),
None,
None,
None,
None,
false)
});
// Forward pass
let (start_scores, end_scores, all_hidden_states, all_attentions) =
no_grad(|| bert_model.forward_t(Some(input_tensor), None, None, None, None, false));
assert_eq!(start_scores.size(), &[2, 11]);
assert_eq!(end_scores.size(), &[2, 11]);
assert_eq!(config.num_hidden_layers as usize, all_hidden_states.unwrap().len());
assert_eq!(config.num_hidden_layers as usize, all_attentions.unwrap().len());
assert_eq!(
config.num_hidden_layers as usize,
all_hidden_states.unwrap().len()
);
assert_eq!(
config.num_hidden_layers as usize,
all_attentions.unwrap().len()
);
Ok(())
}
#[test]
fn bert_pre_trained_ner() -> failure::Fallible<()> {
// Set-up model
// Set-up model
let ner_model = NERModel::new(Default::default())?;
// Define input
// Define input
let input = [
"My name is Amy. I live in Paris.",
"Paris is a city in France."
"Paris is a city in France.",
];
// Run model
// Run model
let output = ner_model.predict(&input);
assert_eq!(output.len(), 4);
@ -327,4 +373,4 @@ fn bert_pre_trained_ner() -> failure::Fallible<()> {
assert_eq!(output[3].label, "I-LOC");
Ok(())
}
}

View File

@ -1,22 +1,26 @@
use tch::{Device, Tensor, nn, no_grad};
use rust_tokenizers::preprocessing::tokenizer::base_tokenizer::{Tokenizer, TruncationStrategy};
use rust_tokenizers::bert_tokenizer::BertTokenizer;
use rust_tokenizers::preprocessing::vocab::base_vocab::Vocab;
use rust_bert::Config;
use rust_bert::distilbert::{DistilBertConfig, DistilBertModelMaskedLM, DistilBertForQuestionAnswering, DistilBertForTokenClassification, DistilBertModelResources, DistilBertConfigResources, DistilBertVocabResources};
use rust_bert::distilbert::{
DistilBertConfig, DistilBertConfigResources, DistilBertForQuestionAnswering,
DistilBertForTokenClassification, DistilBertModelMaskedLM, DistilBertModelResources,
DistilBertVocabResources,
};
use rust_bert::pipelines::question_answering::{QaInput, QuestionAnsweringModel};
use rust_bert::pipelines::sentiment::{SentimentModel, SentimentPolarity};
use rust_bert::pipelines::question_answering::{QuestionAnsweringModel, QaInput};
use rust_bert::resources::{Resource, RemoteResource, download_resource};
use rust_bert::resources::{download_resource, RemoteResource, Resource};
use rust_bert::Config;
use rust_tokenizers::bert_tokenizer::BertTokenizer;
use rust_tokenizers::preprocessing::tokenizer::base_tokenizer::{Tokenizer, TruncationStrategy};
use rust_tokenizers::preprocessing::vocab::base_vocab::Vocab;
use std::collections::HashMap;
use tch::{nn, no_grad, Device, Tensor};
extern crate failure;
#[test]
fn distilbert_sentiment_classifier() -> failure::Fallible<()> {
// Set-up classifier
// Set-up classifier
let sentiment_classifier = SentimentModel::new(Default::default())?;
// Get sentiments
// Get sentiments
let input = [
"Probably my all-time favorite movie, a story of selflessness, sacrifice and dedication to a noble cause, but it's not preachy or boring.",
"This film tried to be too many things all at once: stinging political satire, Hollywood blockbuster, sappy romantic comedy, family values promo...",
@ -36,18 +40,23 @@ fn distilbert_sentiment_classifier() -> failure::Fallible<()> {
Ok(())
}
#[test]
fn distilbert_masked_lm() -> failure::Fallible<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(DistilBertConfigResources::DISTIL_BERT));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(DistilBertVocabResources::DISTIL_BERT));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(DistilBertModelResources::DISTIL_BERT));
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
DistilBertConfigResources::DISTIL_BERT,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
DistilBertVocabResources::DISTIL_BERT,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
DistilBertModelResources::DISTIL_BERT,
));
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
let weights_path = download_resource(&weights_resource)?;
// Set-up masked LM model
// Set-up masked LM model
let device = Device::cuda_if_available();
let mut vs = nn::VarStore::new(device);
let tokenizer: BertTokenizer = BertTokenizer::from_file(vocab_path.to_str().unwrap(), true);
@ -55,59 +64,68 @@ fn distilbert_masked_lm() -> failure::Fallible<()> {
let distil_bert_model = DistilBertModelMaskedLM::new(&vs.root(), &config);
vs.load(weights_path)?;
// Define input
let input = ["Looks like one thing is missing", "It\'s like comparing oranges to apples"];
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
let mut tokenized_input = tokenized_input.
iter().
map(|input| input.token_ids.clone()).
map(|mut input| {
// Define input
let input = [
"Looks like one thing is missing",
"It\'s like comparing oranges to apples",
];
let tokenized_input =
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let mut tokenized_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]);
input
}).
collect::<Vec<_>>();
})
.collect::<Vec<_>>();
// Masking the token [thing] of sentence 1 and [oranges] of sentence 2
// Masking the token [thing] of sentence 1 and [oranges] of sentence 2
tokenized_input[0][4] = 103;
tokenized_input[1][6] = 103;
let tokenized_input = tokenized_input.
iter().
map(|input|
Tensor::of_slice(&(input))).
collect::<Vec<_>>();
let tokenized_input = tokenized_input
.iter()
.map(|input| Tensor::of_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
// Forward pass
let (output, _, _) = no_grad(|| {
distil_bert_model
.forward_t(Some(input_tensor), None, None, false)
.unwrap()
});
// Print masked tokens
// Print masked tokens
let index_1 = output.get(0).get(4).argmax(0, false);
let index_2 = output.get(1).get(6).argmax(0, false);
let word_1 = tokenizer.vocab().id_to_token(&index_1.int64_value(&[]));
let word_2 = tokenizer.vocab().id_to_token(&index_2.int64_value(&[]));
assert_eq!("person", word_1); // Outputs "person" : "Looks like one [person] is missing"
assert_eq!("pear", word_2);// Outputs "pear" : "It\'s like comparing [pear] to apples"
assert_eq!("pear", word_2); // Outputs "pear" : "It\'s like comparing [pear] to apples"
Ok(())
}
#[test]
fn distilbert_for_question_answering() -> failure::Fallible<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(DistilBertConfigResources::DISTIL_BERT_SQUAD));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(DistilBertVocabResources::DISTIL_BERT_SQUAD));
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
DistilBertConfigResources::DISTIL_BERT_SQUAD,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
DistilBertVocabResources::DISTIL_BERT_SQUAD,
));
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
// Set-up masked LM model
// Set-up masked LM model
let device = Device::cuda_if_available();
let vs = nn::VarStore::new(device);
let tokenizer: BertTokenizer = BertTokenizer::from_file(vocab_path.to_str().unwrap(), true);
@ -116,23 +134,30 @@ fn distilbert_for_question_answering() -> failure::Fallible<()> {
config.output_hidden_states = Some(true);
let distil_bert_model = DistilBertForQuestionAnswering::new(&vs.root(), &config);
// Define input
let input = ["Looks like one thing is missing", "It\'s like comparing oranges to apples"];
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
let tokenized_input = tokenized_input.
iter().
map(|input| input.token_ids.clone()).
map(|mut input| {
// Define input
let input = [
"Looks like one thing is missing",
"It\'s like comparing oranges to apples",
];
let tokenized_input =
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let tokenized_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]);
input
}).
map(|input|
Tensor::of_slice(&(input))).
collect::<Vec<_>>();
})
.map(|input| Tensor::of_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
// Forward pass
let (start_scores, end_scores, all_hidden_states, all_attentions) = no_grad(|| {
distil_bert_model
.forward_t(Some(input_tensor), None, None, false)
@ -149,14 +174,17 @@ fn distilbert_for_question_answering() -> failure::Fallible<()> {
#[test]
fn distilbert_for_token_classification() -> failure::Fallible<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(DistilBertConfigResources::DISTIL_BERT));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(DistilBertVocabResources::DISTIL_BERT));
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
DistilBertConfigResources::DISTIL_BERT,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
DistilBertVocabResources::DISTIL_BERT,
));
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
// Set-up masked LM model
// Set-up masked LM model
let device = Device::cuda_if_available();
let vs = nn::VarStore::new(device);
let tokenizer: BertTokenizer = BertTokenizer::from_file(vocab_path.to_str().unwrap(), true);
@ -171,23 +199,30 @@ fn distilbert_for_token_classification() -> failure::Fallible<()> {
config.id2label = Some(dummy_label_mapping);
let distil_bert_model = DistilBertForTokenClassification::new(&vs.root(), &config);
// Define input
let input = ["Looks like one thing is missing", "It\'s like comparing oranges to apples"];
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
let tokenized_input = tokenized_input.
iter().
map(|input| input.token_ids.clone()).
map(|mut input| {
// Define input
let input = [
"Looks like one thing is missing",
"It\'s like comparing oranges to apples",
];
let tokenized_input =
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let tokenized_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]);
input
}).
map(|input|
Tensor::of_slice(&(input))).
collect::<Vec<_>>();
})
.map(|input| Tensor::of_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
// Forward pass
let (output, all_hidden_states, all_attentions) = no_grad(|| {
distil_bert_model
.forward_t(Some(input_tensor), None, None, false)
@ -203,15 +238,15 @@ fn distilbert_for_token_classification() -> failure::Fallible<()> {
#[test]
fn distilbert_question_answering() -> failure::Fallible<()> {
// Set-up question answering model
// Set-up question answering model
let qa_model = QuestionAnsweringModel::new(Default::default())?;
// Define input
// Define input
let question = String::from("Where does Amy live ?");
let context = String::from("Amy lives in Amsterdam");
let qa_input = QaInput { question, context };
let answers = qa_model.predict(&vec!(qa_input), 1, 32);
let answers = qa_model.predict(&vec![qa_input], 1, 32);
assert_eq!(answers.len(), 1 as usize);
assert_eq!(answers[0].len(), 1 as usize);
@ -221,4 +256,4 @@ fn distilbert_question_answering() -> failure::Fallible<()> {
assert_eq!(answers[0][0].answer, "Amsterdam");
Ok(())
}
}

View File

@ -1,73 +1,100 @@
use tch::{Device, nn, Tensor};
use rust_tokenizers::{Gpt2Tokenizer, TruncationStrategy, Tokenizer};
use rust_bert::gpt2::{
GPT2LMHeadModel, Gpt2Config, Gpt2ConfigResources, Gpt2MergesResources, Gpt2ModelResources,
Gpt2VocabResources,
};
use rust_bert::pipelines::generation::{Cache, LMHeadModel};
use rust_bert::resources::{download_resource, RemoteResource, Resource};
use rust_bert::Config;
use rust_bert::gpt2::{Gpt2Config, GPT2LMHeadModel, Gpt2ConfigResources, Gpt2VocabResources, Gpt2MergesResources, Gpt2ModelResources};
use rust_bert::pipelines::generation::{LMHeadModel, Cache};
use rust_bert::resources::{Resource, download_resource, RemoteResource};
use rust_tokenizers::{Gpt2Tokenizer, Tokenizer, TruncationStrategy};
use tch::{nn, Device, Tensor};
#[test]
fn distilgpt2_lm_model() -> failure::Fallible<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::DISTIL_GPT2));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::DISTIL_GPT2));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::DISTIL_GPT2));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::DISTIL_GPT2));
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
Gpt2ConfigResources::DISTIL_GPT2,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
Gpt2VocabResources::DISTIL_GPT2,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
Gpt2MergesResources::DISTIL_GPT2,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
Gpt2ModelResources::DISTIL_GPT2,
));
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
let merges_path = download_resource(&merges_resource)?;
let weights_path = download_resource(&weights_resource)?;
// Set-up masked LM model
// Set-up masked LM model
let device = Device::Cpu;
let mut vs = nn::VarStore::new(device);
let tokenizer: Gpt2Tokenizer = Gpt2Tokenizer::from_file(vocab_path.to_str().unwrap(), merges_path.to_str().unwrap(), false);
let tokenizer: Gpt2Tokenizer = Gpt2Tokenizer::from_file(
vocab_path.to_str().unwrap(),
merges_path.to_str().unwrap(),
false,
);
let config = Gpt2Config::from_file(config_path);
let gpt2_model = GPT2LMHeadModel::new(&vs.root(), &config);
vs.load(weights_path)?;
// Define input
// Define input
let input = ["One two three four five six seven eight nine ten eleven"];
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
let tokenized_input = tokenized_input.
iter().
map(|input| input.token_ids.clone()).
map(|mut input| {
let tokenized_input =
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let tokenized_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]);
input
}).
map(|input|
Tensor::of_slice(&(input))).
collect::<Vec<_>>();
})
.map(|input| Tensor::of_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
let (output, _, past, _, _) = gpt2_model.forward_t(
&Some(input_tensor),
Cache::None,
&None,
&None,
&None,
&None,
None,
&None,
false).unwrap();
// Forward pass
let (output, _, past, _, _) = gpt2_model
.forward_t(
&Some(input_tensor),
Cache::None,
&None,
&None,
&None,
&None,
None,
&None,
false,
)
.unwrap();
let next_word_id = output.get(0).get(-1).argmax(-1, true).int64_value(&[0]);
let next_word = tokenizer.decode(vec!(next_word_id), true, true);
let next_word = tokenizer.decode(vec![next_word_id], true, true);
assert_eq!(output.size(), vec!(1, 11, 50257));
match past {
Cache::GPT2Cache(past) => {
assert!(past.is_some());
assert_eq!(past.as_ref().unwrap().len(), config.n_layer as usize);
assert_eq!(past.as_ref().unwrap()[0].size(), vec!(2, 1, config.n_head, 11, 64));
assert_eq!(
past.as_ref().unwrap()[0].size(),
vec!(2, 1, config.n_head, 11, 64)
);
}
_ => panic!("Wrong cache returned for GPT2")
_ => panic!("Wrong cache returned for GPT2"),
}
assert!((output.double_value(&[0, output.size()[1] - 1, next_word_id]) - (-48.7065)).abs() < 1e-4);
assert!(
(output.double_value(&[0, output.size()[1] - 1, next_word_id]) - (-48.7065)).abs() < 1e-4
);
assert_eq!(next_word_id, 14104i64);
assert_eq!(next_word, String::from(" twelve"));
Ok(())
}
}

View File

@ -1,20 +1,29 @@
use rust_bert::resources::{Resource, download_resource, RemoteResource};
use rust_bert::electra::{ElectraConfigResources, ElectraVocabResources, ElectraModelResources, ElectraConfig, ElectraForMaskedLM, ElectraDiscriminator};
use tch::{Device, nn, Tensor, no_grad};
use rust_tokenizers::{BertTokenizer, TruncationStrategy, Tokenizer, Vocab};
use rust_bert::electra::{
ElectraConfig, ElectraConfigResources, ElectraDiscriminator, ElectraForMaskedLM,
ElectraModelResources, ElectraVocabResources,
};
use rust_bert::resources::{download_resource, RemoteResource, Resource};
use rust_bert::Config;
use rust_tokenizers::{BertTokenizer, Tokenizer, TruncationStrategy, Vocab};
use tch::{nn, no_grad, Device, Tensor};
#[test]
fn electra_masked_lm() -> failure::Fallible<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(ElectraConfigResources::BASE_GENERATOR));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(ElectraVocabResources::BASE_GENERATOR));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(ElectraModelResources::BASE_GENERATOR));
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
ElectraConfigResources::BASE_GENERATOR,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
ElectraVocabResources::BASE_GENERATOR,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
ElectraModelResources::BASE_GENERATOR,
));
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
let weights_path = download_resource(&weights_resource)?;
// Set-up masked LM model
// Set-up masked LM model
let device = Device::Cpu;
let mut vs = nn::VarStore::new(device);
let tokenizer: BertTokenizer = BertTokenizer::from_file(vocab_path.to_str().unwrap(), true);
@ -24,60 +33,70 @@ fn electra_masked_lm() -> failure::Fallible<()> {
let electra_model = ElectraForMaskedLM::new(&vs.root(), &config);
vs.load(weights_path)?;
// Define input
let input = ["Looks like one [MASK] is missing", "It was a very nice and [MASK] day"];
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
let tokenized_input = tokenized_input.
iter().
map(|input| input.token_ids.clone()).
map(|mut input| {
// Define input
let input = [
"Looks like one [MASK] is missing",
"It was a very nice and [MASK] day",
];
let tokenized_input =
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let tokenized_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]);
input
}).
map(|input|
Tensor::of_slice(&(input))).
collect::<Vec<_>>();
})
.map(|input| Tensor::of_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
let (output,
all_hidden_states,
all_attentions) = no_grad(|| {
electra_model
.forward_t(Some(input_tensor),
None,
None,
None,
None,
false)
});
// Forward pass
let (output, all_hidden_states, all_attentions) =
no_grad(|| electra_model.forward_t(Some(input_tensor), None, None, None, None, false));
// Decode output
// Decode output
let index_1 = output.get(0).get(4).argmax(0, false);
let index_2 = output.get(1).get(7).argmax(0, false);
let word_1 = tokenizer.vocab().id_to_token(&index_1.int64_value(&[]));
let word_2 = tokenizer.vocab().id_to_token(&index_2.int64_value(&[]));
assert_eq!(output.size(), &[2, 10, config.vocab_size]);
assert_eq!(config.num_hidden_layers as usize, all_hidden_states.unwrap().len());
assert_eq!(config.num_hidden_layers as usize, all_attentions.unwrap().len());
assert_eq!(
config.num_hidden_layers as usize,
all_hidden_states.unwrap().len()
);
assert_eq!(
config.num_hidden_layers as usize,
all_attentions.unwrap().len()
);
assert_eq!("thing", word_1); // Outputs "person" : "Looks like one [person] is missing"
assert_eq!("sunny", word_2);// Outputs "pear" : "It was a very nice and [sunny] day"
assert_eq!("sunny", word_2); // Outputs "pear" : "It was a very nice and [sunny] day"
Ok(())
}
#[test]
fn electra_discriminator() -> failure::Fallible<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(ElectraConfigResources::BASE_DISCRIMINATOR));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(ElectraVocabResources::BASE_DISCRIMINATOR));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(ElectraModelResources::BASE_DISCRIMINATOR));
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
ElectraConfigResources::BASE_DISCRIMINATOR,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
ElectraVocabResources::BASE_DISCRIMINATOR,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
ElectraModelResources::BASE_DISCRIMINATOR,
));
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
let weights_path = download_resource(&weights_resource)?;
// Set-up masked LM model
// Set-up masked LM model
let device = Device::Cpu;
let mut vs = nn::VarStore::new(device);
let tokenizer: BertTokenizer = BertTokenizer::from_file(vocab_path.to_str().unwrap(), true);
@ -85,35 +104,34 @@ fn electra_discriminator() -> failure::Fallible<()> {
let electra_model = ElectraDiscriminator::new(&vs.root(), &config);
vs.load(weights_path)?;
// Define input
// Define input
let input = ["One Two Three Ten Five Six Seven Eight"];
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
let encoded_input = tokenized_input.
iter().
map(|input| input.token_ids.clone()).
map(|mut input| {
let tokenized_input =
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let encoded_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]);
input
}).
map(|input|
Tensor::of_slice(&(input))).
collect::<Vec<_>>();
})
.map(|input| Tensor::of_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(encoded_input.as_slice(), 0).to(device);
// Forward pass
let (output, _, _) = no_grad(|| {
electra_model
.forward_t(Some(input_tensor),
None,
None,
None,
None,
false)
});
let (output, _, _) =
no_grad(|| electra_model.forward_t(Some(input_tensor), None, None, None, None, false));
// Validate model predictions
let expected_probabilities = vec!(0.0101, 0.0030, 0.0010, 0.0018, 0.9489, 0.0067, 0.0026, 0.0017, 0.0311, 0.0101);
// Validate model predictions
let expected_probabilities = vec![
0.0101, 0.0030, 0.0010, 0.0018, 0.9489, 0.0067, 0.0026, 0.0017, 0.0311, 0.0101,
];
let probabilities = output.iter::<f64>().unwrap().collect::<Vec<f64>>();
assert_eq!(output.size(), &[10]);
@ -122,4 +140,4 @@ fn electra_discriminator() -> failure::Fallible<()> {
}
Ok(())
}
}

View File

@ -1,87 +1,115 @@
use tch::{Device, nn, Tensor};
use rust_tokenizers::{Gpt2Tokenizer, TruncationStrategy, Tokenizer};
use rust_bert::gpt2::{
GPT2LMHeadModel, Gpt2Config, Gpt2ConfigResources, Gpt2MergesResources, Gpt2ModelResources,
Gpt2VocabResources,
};
use rust_bert::pipelines::generation::{
Cache, GPT2Generator, GenerateConfig, LMHeadModel, LanguageGenerator,
};
use rust_bert::resources::{download_resource, RemoteResource, Resource};
use rust_bert::Config;
use rust_bert::pipelines::generation::{GPT2Generator, LanguageGenerator, GenerateConfig, LMHeadModel, Cache};
use rust_bert::gpt2::{Gpt2Config, GPT2LMHeadModel, Gpt2ConfigResources, Gpt2MergesResources, Gpt2VocabResources, Gpt2ModelResources};
use rust_bert::resources::{RemoteResource, Resource, download_resource};
use rust_tokenizers::{Gpt2Tokenizer, Tokenizer, TruncationStrategy};
use tch::{nn, Device, Tensor};
#[test]
fn gpt2_lm_model() -> failure::Fallible<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
let config_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2));
let vocab_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2));
let merges_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
let weights_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
let merges_path = download_resource(&merges_resource)?;
let weights_path = download_resource(&weights_resource)?;
// Set-up masked LM model
// Set-up masked LM model
let device = Device::Cpu;
let mut vs = nn::VarStore::new(device);
let tokenizer: Gpt2Tokenizer = Gpt2Tokenizer::from_file(vocab_path.to_str().unwrap(), merges_path.to_str().unwrap(), false);
let tokenizer: Gpt2Tokenizer = Gpt2Tokenizer::from_file(
vocab_path.to_str().unwrap(),
merges_path.to_str().unwrap(),
false,
);
let config = Gpt2Config::from_file(config_path);
let gpt2_model = GPT2LMHeadModel::new(&vs.root(), &config);
vs.load(weights_path)?;
// Define input
// Define input
let input = ["One two three four"];
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
let tokenized_input = tokenized_input.
iter().
map(|input| input.token_ids.clone()).
map(|mut input| {
let tokenized_input =
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let tokenized_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]);
input
}).
map(|input|
Tensor::of_slice(&(input))).
collect::<Vec<_>>();
})
.map(|input| Tensor::of_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
let (output, _, past, _, _) = gpt2_model.forward_t(
&Some(input_tensor),
Cache::None,
&None,
&None,
&None,
&None,
None,
&None,
false).unwrap();
// Forward pass
let (output, _, past, _, _) = gpt2_model
.forward_t(
&Some(input_tensor),
Cache::None,
&None,
&None,
&None,
&None,
None,
&None,
false,
)
.unwrap();
let next_word_id = output.get(0).get(-1).argmax(-1, true).int64_value(&[0]);
let next_word = tokenizer.decode(vec!(next_word_id), true, true);
let next_word = tokenizer.decode(vec![next_word_id], true, true);
assert_eq!(output.size(), vec!(1, 4, 50257));
match past {
Cache::GPT2Cache(past) => {
assert!(past.is_some());
assert_eq!(past.as_ref().unwrap().len(), config.n_layer as usize);
assert_eq!(past.as_ref().unwrap()[0].size(), vec!(2, 1, config.n_head, 4, 64));
assert_eq!(past.as_ref().unwrap().len(), config.n_layer as usize);
assert_eq!(
past.as_ref().unwrap()[0].size(),
vec!(2, 1, config.n_head, 4, 64)
);
}
_ => panic!("Wrong cache returned for GPT2")
_ => panic!("Wrong cache returned for GPT2"),
}
assert!((output.double_value(&[0, output.size()[1] - 1, next_word_id]) - (-69.4948)).abs() < 1e-4);
assert!(
(output.double_value(&[0, output.size()[1] - 1, next_word_id]) - (-69.4948)).abs() < 1e-4
);
assert_eq!(next_word_id, 1936i64);
assert_eq!(next_word, String::from(" five"));
Ok(())
}
#[test]
fn gpt2_generation_greedy() -> failure::Fallible<()> {
// Resources definition
let config_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
let model_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
let config_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2));
let vocab_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2));
let merges_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
let model_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
// Set-up masked LM model
// Set-up masked LM model
let generate_config = GenerateConfig {
model_resource,
config_resource,
@ -97,7 +125,7 @@ fn gpt2_generation_greedy() -> failure::Fallible<()> {
let model = GPT2Generator::new(generate_config)?;
let input_context = "The cat";
let output = model.generate(Some(vec!(input_context)), None);
let output = model.generate(Some(vec![input_context]), None);
assert_eq!(output.len(), 1);
assert_eq!(output[0], "The cat was found in a field near the town of Keflavik, about 30 miles (48 kilometers) south-east of Moscow.\n\n\n");
@ -108,12 +136,16 @@ fn gpt2_generation_greedy() -> failure::Fallible<()> {
#[test]
fn gpt2_generation_beam_search() -> failure::Fallible<()> {
// Resources definition
let config_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
let model_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
let config_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2));
let vocab_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2));
let merges_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
let model_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
// Set-up masked LM model
// Set-up masked LM model
let generate_config = GenerateConfig {
model_resource,
config_resource,
@ -129,12 +161,21 @@ fn gpt2_generation_beam_search() -> failure::Fallible<()> {
let model = GPT2Generator::new(generate_config)?;
let input_context = "The dog";
let output = model.generate(Some(vec!(input_context)), None);
let output = model.generate(Some(vec![input_context]), None);
assert_eq!(output.len(), 3);
assert_eq!(output[0], "The dog was found in the backyard of a home in the 6200 block of South Main Street.");
assert_eq!(output[1], "The dog was found in the backyard of a home in the 6500 block of South Main Street.");
assert_eq!(output[2], "The dog was found in the backyard of a home in the 6200 block of South Main Street,");
assert_eq!(
output[0],
"The dog was found in the backyard of a home in the 6200 block of South Main Street."
);
assert_eq!(
output[1],
"The dog was found in the backyard of a home in the 6500 block of South Main Street."
);
assert_eq!(
output[2],
"The dog was found in the backyard of a home in the 6200 block of South Main Street,"
);
Ok(())
}
@ -142,12 +183,16 @@ fn gpt2_generation_beam_search() -> failure::Fallible<()> {
#[test]
fn gpt2_generation_beam_search_multiple_prompts_without_padding() -> failure::Fallible<()> {
// Resources definition
let config_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
let model_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
let config_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2));
let vocab_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2));
let merges_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
let model_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
// Set-up masked LM model
// Set-up masked LM model
let generate_config = GenerateConfig {
model_resource,
config_resource,
@ -164,15 +209,33 @@ fn gpt2_generation_beam_search_multiple_prompts_without_padding() -> failure::Fa
let input_context_1 = "The dog";
let input_context_2 = "The cat";
let output = model.generate(Some(vec!(input_context_1, input_context_2)), None);
let output = model.generate(Some(vec![input_context_1, input_context_2]), None);
assert_eq!(output.len(), 6);
assert_eq!(output[0], "The dog was found in the backyard of a home in the 6200 block of South Main Street.");
assert_eq!(output[1], "The dog was found in the backyard of a home in the 6500 block of South Main Street.");
assert_eq!(output[2], "The dog was found in the backyard of a home in the 6200 block of South Main Street,");
assert_eq!(output[3], "The cat-and-mouse game.\n\n\"I think it\'s going to be interesting to");
assert_eq!(output[4], "The cat-and-mouse game.\n\n\"I think it\'s going to be a very");
assert_eq!(output[5], "The cat-and-mouse game.\n\n\"I think it\'s going to be very interesting");
assert_eq!(
output[0],
"The dog was found in the backyard of a home in the 6200 block of South Main Street."
);
assert_eq!(
output[1],
"The dog was found in the backyard of a home in the 6500 block of South Main Street."
);
assert_eq!(
output[2],
"The dog was found in the backyard of a home in the 6200 block of South Main Street,"
);
assert_eq!(
output[3],
"The cat-and-mouse game.\n\n\"I think it\'s going to be interesting to"
);
assert_eq!(
output[4],
"The cat-and-mouse game.\n\n\"I think it\'s going to be a very"
);
assert_eq!(
output[5],
"The cat-and-mouse game.\n\n\"I think it\'s going to be very interesting"
);
Ok(())
}
@ -180,12 +243,16 @@ fn gpt2_generation_beam_search_multiple_prompts_without_padding() -> failure::Fa
#[test]
fn gpt2_generation_beam_search_multiple_prompts_with_padding() -> failure::Fallible<()> {
// Resources definition
let config_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
let model_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
let config_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2));
let vocab_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2));
let merges_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
let model_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
// Set-up masked LM model
// Set-up masked LM model
let generate_config = GenerateConfig {
model_resource,
config_resource,
@ -202,15 +269,33 @@ fn gpt2_generation_beam_search_multiple_prompts_with_padding() -> failure::Falli
let input_context_1 = "The dog";
let input_context_2 = "The cat was";
let output = model.generate(Some(vec!(input_context_1, input_context_2)), None);
let output = model.generate(Some(vec![input_context_1, input_context_2]), None);
assert_eq!(output.len(), 6);
assert_eq!(output[0], "The dog was found dead on the side of the road in the middle of the night.\n");
assert_eq!(output[1], "The dog was found dead on the side of the road in the middle of the night on Sunday");
assert_eq!(output[2], "The dog was found dead on the side of the road in the middle of the night on Saturday");
assert_eq!(output[3], "The cat was taken to a local hospital, where it was treated and released.\n\nPolice said");
assert_eq!(output[4], "The cat was taken to a local hospital, where it was treated and released.\n\n\"It");
assert_eq!(output[5], "The cat was taken to a local hospital, where it was treated and released.\n\n\"We");
assert_eq!(
output[0],
"The dog was found dead on the side of the road in the middle of the night.\n"
);
assert_eq!(
output[1],
"The dog was found dead on the side of the road in the middle of the night on Sunday"
);
assert_eq!(
output[2],
"The dog was found dead on the side of the road in the middle of the night on Saturday"
);
assert_eq!(
output[3],
"The cat was taken to a local hospital, where it was treated and released.\n\nPolice said"
);
assert_eq!(
output[4],
"The cat was taken to a local hospital, where it was treated and released.\n\n\"It"
);
assert_eq!(
output[5],
"The cat was taken to a local hospital, where it was treated and released.\n\n\"We"
);
Ok(())
}
}

View File

@ -1,12 +1,11 @@
use rust_bert::pipelines::translation::{TranslationConfig, Language, TranslationModel};
use rust_bert::pipelines::translation::{Language, TranslationConfig, TranslationModel};
use tch::Device;
#[test]
#[cfg_attr(not(feature = "all-tests"), ignore)]
fn test_translation() -> failure::Fallible<()> {
// Set-up translation model
let translation_config = TranslationConfig::new(Language::EnglishToFrench, Device::Cpu);
// Set-up translation model
let translation_config = TranslationConfig::new(Language::EnglishToFrench, Device::Cpu);
let model = TranslationModel::new(translation_config)?;
let input_context_1 = "The quick brown fox jumps over the lazy dog";
@ -15,8 +14,11 @@ fn test_translation() -> failure::Fallible<()> {
let output = model.translate(&[input_context_1, input_context_2]);
assert_eq!(output.len(), 2);
assert_eq!(output[0], " Le rapide renard brun saute sur le chien paresseux");
assert_eq!(
output[0],
" Le rapide renard brun saute sur le chien paresseux"
);
assert_eq!(output[1], " Le chien ne s'est pas réveillé.");
Ok(())
}
}

View File

@ -1,64 +1,90 @@
use tch::{Device, nn, Tensor};
use rust_tokenizers::{TruncationStrategy, Tokenizer, OpenAiGptTokenizer};
use rust_bert::Config;
use rust_bert::pipelines::generation::{OpenAIGenerator, LanguageGenerator, GenerateConfig, LMHeadModel, Cache};
use rust_bert::gpt2::Gpt2Config;
use rust_bert::openai_gpt::{OpenAIGPTLMHeadModel, OpenAiGptConfigResources, OpenAiGptVocabResources, OpenAiGptMergesResources, OpenAiGptModelResources};
use rust_bert::resources::{RemoteResource, Resource, download_resource};
use rust_bert::openai_gpt::{
OpenAIGPTLMHeadModel, OpenAiGptConfigResources, OpenAiGptMergesResources,
OpenAiGptModelResources, OpenAiGptVocabResources,
};
use rust_bert::pipelines::generation::{
Cache, GenerateConfig, LMHeadModel, LanguageGenerator, OpenAIGenerator,
};
use rust_bert::resources::{download_resource, RemoteResource, Resource};
use rust_bert::Config;
use rust_tokenizers::{OpenAiGptTokenizer, Tokenizer, TruncationStrategy};
use tch::{nn, Device, Tensor};
#[test]
fn openai_gpt_lm_model() -> failure::Fallible<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptConfigResources::GPT));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptVocabResources::GPT));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptMergesResources::GPT));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptModelResources::GPT));
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
OpenAiGptConfigResources::GPT,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
OpenAiGptVocabResources::GPT,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
OpenAiGptMergesResources::GPT,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
OpenAiGptModelResources::GPT,
));
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
let merges_path = download_resource(&merges_resource)?;
let weights_path = download_resource(&weights_resource)?;
// Set-up masked LM model
// Set-up masked LM model
let device = Device::Cpu;
let mut vs = nn::VarStore::new(device);
let tokenizer = OpenAiGptTokenizer::from_file(vocab_path.to_str().unwrap(), merges_path.to_str().unwrap(), true);
let tokenizer = OpenAiGptTokenizer::from_file(
vocab_path.to_str().unwrap(),
merges_path.to_str().unwrap(),
true,
);
let config = Gpt2Config::from_file(config_path);
let openai_gpt = OpenAIGPTLMHeadModel::new(&vs.root(), &config);
vs.load(weights_path)?;
// Define input
// Define input
let input = ["Wondering what the next word will"];
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
let tokenized_input = tokenized_input.
iter().
map(|input| input.token_ids.clone()).
map(|mut input| {
let tokenized_input =
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let tokenized_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]);
input
}).
map(|input|
Tensor::of_slice(&(input))).
collect::<Vec<_>>();
})
.map(|input| Tensor::of_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
let (output, _, _, _, _) = openai_gpt.forward_t(
&Some(input_tensor),
Cache::None,
&None,
&None,
&None,
&None,
None,
&None,
false).unwrap();
// Forward pass
let (output, _, _, _, _) = openai_gpt
.forward_t(
&Some(input_tensor),
Cache::None,
&None,
&None,
&None,
&None,
None,
&None,
false,
)
.unwrap();
let next_word_id = output.get(0).get(-1).argmax(-1, true).int64_value(&[0]);
let next_word = tokenizer.decode(vec!(next_word_id), true, true);
let next_word = tokenizer.decode(vec![next_word_id], true, true);
assert_eq!(output.size(), vec!(1, 6, 40478));
assert!((output.double_value(&[0, output.size()[1] - 1, next_word_id]) - (9.1056)).abs() < 1e-4);
assert!(
(output.double_value(&[0, output.size()[1] - 1, next_word_id]) - (9.1056)).abs() < 1e-4
);
assert_eq!(next_word_id, 580i64);
assert_eq!(next_word, String::from("be"));
@ -68,12 +94,20 @@ fn openai_gpt_lm_model() -> failure::Fallible<()> {
#[test]
fn openai_gpt_generation_greedy() -> failure::Fallible<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptConfigResources::GPT));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptVocabResources::GPT));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptMergesResources::GPT));
let model_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptModelResources::GPT));
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
OpenAiGptConfigResources::GPT,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
OpenAiGptVocabResources::GPT,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
OpenAiGptMergesResources::GPT,
));
let model_resource = Resource::Remote(RemoteResource::from_pretrained(
OpenAiGptModelResources::GPT,
));
// Set-up masked LM model
// Set-up masked LM model
let generate_config = GenerateConfig {
model_resource,
config_resource,
@ -90,7 +124,7 @@ fn openai_gpt_generation_greedy() -> failure::Fallible<()> {
let model = OpenAIGenerator::new(generate_config)?;
let input_context = "It was an intense machine dialogue. ";
let output = model.generate(Some(vec!(input_context)), None);
let output = model.generate(Some(vec![input_context]), None);
assert_eq!(output.len(), 1);
assert_eq!(output[0], "it was an intense machine dialogue. \n \" i\'m sorry, but we have to go now! the police are on their way and they\'re going after you - or at least that\'s what my");
@ -101,12 +135,20 @@ fn openai_gpt_generation_greedy() -> failure::Fallible<()> {
#[test]
fn openai_gpt_generation_beam_search() -> failure::Fallible<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptConfigResources::GPT));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptVocabResources::GPT));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptMergesResources::GPT));
let model_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptModelResources::GPT));
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
OpenAiGptConfigResources::GPT,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
OpenAiGptVocabResources::GPT,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
OpenAiGptMergesResources::GPT,
));
let model_resource = Resource::Remote(RemoteResource::from_pretrained(
OpenAiGptModelResources::GPT,
));
// Set-up masked LM model
// Set-up masked LM model
let generate_config = GenerateConfig {
model_resource,
config_resource,
@ -122,12 +164,21 @@ fn openai_gpt_generation_beam_search() -> failure::Fallible<()> {
let model = OpenAIGenerator::new(generate_config)?;
let input_context = "The dog is";
let output = model.generate(Some(vec!(input_context)), None);
let output = model.generate(Some(vec![input_context]), None);
assert_eq!(output.len(), 3);
assert_eq!(output[0], "the dog isn\'t going anywhere. i\'m going to take care of him. i \'ll be right");
assert_eq!(output[1], "the dog isn\'t going anywhere. i\'m going to take care of him. i \'ll be back");
assert_eq!(output[2], "the dog isn\'t going anywhere. i\'m going to take care of him. \" \n \" i");
assert_eq!(
output[0],
"the dog isn\'t going anywhere. i\'m going to take care of him. i \'ll be right"
);
assert_eq!(
output[1],
"the dog isn\'t going anywhere. i\'m going to take care of him. i \'ll be back"
);
assert_eq!(
output[2],
"the dog isn\'t going anywhere. i\'m going to take care of him. \" \n \" i"
);
Ok(())
}
@ -135,12 +186,20 @@ fn openai_gpt_generation_beam_search() -> failure::Fallible<()> {
#[test]
fn openai_gpt_generation_beam_search_multiple_prompts_without_padding() -> failure::Fallible<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptConfigResources::GPT));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptVocabResources::GPT));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptMergesResources::GPT));
let model_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptModelResources::GPT));
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
OpenAiGptConfigResources::GPT,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
OpenAiGptVocabResources::GPT,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
OpenAiGptMergesResources::GPT,
));
let model_resource = Resource::Remote(RemoteResource::from_pretrained(
OpenAiGptModelResources::GPT,
));
// Set-up masked LM model
// Set-up masked LM model
let generate_config = GenerateConfig {
model_resource,
config_resource,
@ -157,18 +216,36 @@ fn openai_gpt_generation_beam_search_multiple_prompts_without_padding() -> failu
let input_context_1 = "The dog is";
let input_context_2 = "The cat";
let output = model.generate(Some(vec!(input_context_1, input_context_2)), None);
let output = model.generate(Some(vec![input_context_1, input_context_2]), None);
assert_eq!(output.len(), 6);
// Unpadded sequence (generation for `The dog is`) is identical to the
assert_eq!(output[0], "the dog isn\'t going anywhere. i\'m going to take care of him. i \'ll be right");
assert_eq!(output[1], "the dog isn\'t going anywhere. i\'m going to take care of him. i \'ll be back");
assert_eq!(output[2], "the dog isn\'t going anywhere. i\'m going to take care of him. \" \n \" i");
// Unpadded sequence (generation for `The dog is`) is identical to the
assert_eq!(
output[0],
"the dog isn\'t going anywhere. i\'m going to take care of him. i \'ll be right"
);
assert_eq!(
output[1],
"the dog isn\'t going anywhere. i\'m going to take care of him. i \'ll be back"
);
assert_eq!(
output[2],
"the dog isn\'t going anywhere. i\'m going to take care of him. \" \n \" i"
);
assert_eq!(output[3], "the cat. \" \n \" i don\'t know what you\'re talking about. i don\'t");
assert_eq!(output[4], "the cat. \" \n \" i don\'t know what you\'re talking about. i\'m not");
assert_eq!(output[5], "the cat. \" \n \" i don\'t know what you\'re talking about. i do know");
assert_eq!(
output[3],
"the cat. \" \n \" i don\'t know what you\'re talking about. i don\'t"
);
assert_eq!(
output[4],
"the cat. \" \n \" i don\'t know what you\'re talking about. i\'m not"
);
assert_eq!(
output[5],
"the cat. \" \n \" i don\'t know what you\'re talking about. i do know"
);
Ok(())
}
@ -176,12 +253,20 @@ fn openai_gpt_generation_beam_search_multiple_prompts_without_padding() -> failu
#[test]
fn openai_gpt_generation_beam_search_multiple_prompts_with_padding() -> failure::Fallible<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptConfigResources::GPT));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptVocabResources::GPT));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptMergesResources::GPT));
let model_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptModelResources::GPT));
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
OpenAiGptConfigResources::GPT,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
OpenAiGptVocabResources::GPT,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
OpenAiGptMergesResources::GPT,
));
let model_resource = Resource::Remote(RemoteResource::from_pretrained(
OpenAiGptModelResources::GPT,
));
// Set-up masked LM model
// Set-up masked LM model
let generate_config = GenerateConfig {
model_resource,
config_resource,
@ -198,16 +283,34 @@ fn openai_gpt_generation_beam_search_multiple_prompts_with_padding() -> failure:
let input_context_1 = "The dog is";
let input_context_2 = "The cat was in";
let output = model.generate(Some(vec!(input_context_1, input_context_2)), None);
let output = model.generate(Some(vec![input_context_1, input_context_2]), None);
assert_eq!(output.len(), 6);
// Left padding impacts the generated sentences output
assert_eq!(output[0], "the dog is a dog. \" \n \" i don\'t know what you\'re talking about.");
assert_eq!(output[1], "the dog is a dog. \" \n \" i don\'t know what you\'re talking about,");
assert_eq!(output[2], "the dog is a dog. \" \n \" i don\'t know what you\'re talking about!");
assert_eq!(output[3], "the cat was in the room with them. \n \" what\'s going on? \" i asked.");
assert_eq!(output[4], "the cat was in the room with them. \n \" what\'s going on? \" she asked.");
assert_eq!(output[5], "the cat was in the room with them. \n \" what\'s going on? why are you all");
// Left padding impacts the generated sentences output
assert_eq!(
output[0],
"the dog is a dog. \" \n \" i don\'t know what you\'re talking about."
);
assert_eq!(
output[1],
"the dog is a dog. \" \n \" i don\'t know what you\'re talking about,"
);
assert_eq!(
output[2],
"the dog is a dog. \" \n \" i don\'t know what you\'re talking about!"
);
assert_eq!(
output[3],
"the cat was in the room with them. \n \" what\'s going on? \" i asked."
);
assert_eq!(
output[4],
"the cat was in the room with them. \n \" what\'s going on? \" she asked."
);
assert_eq!(
output[5],
"the cat was in the room with them. \n \" what\'s going on? why are you all"
);
Ok(())
}

View File

@ -1,75 +1,99 @@
use tch::{Device, nn, Tensor, no_grad};
use rust_tokenizers::{RobertaTokenizer, TruncationStrategy, Tokenizer, Vocab};
use rust_bert::Config;
use rust_bert::bert::BertConfig;
use rust_bert::roberta::{RobertaForMaskedLM, RobertaForSequenceClassification, RobertaForMultipleChoice, RobertaForTokenClassification, RobertaForQuestionAnswering, RobertaConfigResources, RobertaVocabResources, RobertaMergesResources, RobertaModelResources};
use rust_bert::resources::{download_resource, RemoteResource, Resource};
use rust_bert::roberta::{
RobertaConfigResources, RobertaForMaskedLM, RobertaForMultipleChoice,
RobertaForQuestionAnswering, RobertaForSequenceClassification, RobertaForTokenClassification,
RobertaMergesResources, RobertaModelResources, RobertaVocabResources,
};
use rust_bert::Config;
use rust_tokenizers::{RobertaTokenizer, Tokenizer, TruncationStrategy, Vocab};
use std::collections::HashMap;
use rust_bert::resources::{RemoteResource, Resource, download_resource};
use tch::{nn, no_grad, Device, Tensor};
#[test]
fn roberta_masked_lm() -> failure::Fallible<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaConfigResources::ROBERTA));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaVocabResources::ROBERTA));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaMergesResources::ROBERTA));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaModelResources::ROBERTA));
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
RobertaConfigResources::ROBERTA,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
RobertaVocabResources::ROBERTA,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
RobertaMergesResources::ROBERTA,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
RobertaModelResources::ROBERTA,
));
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
let merges_path = download_resource(&merges_resource)?;
let weights_path = download_resource(&weights_resource)?;
// Set-up masked LM model
// Set-up masked LM model
let device = Device::Cpu;
let mut vs = nn::VarStore::new(device);
let tokenizer: RobertaTokenizer = RobertaTokenizer::from_file(vocab_path.to_str().unwrap(), merges_path.to_str().unwrap(), true);
let tokenizer: RobertaTokenizer = RobertaTokenizer::from_file(
vocab_path.to_str().unwrap(),
merges_path.to_str().unwrap(),
true,
);
let config = BertConfig::from_file(config_path);
let roberta_model = RobertaForMaskedLM::new(&vs.root(), &config);
vs.load(weights_path)?;
// Define input
let input = ["<pad> Looks like one thing is missing", "It\'s like comparing oranges to apples"];
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
let mut tokenized_input = tokenized_input.
iter().
map(|input| input.token_ids.clone()).
map(|mut input| {
// Define input
let input = [
"<pad> Looks like one thing is missing",
"It\'s like comparing oranges to apples",
];
let tokenized_input =
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let mut tokenized_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]);
input
}).
collect::<Vec<_>>();
})
.collect::<Vec<_>>();
// Masking the token [thing] of sentence 1 and [oranges] of sentence 2
// Masking the token [thing] of sentence 1 and [oranges] of sentence 2
tokenized_input[0][4] = 103;
tokenized_input[1][5] = 103;
let tokenized_input = tokenized_input.
iter().
map(|input|
Tensor::of_slice(&(input))).
collect::<Vec<_>>();
let tokenized_input = tokenized_input
.iter()
.map(|input| Tensor::of_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
// Forward pass
let (output, _, _) = no_grad(|| {
roberta_model
.forward_t(Some(input_tensor),
None,
None,
None,
None,
&None,
&None,
false)
roberta_model.forward_t(
Some(input_tensor),
None,
None,
None,
None,
&None,
&None,
false,
)
});
// Print masked tokens
// Print masked tokens
let index_1 = output.get(0).get(4).argmax(0, false);
let index_2 = output.get(1).get(5).argmax(0, false);
let word_1 = tokenizer.vocab().id_to_token(&index_1.int64_value(&[]));
let word_2 = tokenizer.vocab().id_to_token(&index_2.int64_value(&[]));
assert_eq!("Ġsome", word_1); // Outputs "person" : "Looks like [some] thing is missing"
assert_eq!("Ġapples", word_2);// Outputs "pear" : "It\'s like comparing [apples] to apples"
assert_eq!("Ġapples", word_2); // Outputs "pear" : "It\'s like comparing [apples] to apples"
Ok(())
}
@ -77,17 +101,27 @@ fn roberta_masked_lm() -> failure::Fallible<()> {
#[test]
fn roberta_for_sequence_classification() -> failure::Fallible<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaConfigResources::ROBERTA));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaVocabResources::ROBERTA));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaMergesResources::ROBERTA));
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
RobertaConfigResources::ROBERTA,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
RobertaVocabResources::ROBERTA,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
RobertaMergesResources::ROBERTA,
));
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
let merges_path = download_resource(&merges_resource)?;
// Set-up model
// Set-up model
let device = Device::Cpu;
let vs = nn::VarStore::new(device);
let tokenizer: RobertaTokenizer = RobertaTokenizer::from_file(vocab_path.to_str().unwrap(), merges_path.to_str().unwrap(), true);
let tokenizer: RobertaTokenizer = RobertaTokenizer::from_file(
vocab_path.to_str().unwrap(),
merges_path.to_str().unwrap(),
true,
);
let mut config = BertConfig::from_file(config_path);
let mut dummy_label_mapping = HashMap::new();
dummy_label_mapping.insert(0, String::from("Positive"));
@ -98,37 +132,42 @@ fn roberta_for_sequence_classification() -> failure::Fallible<()> {
config.output_hidden_states = Some(true);
let roberta_model = RobertaForSequenceClassification::new(&vs.root(), &config);
// Define input
let input = ["Looks like one thing is missing", "It\'s like comparing oranges to apples"];
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
let tokenized_input = tokenized_input.
iter().
map(|input| input.token_ids.clone()).
map(|mut input| {
// Define input
let input = [
"Looks like one thing is missing",
"It\'s like comparing oranges to apples",
];
let tokenized_input =
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let tokenized_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]);
input
}).
map(|input|
Tensor::of_slice(&(input))).
collect::<Vec<_>>();
})
.map(|input| Tensor::of_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
let (output, all_hidden_states, all_attentions) = no_grad(|| {
roberta_model
.forward_t(Some(input_tensor),
None,
None,
None,
None,
false)
});
// Forward pass
let (output, all_hidden_states, all_attentions) =
no_grad(|| roberta_model.forward_t(Some(input_tensor), None, None, None, None, false));
assert_eq!(output.size(), &[2, 3]);
assert_eq!(config.num_hidden_layers as usize, all_hidden_states.unwrap().len());
assert_eq!(config.num_hidden_layers as usize, all_attentions.unwrap().len());
assert_eq!(
config.num_hidden_layers as usize,
all_hidden_states.unwrap().len()
);
assert_eq!(
config.num_hidden_layers as usize,
all_attentions.unwrap().len()
);
Ok(())
}
@ -136,52 +175,70 @@ fn roberta_for_sequence_classification() -> failure::Fallible<()> {
#[test]
fn roberta_for_multiple_choice() -> failure::Fallible<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaConfigResources::ROBERTA));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaVocabResources::ROBERTA));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaMergesResources::ROBERTA));
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
RobertaConfigResources::ROBERTA,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
RobertaVocabResources::ROBERTA,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
RobertaMergesResources::ROBERTA,
));
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
let merges_path = download_resource(&merges_resource)?;
// Set-up model
// Set-up model
let device = Device::Cpu;
let vs = nn::VarStore::new(device);
let tokenizer: RobertaTokenizer = RobertaTokenizer::from_file(vocab_path.to_str().unwrap(), merges_path.to_str().unwrap(), true);
let tokenizer: RobertaTokenizer = RobertaTokenizer::from_file(
vocab_path.to_str().unwrap(),
merges_path.to_str().unwrap(),
true,
);
let mut config = BertConfig::from_file(config_path);
config.output_attentions = Some(true);
config.output_hidden_states = Some(true);
let roberta_model = RobertaForMultipleChoice::new(&vs.root(), &config);
// Define input
let input = ["Looks like one thing is missing", "It\'s like comparing oranges to apples"];
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
let tokenized_input = tokenized_input.
iter().
map(|input| input.token_ids.clone()).
map(|mut input| {
// Define input
let input = [
"Looks like one thing is missing",
"It\'s like comparing oranges to apples",
];
let tokenized_input =
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let tokenized_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]);
input
}).
map(|input|
Tensor::of_slice(&(input))).
collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device).unsqueeze(0);
})
.map(|input| Tensor::of_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0)
.to(device)
.unsqueeze(0);
// Forward pass
let (output, all_hidden_states, all_attentions) = no_grad(|| {
roberta_model
.forward_t(input_tensor,
None,
None,
None,
false)
});
// Forward pass
let (output, all_hidden_states, all_attentions) =
no_grad(|| roberta_model.forward_t(input_tensor, None, None, None, false));
assert_eq!(output.size(), &[1, 2]);
assert_eq!(config.num_hidden_layers as usize, all_hidden_states.unwrap().len());
assert_eq!(config.num_hidden_layers as usize, all_attentions.unwrap().len());
assert_eq!(
config.num_hidden_layers as usize,
all_hidden_states.unwrap().len()
);
assert_eq!(
config.num_hidden_layers as usize,
all_attentions.unwrap().len()
);
Ok(())
}
@ -189,17 +246,27 @@ fn roberta_for_multiple_choice() -> failure::Fallible<()> {
#[test]
fn roberta_for_token_classification() -> failure::Fallible<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaConfigResources::ROBERTA));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaVocabResources::ROBERTA));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaMergesResources::ROBERTA));
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
RobertaConfigResources::ROBERTA,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
RobertaVocabResources::ROBERTA,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
RobertaMergesResources::ROBERTA,
));
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
let merges_path = download_resource(&merges_resource)?;
// Set-up model
// Set-up model
let device = Device::Cpu;
let vs = nn::VarStore::new(device);
let tokenizer: RobertaTokenizer = RobertaTokenizer::from_file(vocab_path.to_str().unwrap(), merges_path.to_str().unwrap(), true);
let tokenizer: RobertaTokenizer = RobertaTokenizer::from_file(
vocab_path.to_str().unwrap(),
merges_path.to_str().unwrap(),
true,
);
let mut config = BertConfig::from_file(config_path);
let mut dummy_label_mapping = HashMap::new();
dummy_label_mapping.insert(0, String::from("O"));
@ -211,55 +278,70 @@ fn roberta_for_token_classification() -> failure::Fallible<()> {
config.output_hidden_states = Some(true);
let roberta_model = RobertaForTokenClassification::new(&vs.root(), &config);
// Define input
let input = ["Looks like one thing is missing", "It\'s like comparing oranges to apples"];
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
let tokenized_input = tokenized_input.
iter().
map(|input| input.token_ids.clone()).
map(|mut input| {
// Define input
let input = [
"Looks like one thing is missing",
"It\'s like comparing oranges to apples",
];
let tokenized_input =
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let tokenized_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]);
input
}).
map(|input|
Tensor::of_slice(&(input))).
collect::<Vec<_>>();
})
.map(|input| Tensor::of_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
let (output, all_hidden_states, all_attentions) = no_grad(|| {
roberta_model
.forward_t(Some(input_tensor),
None,
None,
None,
None,
false)
});
// Forward pass
let (output, all_hidden_states, all_attentions) =
no_grad(|| roberta_model.forward_t(Some(input_tensor), None, None, None, None, false));
assert_eq!(output.size(), &[2, 9, 4]);
assert_eq!(config.num_hidden_layers as usize, all_hidden_states.unwrap().len());
assert_eq!(config.num_hidden_layers as usize, all_attentions.unwrap().len());
assert_eq!(
config.num_hidden_layers as usize,
all_hidden_states.unwrap().len()
);
assert_eq!(
config.num_hidden_layers as usize,
all_attentions.unwrap().len()
);
Ok(())
}
#[test]
fn roberta_for_question_answering() -> failure::Fallible<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaConfigResources::ROBERTA));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaVocabResources::ROBERTA));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaMergesResources::ROBERTA));
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
RobertaConfigResources::ROBERTA,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
RobertaVocabResources::ROBERTA,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
RobertaMergesResources::ROBERTA,
));
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
let merges_path = download_resource(&merges_resource)?;
// Set-up model
// Set-up model
let device = Device::Cpu;
let vs = nn::VarStore::new(device);
let tokenizer: RobertaTokenizer = RobertaTokenizer::from_file(vocab_path.to_str().unwrap(), merges_path.to_str().unwrap(), true);
let tokenizer: RobertaTokenizer = RobertaTokenizer::from_file(
vocab_path.to_str().unwrap(),
merges_path.to_str().unwrap(),
true,
);
let mut config = BertConfig::from_file(config_path);
let mut dummy_label_mapping = HashMap::new();
dummy_label_mapping.insert(0, String::from("Positive"));
@ -270,38 +352,43 @@ fn roberta_for_question_answering() -> failure::Fallible<()> {
config.output_hidden_states = Some(true);
let roberta_model = RobertaForQuestionAnswering::new(&vs.root(), &config);
// Define input
let input = ["Looks like one thing is missing", "It\'s like comparing oranges to apples"];
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
let tokenized_input = tokenized_input.
iter().
map(|input| input.token_ids.clone()).
map(|mut input| {
// Define input
let input = [
"Looks like one thing is missing",
"It\'s like comparing oranges to apples",
];
let tokenized_input =
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let tokenized_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]);
input
}).
map(|input|
Tensor::of_slice(&(input))).
collect::<Vec<_>>();
})
.map(|input| Tensor::of_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
let (start_scores, end_scores, all_hidden_states, all_attentions) = no_grad(|| {
roberta_model
.forward_t(Some(input_tensor),
None,
None,
None,
None,
false)
});
// Forward pass
let (start_scores, end_scores, all_hidden_states, all_attentions) =
no_grad(|| roberta_model.forward_t(Some(input_tensor), None, None, None, None, false));
assert_eq!(start_scores.size(), &[2, 9]);
assert_eq!(end_scores.size(), &[2, 9]);
assert_eq!(config.num_hidden_layers as usize, all_hidden_states.unwrap().len());
assert_eq!(config.num_hidden_layers as usize, all_attentions.unwrap().len());
assert_eq!(
config.num_hidden_layers as usize,
all_hidden_states.unwrap().len()
);
assert_eq!(
config.num_hidden_layers as usize,
all_attentions.unwrap().len()
);
Ok(())
}
}