mirror of
https://github.com/guillaume-be/rust-bert.git
synced 2024-10-05 16:47:24 +03:00
Code formatted using rustfmt
This commit is contained in:
parent
0624a5368c
commit
47e36c4e8c
@ -13,18 +13,26 @@
|
||||
|
||||
extern crate failure;
|
||||
|
||||
use tch::{Device, nn, Tensor, no_grad};
|
||||
use rust_tokenizers::{AlbertTokenizer, TruncationStrategy, Tokenizer, Vocab};
|
||||
use rust_bert::albert::{
|
||||
AlbertConfig, AlbertConfigResources, AlbertForMaskedLM, AlbertModelResources,
|
||||
AlbertVocabResources,
|
||||
};
|
||||
use rust_bert::resources::{download_resource, RemoteResource, Resource};
|
||||
use rust_bert::Config;
|
||||
use rust_bert::resources::{Resource, download_resource, RemoteResource};
|
||||
use rust_bert::albert::{AlbertConfig, AlbertForMaskedLM, AlbertConfigResources, AlbertVocabResources, AlbertModelResources};
|
||||
|
||||
use rust_tokenizers::{AlbertTokenizer, Tokenizer, TruncationStrategy, Vocab};
|
||||
use tch::{nn, no_grad, Device, Tensor};
|
||||
|
||||
fn main() -> failure::Fallible<()> {
|
||||
// Resources paths
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(AlbertConfigResources::ALBERT_BASE_V2));
|
||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(AlbertVocabResources::ALBERT_BASE_V2));
|
||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(AlbertModelResources::ALBERT_BASE_V2));
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
AlbertConfigResources::ALBERT_BASE_V2,
|
||||
));
|
||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
AlbertVocabResources::ALBERT_BASE_V2,
|
||||
));
|
||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
AlbertModelResources::ALBERT_BASE_V2,
|
||||
));
|
||||
let config_path = download_resource(&config_resource)?;
|
||||
let vocab_path = download_resource(&vocab_resource)?;
|
||||
let weights_path = download_resource(&weights_resource)?;
|
||||
@ -32,37 +40,38 @@ fn main() -> failure::Fallible<()> {
|
||||
// Set-up masked LM model
|
||||
let device = Device::Cpu;
|
||||
let mut vs = nn::VarStore::new(device);
|
||||
let tokenizer: AlbertTokenizer = AlbertTokenizer::from_file(vocab_path.to_str().unwrap(), true, false);
|
||||
let tokenizer: AlbertTokenizer =
|
||||
AlbertTokenizer::from_file(vocab_path.to_str().unwrap(), true, false);
|
||||
let config = AlbertConfig::from_file(config_path);
|
||||
let albert_model = AlbertForMaskedLM::new(&vs.root(), &config);
|
||||
vs.load(weights_path)?;
|
||||
|
||||
// Define input
|
||||
let input = ["Looks like one [MASK] is missing", "It was a very nice and [MASK] day"];
|
||||
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
||||
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
|
||||
let tokenized_input = tokenized_input.
|
||||
iter().
|
||||
map(|input| input.token_ids.clone()).
|
||||
map(|mut input| {
|
||||
let input = [
|
||||
"Looks like one [MASK] is missing",
|
||||
"It was a very nice and [MASK] day",
|
||||
];
|
||||
let tokenized_input =
|
||||
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
||||
let max_len = tokenized_input
|
||||
.iter()
|
||||
.map(|input| input.token_ids.len())
|
||||
.max()
|
||||
.unwrap();
|
||||
let tokenized_input = tokenized_input
|
||||
.iter()
|
||||
.map(|input| input.token_ids.clone())
|
||||
.map(|mut input| {
|
||||
input.extend(vec![0; max_len - input.len()]);
|
||||
input
|
||||
}).
|
||||
map(|input|
|
||||
Tensor::of_slice(&(input))).
|
||||
collect::<Vec<_>>();
|
||||
})
|
||||
.map(|input| Tensor::of_slice(&(input)))
|
||||
.collect::<Vec<_>>();
|
||||
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
|
||||
|
||||
// Forward pass
|
||||
let (output, _, _) = no_grad(|| {
|
||||
albert_model
|
||||
.forward_t(Some(input_tensor),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
false)
|
||||
});
|
||||
let (output, _, _) =
|
||||
no_grad(|| albert_model.forward_t(Some(input_tensor), None, None, None, None, false));
|
||||
println!("{:?}", output.double_value(&[0, 0, 0]));
|
||||
// Print masked tokens
|
||||
let index_1 = output.get(0).get(4).argmax(0, false);
|
||||
|
@ -12,19 +12,25 @@
|
||||
|
||||
extern crate failure;
|
||||
|
||||
use tch::{Device, nn, Tensor, no_grad};
|
||||
use rust_tokenizers::{RobertaTokenizer, TruncationStrategy, Tokenizer};
|
||||
use rust_bert::bart::{BartConfig, BartConfigResources, BartVocabResources, BartMergesResources, BartModelResources, BartModel};
|
||||
use rust_bert::bart::{
|
||||
BartConfig, BartConfigResources, BartMergesResources, BartModel, BartModelResources,
|
||||
BartVocabResources,
|
||||
};
|
||||
use rust_bert::resources::{download_resource, RemoteResource, Resource};
|
||||
use rust_bert::Config;
|
||||
use rust_bert::resources::{Resource, download_resource, RemoteResource};
|
||||
|
||||
use rust_tokenizers::{RobertaTokenizer, Tokenizer, TruncationStrategy};
|
||||
use tch::{nn, no_grad, Device, Tensor};
|
||||
|
||||
fn main() -> failure::Fallible<()> {
|
||||
// Resources paths
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(BartConfigResources::BART));
|
||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(BartVocabResources::BART));
|
||||
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(BartMergesResources::BART));
|
||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(BartModelResources::BART));
|
||||
let config_resource =
|
||||
Resource::Remote(RemoteResource::from_pretrained(BartConfigResources::BART));
|
||||
let vocab_resource =
|
||||
Resource::Remote(RemoteResource::from_pretrained(BartVocabResources::BART));
|
||||
let merges_resource =
|
||||
Resource::Remote(RemoteResource::from_pretrained(BartMergesResources::BART));
|
||||
let weights_resource =
|
||||
Resource::Remote(RemoteResource::from_pretrained(BartModelResources::BART));
|
||||
let config_path = download_resource(&config_resource)?;
|
||||
let vocab_path = download_resource(&vocab_resource)?;
|
||||
let merges_path = download_resource(&merges_resource)?;
|
||||
@ -33,7 +39,11 @@ fn main() -> failure::Fallible<()> {
|
||||
// Set-up masked LM model
|
||||
let device = Device::cuda_if_available();
|
||||
let mut vs = nn::VarStore::new(device);
|
||||
let tokenizer = RobertaTokenizer::from_file(vocab_path.to_str().unwrap(), merges_path.to_str().unwrap(), false);
|
||||
let tokenizer = RobertaTokenizer::from_file(
|
||||
vocab_path.to_str().unwrap(),
|
||||
merges_path.to_str().unwrap(),
|
||||
false,
|
||||
);
|
||||
let config = BartConfig::from_file(config_path);
|
||||
let bart_model = BartModel::new(&vs.root(), &config, false);
|
||||
vs.load(weights_path)?;
|
||||
@ -63,31 +73,27 @@ about exoplanets like K2-18b."];
|
||||
|
||||
// Credits: WikiNews, CC BY 2.5 license (https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b)
|
||||
|
||||
let tokenized_input = tokenizer.encode_list(input.to_vec(), 1024, &TruncationStrategy::LongestFirst, 0);
|
||||
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
|
||||
let tokenized_input = tokenized_input.
|
||||
iter().
|
||||
map(|input| input.token_ids.clone()).
|
||||
map(|mut input| {
|
||||
let tokenized_input =
|
||||
tokenizer.encode_list(input.to_vec(), 1024, &TruncationStrategy::LongestFirst, 0);
|
||||
let max_len = tokenized_input
|
||||
.iter()
|
||||
.map(|input| input.token_ids.len())
|
||||
.max()
|
||||
.unwrap();
|
||||
let tokenized_input = tokenized_input
|
||||
.iter()
|
||||
.map(|input| input.token_ids.clone())
|
||||
.map(|mut input| {
|
||||
input.extend(vec![0; max_len - input.len()]);
|
||||
input
|
||||
}).
|
||||
map(|input|
|
||||
Tensor::of_slice(&(input))).
|
||||
collect::<Vec<_>>();
|
||||
})
|
||||
.map(|input| Tensor::of_slice(&(input)))
|
||||
.collect::<Vec<_>>();
|
||||
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
|
||||
|
||||
// Forward pass
|
||||
let (decoder_output, encoder_output, _, _, _, _, _) = no_grad(|| {
|
||||
bart_model
|
||||
.forward_t(Some(&input_tensor),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
false)
|
||||
});
|
||||
let (decoder_output, encoder_output, _, _, _, _, _) =
|
||||
no_grad(|| bart_model.forward_t(Some(&input_tensor), None, None, None, None, None, false));
|
||||
|
||||
// Print masked tokens
|
||||
println!("{:?}", encoder_output);
|
||||
|
@ -12,18 +12,22 @@
|
||||
|
||||
extern crate failure;
|
||||
|
||||
use tch::{Device, nn, Tensor, no_grad};
|
||||
use rust_tokenizers::{BertTokenizer, TruncationStrategy, Tokenizer, Vocab};
|
||||
use rust_bert::bert::{
|
||||
BertConfig, BertConfigResources, BertForMaskedLM, BertModelResources, BertVocabResources,
|
||||
};
|
||||
use rust_bert::resources::{download_resource, RemoteResource, Resource};
|
||||
use rust_bert::Config;
|
||||
use rust_bert::bert::{BertConfig, BertForMaskedLM, BertConfigResources, BertVocabResources, BertModelResources};
|
||||
use rust_bert::resources::{Resource, download_resource, RemoteResource};
|
||||
|
||||
use rust_tokenizers::{BertTokenizer, Tokenizer, TruncationStrategy, Vocab};
|
||||
use tch::{nn, no_grad, Device, Tensor};
|
||||
|
||||
fn main() -> failure::Fallible<()> {
|
||||
// Resources paths
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT));
|
||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT));
|
||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(BertModelResources::BERT));
|
||||
let config_resource =
|
||||
Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT));
|
||||
let vocab_resource =
|
||||
Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT));
|
||||
let weights_resource =
|
||||
Resource::Remote(RemoteResource::from_pretrained(BertModelResources::BERT));
|
||||
let config_path = download_resource(&config_resource)?;
|
||||
let vocab_path = download_resource(&vocab_resource)?;
|
||||
let weights_path = download_resource(&weights_resource)?;
|
||||
@ -37,32 +41,40 @@ fn main() -> failure::Fallible<()> {
|
||||
vs.load(weights_path)?;
|
||||
|
||||
// Define input
|
||||
let input = ["Looks like one [MASK] is missing", "It was a very nice and [MASK] day"];
|
||||
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
||||
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
|
||||
let tokenized_input = tokenized_input.
|
||||
iter().
|
||||
map(|input| input.token_ids.clone()).
|
||||
map(|mut input| {
|
||||
let input = [
|
||||
"Looks like one [MASK] is missing",
|
||||
"It was a very nice and [MASK] day",
|
||||
];
|
||||
let tokenized_input =
|
||||
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
||||
let max_len = tokenized_input
|
||||
.iter()
|
||||
.map(|input| input.token_ids.len())
|
||||
.max()
|
||||
.unwrap();
|
||||
let tokenized_input = tokenized_input
|
||||
.iter()
|
||||
.map(|input| input.token_ids.clone())
|
||||
.map(|mut input| {
|
||||
input.extend(vec![0; max_len - input.len()]);
|
||||
input
|
||||
}).
|
||||
map(|input|
|
||||
Tensor::of_slice(&(input))).
|
||||
collect::<Vec<_>>();
|
||||
})
|
||||
.map(|input| Tensor::of_slice(&(input)))
|
||||
.collect::<Vec<_>>();
|
||||
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
|
||||
|
||||
// Forward pass
|
||||
let (output, _, _) = no_grad(|| {
|
||||
bert_model
|
||||
.forward_t(Some(input_tensor),
|
||||
bert_model.forward_t(
|
||||
Some(input_tensor),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
&None,
|
||||
&None,
|
||||
false)
|
||||
false,
|
||||
)
|
||||
});
|
||||
|
||||
// Print masked tokens
|
||||
|
@ -11,20 +11,28 @@
|
||||
// limitations under the License.
|
||||
extern crate failure;
|
||||
|
||||
use tch::{Device, Tensor, nn, no_grad};
|
||||
use rust_tokenizers::preprocessing::tokenizer::base_tokenizer::{Tokenizer, TruncationStrategy};
|
||||
use rust_tokenizers::bert_tokenizer::BertTokenizer;
|
||||
use rust_tokenizers::preprocessing::vocab::base_vocab::Vocab;
|
||||
use rust_bert::distilbert::{
|
||||
DistilBertConfig, DistilBertConfigResources, DistilBertModelMaskedLM, DistilBertModelResources,
|
||||
DistilBertVocabResources,
|
||||
};
|
||||
use rust_bert::resources::{download_resource, RemoteResource, Resource};
|
||||
use rust_bert::Config;
|
||||
use rust_bert::distilbert::{DistilBertConfig, DistilBertModelMaskedLM, DistilBertConfigResources, DistilBertVocabResources, DistilBertModelResources};
|
||||
use rust_bert::resources::{Resource, download_resource, RemoteResource};
|
||||
|
||||
use rust_tokenizers::bert_tokenizer::BertTokenizer;
|
||||
use rust_tokenizers::preprocessing::tokenizer::base_tokenizer::{Tokenizer, TruncationStrategy};
|
||||
use rust_tokenizers::preprocessing::vocab::base_vocab::Vocab;
|
||||
use tch::{nn, no_grad, Device, Tensor};
|
||||
|
||||
fn main() -> failure::Fallible<()> {
|
||||
// Resources paths
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(DistilBertConfigResources::DISTIL_BERT));
|
||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(DistilBertVocabResources::DISTIL_BERT));
|
||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(DistilBertModelResources::DISTIL_BERT));
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
DistilBertConfigResources::DISTIL_BERT,
|
||||
));
|
||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
DistilBertVocabResources::DISTIL_BERT,
|
||||
));
|
||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
DistilBertModelResources::DISTIL_BERT,
|
||||
));
|
||||
let config_path = download_resource(&config_resource)?;
|
||||
let vocab_path = download_resource(&vocab_resource)?;
|
||||
let weights_path = download_resource(&weights_resource)?;
|
||||
@ -38,29 +46,35 @@ fn main() -> failure::Fallible<()> {
|
||||
vs.load(weights_path)?;
|
||||
|
||||
// Define input
|
||||
let input = ["Looks like one thing is missing", "It\'s like comparing oranges to apples"];
|
||||
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
||||
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
|
||||
let mut tokenized_input = tokenized_input.
|
||||
iter().
|
||||
map(|input| input.token_ids.clone()).
|
||||
map(|mut input| {
|
||||
let input = [
|
||||
"Looks like one thing is missing",
|
||||
"It\'s like comparing oranges to apples",
|
||||
];
|
||||
let tokenized_input =
|
||||
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
||||
let max_len = tokenized_input
|
||||
.iter()
|
||||
.map(|input| input.token_ids.len())
|
||||
.max()
|
||||
.unwrap();
|
||||
let mut tokenized_input = tokenized_input
|
||||
.iter()
|
||||
.map(|input| input.token_ids.clone())
|
||||
.map(|mut input| {
|
||||
input.extend(vec![0; max_len - input.len()]);
|
||||
input
|
||||
}).
|
||||
collect::<Vec<_>>();
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
// Masking the token [thing] of sentence 1 and [oranges] of sentence 2
|
||||
tokenized_input[0][4] = 103;
|
||||
tokenized_input[1][6] = 103;
|
||||
let tokenized_input = tokenized_input.
|
||||
iter().
|
||||
map(|input|
|
||||
Tensor::of_slice(&(input))).
|
||||
collect::<Vec<_>>();
|
||||
let tokenized_input = tokenized_input
|
||||
.iter()
|
||||
.map(|input| Tensor::of_slice(&(input)))
|
||||
.collect::<Vec<_>>();
|
||||
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
|
||||
|
||||
|
||||
// Forward pass
|
||||
let (output, _, _) = no_grad(|| {
|
||||
distil_bert_model
|
||||
|
@ -1,26 +1,44 @@
|
||||
extern crate failure;
|
||||
|
||||
use rust_bert::gpt2::{Gpt2ConfigResources, Gpt2VocabResources, Gpt2MergesResources, Gpt2ModelResources};
|
||||
use rust_bert::distilbert::{DistilBertModelResources, DistilBertConfigResources, DistilBertVocabResources};
|
||||
use rust_bert::openai_gpt::{OpenAiGptConfigResources, OpenAiGptVocabResources, OpenAiGptMergesResources, OpenAiGptModelResources};
|
||||
use rust_bert::roberta::{RobertaConfigResources, RobertaVocabResources, RobertaMergesResources, RobertaModelResources};
|
||||
use rust_bert::bert::{BertConfigResources, BertVocabResources, BertModelResources};
|
||||
use rust_bert::bart::{BartConfigResources, BartVocabResources, BartMergesResources, BartModelResources};
|
||||
use rust_bert::resources::{Resource, download_resource, RemoteResource};
|
||||
use rust_bert::electra::{ElectraConfigResources, ElectraVocabResources, ElectraModelResources};
|
||||
use rust_bert::albert::{AlbertConfigResources, AlbertVocabResources, AlbertModelResources};
|
||||
use rust_bert::albert::{AlbertConfigResources, AlbertModelResources, AlbertVocabResources};
|
||||
use rust_bert::bart::{
|
||||
BartConfigResources, BartMergesResources, BartModelResources, BartVocabResources,
|
||||
};
|
||||
use rust_bert::bert::{BertConfigResources, BertModelResources, BertVocabResources};
|
||||
use rust_bert::distilbert::{
|
||||
DistilBertConfigResources, DistilBertModelResources, DistilBertVocabResources,
|
||||
};
|
||||
use rust_bert::electra::{ElectraConfigResources, ElectraModelResources, ElectraVocabResources};
|
||||
use rust_bert::gpt2::{
|
||||
Gpt2ConfigResources, Gpt2MergesResources, Gpt2ModelResources, Gpt2VocabResources,
|
||||
};
|
||||
use rust_bert::openai_gpt::{
|
||||
OpenAiGptConfigResources, OpenAiGptMergesResources, OpenAiGptModelResources,
|
||||
OpenAiGptVocabResources,
|
||||
};
|
||||
use rust_bert::resources::{download_resource, RemoteResource, Resource};
|
||||
use rust_bert::roberta::{
|
||||
RobertaConfigResources, RobertaMergesResources, RobertaModelResources, RobertaVocabResources,
|
||||
};
|
||||
|
||||
/// This example downloads and caches all dependencies used in model tests. This allows for safe
|
||||
/// multi threaded testing (two test using the same resource would otherwise download the file to
|
||||
/// the same location).
|
||||
|
||||
|
||||
fn download_distil_gpt2() -> failure::Fallible<()> {
|
||||
// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::DISTIL_GPT2));
|
||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::DISTIL_GPT2));
|
||||
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::DISTIL_GPT2));
|
||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::DISTIL_GPT2));
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
Gpt2ConfigResources::DISTIL_GPT2,
|
||||
));
|
||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
Gpt2VocabResources::DISTIL_GPT2,
|
||||
));
|
||||
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
Gpt2MergesResources::DISTIL_GPT2,
|
||||
));
|
||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
Gpt2ModelResources::DISTIL_GPT2,
|
||||
));
|
||||
let _ = download_resource(&config_resource)?;
|
||||
let _ = download_resource(&vocab_resource)?;
|
||||
let _ = download_resource(&merges_resource)?;
|
||||
@ -30,9 +48,15 @@ fn download_distil_gpt2() -> failure::Fallible<()> {
|
||||
|
||||
fn download_distilbert_sst2() -> failure::Fallible<()> {
|
||||
// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
|
||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(DistilBertModelResources::DISTIL_BERT_SST2));
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(DistilBertConfigResources::DISTIL_BERT_SST2));
|
||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(DistilBertVocabResources::DISTIL_BERT_SST2));
|
||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
DistilBertModelResources::DISTIL_BERT_SST2,
|
||||
));
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
DistilBertConfigResources::DISTIL_BERT_SST2,
|
||||
));
|
||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
DistilBertVocabResources::DISTIL_BERT_SST2,
|
||||
));
|
||||
let _ = download_resource(&config_resource)?;
|
||||
let _ = download_resource(&vocab_resource)?;
|
||||
let _ = download_resource(&weights_resource)?;
|
||||
@ -41,9 +65,15 @@ fn download_distilbert_sst2() -> failure::Fallible<()> {
|
||||
|
||||
fn download_distilbert_qa() -> failure::Fallible<()> {
|
||||
// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
|
||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(DistilBertModelResources::DISTIL_BERT_SQUAD));
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(DistilBertConfigResources::DISTIL_BERT_SQUAD));
|
||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(DistilBertVocabResources::DISTIL_BERT_SQUAD));
|
||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
DistilBertModelResources::DISTIL_BERT_SQUAD,
|
||||
));
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
DistilBertConfigResources::DISTIL_BERT_SQUAD,
|
||||
));
|
||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
DistilBertVocabResources::DISTIL_BERT_SQUAD,
|
||||
));
|
||||
let _ = download_resource(&config_resource)?;
|
||||
let _ = download_resource(&vocab_resource)?;
|
||||
let _ = download_resource(&weights_resource)?;
|
||||
@ -52,9 +82,15 @@ fn download_distilbert_qa() -> failure::Fallible<()> {
|
||||
|
||||
fn download_distilbert() -> failure::Fallible<()> {
|
||||
// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
|
||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(DistilBertModelResources::DISTIL_BERT));
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(DistilBertConfigResources::DISTIL_BERT));
|
||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(DistilBertVocabResources::DISTIL_BERT));
|
||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
DistilBertModelResources::DISTIL_BERT,
|
||||
));
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
DistilBertConfigResources::DISTIL_BERT,
|
||||
));
|
||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
DistilBertVocabResources::DISTIL_BERT,
|
||||
));
|
||||
let _ = download_resource(&config_resource)?;
|
||||
let _ = download_resource(&vocab_resource)?;
|
||||
let _ = download_resource(&weights_resource)?;
|
||||
@ -63,10 +99,14 @@ fn download_distilbert() -> failure::Fallible<()> {
|
||||
|
||||
fn download_gpt2() -> failure::Fallible<()> {
|
||||
// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2. Modified with conversion to C-array format.
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2));
|
||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2));
|
||||
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
|
||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
|
||||
let config_resource =
|
||||
Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2));
|
||||
let vocab_resource =
|
||||
Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2));
|
||||
let merges_resource =
|
||||
Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
|
||||
let weights_resource =
|
||||
Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
|
||||
let _ = download_resource(&config_resource)?;
|
||||
let _ = download_resource(&vocab_resource)?;
|
||||
let _ = download_resource(&merges_resource)?;
|
||||
@ -76,10 +116,18 @@ fn download_gpt2() -> failure::Fallible<()> {
|
||||
|
||||
fn download_gpt() -> failure::Fallible<()> {
|
||||
// Shared under MIT license by the OpenAI team at https://github.com/openai/finetune-transformer-lm. Modified with conversion to C-array format.
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptConfigResources::GPT));
|
||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptVocabResources::GPT));
|
||||
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptMergesResources::GPT));
|
||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptModelResources::GPT));
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
OpenAiGptConfigResources::GPT,
|
||||
));
|
||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
OpenAiGptVocabResources::GPT,
|
||||
));
|
||||
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
OpenAiGptMergesResources::GPT,
|
||||
));
|
||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
OpenAiGptModelResources::GPT,
|
||||
));
|
||||
let _ = download_resource(&config_resource)?;
|
||||
let _ = download_resource(&vocab_resource)?;
|
||||
let _ = download_resource(&merges_resource)?;
|
||||
@ -89,10 +137,18 @@ fn download_gpt() -> failure::Fallible<()> {
|
||||
|
||||
fn download_roberta() -> failure::Fallible<()> {
|
||||
// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaConfigResources::ROBERTA));
|
||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaVocabResources::ROBERTA));
|
||||
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaMergesResources::ROBERTA));
|
||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaModelResources::ROBERTA));
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
RobertaConfigResources::ROBERTA,
|
||||
));
|
||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
RobertaVocabResources::ROBERTA,
|
||||
));
|
||||
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
RobertaMergesResources::ROBERTA,
|
||||
));
|
||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
RobertaModelResources::ROBERTA,
|
||||
));
|
||||
let _ = download_resource(&config_resource)?;
|
||||
let _ = download_resource(&vocab_resource)?;
|
||||
let _ = download_resource(&merges_resource)?;
|
||||
@ -102,9 +158,12 @@ fn download_roberta() -> failure::Fallible<()> {
|
||||
|
||||
fn download_bert() -> failure::Fallible<()> {
|
||||
// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/bert. Modified with conversion to C-array format.
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT));
|
||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT));
|
||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(BertModelResources::BERT));
|
||||
let config_resource =
|
||||
Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT));
|
||||
let vocab_resource =
|
||||
Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT));
|
||||
let weights_resource =
|
||||
Resource::Remote(RemoteResource::from_pretrained(BertModelResources::BERT));
|
||||
let _ = download_resource(&config_resource)?;
|
||||
let _ = download_resource(&vocab_resource)?;
|
||||
let _ = download_resource(&weights_resource)?;
|
||||
@ -113,9 +172,15 @@ fn download_bert() -> failure::Fallible<()> {
|
||||
|
||||
fn download_bert_ner() -> failure::Fallible<()> {
|
||||
// Shared under MIT license by the MDZ Digital Library team at the Bavarian State Library at https://github.com/dbmdz/berts. Modified with conversion to C-array format.
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT_NER));
|
||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT_NER));
|
||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(BertModelResources::BERT_NER));
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
BertConfigResources::BERT_NER,
|
||||
));
|
||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
BertVocabResources::BERT_NER,
|
||||
));
|
||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
BertModelResources::BERT_NER,
|
||||
));
|
||||
let _ = download_resource(&config_resource)?;
|
||||
let _ = download_resource(&vocab_resource)?;
|
||||
let _ = download_resource(&weights_resource)?;
|
||||
@ -124,10 +189,14 @@ fn download_bert_ner() -> failure::Fallible<()> {
|
||||
|
||||
fn download_bart() -> failure::Fallible<()> {
|
||||
// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(BartConfigResources::BART));
|
||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(BartVocabResources::BART));
|
||||
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(BartMergesResources::BART));
|
||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(BartModelResources::BART));
|
||||
let config_resource =
|
||||
Resource::Remote(RemoteResource::from_pretrained(BartConfigResources::BART));
|
||||
let vocab_resource =
|
||||
Resource::Remote(RemoteResource::from_pretrained(BartVocabResources::BART));
|
||||
let merges_resource =
|
||||
Resource::Remote(RemoteResource::from_pretrained(BartMergesResources::BART));
|
||||
let weights_resource =
|
||||
Resource::Remote(RemoteResource::from_pretrained(BartModelResources::BART));
|
||||
let _ = download_resource(&config_resource)?;
|
||||
let _ = download_resource(&vocab_resource)?;
|
||||
let _ = download_resource(&merges_resource)?;
|
||||
@ -137,10 +206,18 @@ fn download_bart() -> failure::Fallible<()> {
|
||||
|
||||
fn download_bart_cnn() -> failure::Fallible<()> {
|
||||
// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(BartConfigResources::BART_CNN));
|
||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(BartVocabResources::BART_CNN));
|
||||
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(BartMergesResources::BART_CNN));
|
||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(BartModelResources::BART_CNN));
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
BartConfigResources::BART_CNN,
|
||||
));
|
||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
BartVocabResources::BART_CNN,
|
||||
));
|
||||
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
BartMergesResources::BART_CNN,
|
||||
));
|
||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
BartModelResources::BART_CNN,
|
||||
));
|
||||
let _ = download_resource(&config_resource)?;
|
||||
let _ = download_resource(&vocab_resource)?;
|
||||
let _ = download_resource(&merges_resource)?;
|
||||
@ -150,9 +227,15 @@ fn download_bart_cnn() -> failure::Fallible<()> {
|
||||
|
||||
fn download_electra_generator() -> failure::Fallible<()> {
|
||||
// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/electra. Modified with conversion to C-array format.
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(ElectraConfigResources::BASE_GENERATOR));
|
||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(ElectraVocabResources::BASE_GENERATOR));
|
||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(ElectraModelResources::BASE_GENERATOR));
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
ElectraConfigResources::BASE_GENERATOR,
|
||||
));
|
||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
ElectraVocabResources::BASE_GENERATOR,
|
||||
));
|
||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
ElectraModelResources::BASE_GENERATOR,
|
||||
));
|
||||
let _ = download_resource(&config_resource)?;
|
||||
let _ = download_resource(&vocab_resource)?;
|
||||
let _ = download_resource(&weights_resource)?;
|
||||
@ -161,9 +244,15 @@ fn download_electra_generator() -> failure::Fallible<()> {
|
||||
|
||||
fn download_electra_discriminator() -> failure::Fallible<()> {
|
||||
// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/electra. Modified with conversion to C-array format.
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(ElectraConfigResources::BASE_DISCRIMINATOR));
|
||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(ElectraVocabResources::BASE_DISCRIMINATOR));
|
||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(ElectraModelResources::BASE_DISCRIMINATOR));
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
ElectraConfigResources::BASE_DISCRIMINATOR,
|
||||
));
|
||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
ElectraVocabResources::BASE_DISCRIMINATOR,
|
||||
));
|
||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
ElectraModelResources::BASE_DISCRIMINATOR,
|
||||
));
|
||||
let _ = download_resource(&config_resource)?;
|
||||
let _ = download_resource(&vocab_resource)?;
|
||||
let _ = download_resource(&weights_resource)?;
|
||||
@ -172,9 +261,15 @@ fn download_electra_discriminator() -> failure::Fallible<()> {
|
||||
|
||||
fn download_albert_base_v2() -> failure::Fallible<()> {
|
||||
// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/ALBERT. Modified with conversion to C-array format.
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(AlbertConfigResources::ALBERT_BASE_V2));
|
||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(AlbertVocabResources::ALBERT_BASE_V2));
|
||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(AlbertModelResources::ALBERT_BASE_V2));
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
AlbertConfigResources::ALBERT_BASE_V2,
|
||||
));
|
||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
AlbertVocabResources::ALBERT_BASE_V2,
|
||||
));
|
||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
AlbertModelResources::ALBERT_BASE_V2,
|
||||
));
|
||||
let _ = download_resource(&config_resource)?;
|
||||
let _ = download_resource(&vocab_resource)?;
|
||||
let _ = download_resource(&weights_resource)?;
|
||||
|
@ -12,18 +12,26 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
|
||||
use rust_bert::resources::{Resource, download_resource, RemoteResource};
|
||||
use rust_bert::electra::{ElectraConfig, ElectraDiscriminator, ElectraConfigResources, ElectraVocabResources, ElectraModelResources};
|
||||
use rust_bert::electra::{
|
||||
ElectraConfig, ElectraConfigResources, ElectraDiscriminator, ElectraModelResources,
|
||||
ElectraVocabResources,
|
||||
};
|
||||
use rust_bert::resources::{download_resource, RemoteResource, Resource};
|
||||
use rust_bert::Config;
|
||||
use rust_tokenizers::{BertTokenizer, Tokenizer, TruncationStrategy};
|
||||
use tch::{Tensor, Device, nn, no_grad};
|
||||
use tch::{nn, no_grad, Device, Tensor};
|
||||
|
||||
fn main() -> failure::Fallible<()> {
|
||||
// Resources paths
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(ElectraConfigResources::BASE_DISCRIMINATOR));
|
||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(ElectraVocabResources::BASE_DISCRIMINATOR));
|
||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(ElectraModelResources::BASE_DISCRIMINATOR));
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
ElectraConfigResources::BASE_DISCRIMINATOR,
|
||||
));
|
||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
ElectraVocabResources::BASE_DISCRIMINATOR,
|
||||
));
|
||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
ElectraModelResources::BASE_DISCRIMINATOR,
|
||||
));
|
||||
let config_path = download_resource(&config_resource)?;
|
||||
let vocab_path = download_resource(&vocab_resource)?;
|
||||
let weights_path = download_resource(&weights_resource)?;
|
||||
@ -38,43 +46,43 @@ fn main() -> failure::Fallible<()> {
|
||||
|
||||
// Define input
|
||||
let input = ["One Two Three Ten Five Six Seven Eight"];
|
||||
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
||||
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
|
||||
let encoded_input = tokenized_input.
|
||||
iter().
|
||||
map(|input| input.token_ids.clone()).
|
||||
map(|mut input| {
|
||||
let tokenized_input =
|
||||
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
||||
let max_len = tokenized_input
|
||||
.iter()
|
||||
.map(|input| input.token_ids.len())
|
||||
.max()
|
||||
.unwrap();
|
||||
let encoded_input = tokenized_input
|
||||
.iter()
|
||||
.map(|input| input.token_ids.clone())
|
||||
.map(|mut input| {
|
||||
input.extend(vec![0; max_len - input.len()]);
|
||||
input
|
||||
}).
|
||||
map(|input|
|
||||
Tensor::of_slice(&(input))).
|
||||
collect::<Vec<_>>();
|
||||
})
|
||||
.map(|input| Tensor::of_slice(&(input)))
|
||||
.collect::<Vec<_>>();
|
||||
let input_tensor = Tensor::stack(encoded_input.as_slice(), 0).to(device);
|
||||
|
||||
// Forward pass
|
||||
let (output, _, _) = no_grad(|| {
|
||||
electra_model
|
||||
.forward_t(Some(input_tensor),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
false)
|
||||
});
|
||||
let (output, _, _) =
|
||||
no_grad(|| electra_model.forward_t(Some(input_tensor), None, None, None, None, false));
|
||||
|
||||
// Print model predictions
|
||||
for (position, token) in tokenized_input[0].token_ids.iter().enumerate() {
|
||||
let probability = output.double_value(&[position as i64]);
|
||||
let generated = if probability > 0.5 { "generated" } else { "original" };
|
||||
println!("{:?}: {} ({:.1}%)",
|
||||
tokenizer.decode([*token].to_vec(),
|
||||
false,
|
||||
false),
|
||||
let generated = if probability > 0.5 {
|
||||
"generated"
|
||||
} else {
|
||||
"original"
|
||||
};
|
||||
println!(
|
||||
"{:?}: {} ({:.1}%)",
|
||||
tokenizer.decode([*token].to_vec(), false, false),
|
||||
generated,
|
||||
100f64 * probability)
|
||||
100f64 * probability
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
@ -12,18 +12,26 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
|
||||
use rust_bert::resources::{Resource, download_resource, RemoteResource};
|
||||
use rust_bert::electra::{ElectraConfig, ElectraForMaskedLM, ElectraModelResources, ElectraConfigResources, ElectraVocabResources};
|
||||
use rust_bert::electra::{
|
||||
ElectraConfig, ElectraConfigResources, ElectraForMaskedLM, ElectraModelResources,
|
||||
ElectraVocabResources,
|
||||
};
|
||||
use rust_bert::resources::{download_resource, RemoteResource, Resource};
|
||||
use rust_bert::Config;
|
||||
use rust_tokenizers::{BertTokenizer, Tokenizer, TruncationStrategy, Vocab};
|
||||
use tch::{Tensor, Device, nn, no_grad};
|
||||
use tch::{nn, no_grad, Device, Tensor};
|
||||
|
||||
fn main() -> failure::Fallible<()> {
|
||||
// Resources paths
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(ElectraConfigResources::BASE_GENERATOR));
|
||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(ElectraVocabResources::BASE_GENERATOR));
|
||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(ElectraModelResources::BASE_GENERATOR));
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
ElectraConfigResources::BASE_GENERATOR,
|
||||
));
|
||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
ElectraVocabResources::BASE_GENERATOR,
|
||||
));
|
||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
ElectraModelResources::BASE_GENERATOR,
|
||||
));
|
||||
let config_path = download_resource(&config_resource)?;
|
||||
let vocab_path = download_resource(&vocab_resource)?;
|
||||
let weights_path = download_resource(&weights_resource)?;
|
||||
@ -37,31 +45,31 @@ fn main() -> failure::Fallible<()> {
|
||||
vs.load(weights_path)?;
|
||||
|
||||
// Define input
|
||||
let input = ["Looks like one [MASK] is missing", "It was a very nice and [MASK] day"];
|
||||
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
||||
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
|
||||
let tokenized_input = tokenized_input.
|
||||
iter().
|
||||
map(|input| input.token_ids.clone()).
|
||||
map(|mut input| {
|
||||
let input = [
|
||||
"Looks like one [MASK] is missing",
|
||||
"It was a very nice and [MASK] day",
|
||||
];
|
||||
let tokenized_input =
|
||||
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
||||
let max_len = tokenized_input
|
||||
.iter()
|
||||
.map(|input| input.token_ids.len())
|
||||
.max()
|
||||
.unwrap();
|
||||
let tokenized_input = tokenized_input
|
||||
.iter()
|
||||
.map(|input| input.token_ids.clone())
|
||||
.map(|mut input| {
|
||||
input.extend(vec![0; max_len - input.len()]);
|
||||
input
|
||||
}).
|
||||
map(|input|
|
||||
Tensor::of_slice(&(input))).
|
||||
collect::<Vec<_>>();
|
||||
})
|
||||
.map(|input| Tensor::of_slice(&(input)))
|
||||
.collect::<Vec<_>>();
|
||||
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
|
||||
|
||||
// Forward pass
|
||||
let (output, _, _) = no_grad(|| {
|
||||
electra_model
|
||||
.forward_t(Some(input_tensor),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
false)
|
||||
});
|
||||
let (output, _, _) =
|
||||
no_grad(|| electra_model.forward_t(Some(input_tensor), None, None, None, None, false));
|
||||
|
||||
// Print masked tokens
|
||||
let index_1 = output.get(0).get(4).argmax(0, false);
|
||||
|
@ -12,11 +12,9 @@
|
||||
|
||||
extern crate failure;
|
||||
|
||||
use rust_bert::pipelines::generation::{GPT2Generator, LanguageGenerator, GenerateConfig};
|
||||
|
||||
use rust_bert::pipelines::generation::{GPT2Generator, GenerateConfig, LanguageGenerator};
|
||||
|
||||
fn main() -> failure::Fallible<()> {
|
||||
|
||||
// Set-up masked LM model
|
||||
let generate_config = GenerateConfig {
|
||||
max_length: 30,
|
||||
@ -30,7 +28,7 @@ fn main() -> failure::Fallible<()> {
|
||||
|
||||
let input_context = "The dog";
|
||||
let second_input_context = "The cat was";
|
||||
let output = model.generate(Some(vec!(input_context, second_input_context)), None);
|
||||
let output = model.generate(Some(vec![input_context, second_input_context]), None);
|
||||
|
||||
for sentence in output {
|
||||
println!("{:?}", sentence);
|
||||
|
@ -12,20 +12,26 @@
|
||||
|
||||
extern crate failure;
|
||||
|
||||
use tch::{Device, nn, Tensor};
|
||||
use rust_tokenizers::{TruncationStrategy, Tokenizer, Gpt2Tokenizer};
|
||||
use rust_bert::gpt2::{Gpt2Config, GPT2LMHeadModel, Gpt2ConfigResources, Gpt2VocabResources, Gpt2MergesResources, Gpt2ModelResources};
|
||||
use rust_bert::pipelines::generation::{LMHeadModel, Cache};
|
||||
use rust_bert::resources::{Resource, download_resource, RemoteResource};
|
||||
use rust_bert::gpt2::{
|
||||
GPT2LMHeadModel, Gpt2Config, Gpt2ConfigResources, Gpt2MergesResources, Gpt2ModelResources,
|
||||
Gpt2VocabResources,
|
||||
};
|
||||
use rust_bert::pipelines::generation::{Cache, LMHeadModel};
|
||||
use rust_bert::resources::{download_resource, RemoteResource, Resource};
|
||||
use rust_bert::Config;
|
||||
|
||||
use rust_tokenizers::{Gpt2Tokenizer, Tokenizer, TruncationStrategy};
|
||||
use tch::{nn, Device, Tensor};
|
||||
|
||||
fn main() -> failure::Fallible<()> {
|
||||
// Resources set-up
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2));
|
||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2));
|
||||
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
|
||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
|
||||
let config_resource =
|
||||
Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2));
|
||||
let vocab_resource =
|
||||
Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2));
|
||||
let merges_resource =
|
||||
Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
|
||||
let weights_resource =
|
||||
Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
|
||||
let config_path = download_resource(&config_resource)?;
|
||||
let vocab_path = download_resource(&vocab_resource)?;
|
||||
let merges_path = download_resource(&merges_resource)?;
|
||||
@ -34,29 +40,38 @@ fn main() -> failure::Fallible<()> {
|
||||
// Set-up masked LM model
|
||||
let device = Device::Cpu;
|
||||
let mut vs = nn::VarStore::new(device);
|
||||
let tokenizer: Gpt2Tokenizer = Gpt2Tokenizer::from_file(vocab_path.to_str().unwrap(), merges_path.to_str().unwrap(), false);
|
||||
let tokenizer: Gpt2Tokenizer = Gpt2Tokenizer::from_file(
|
||||
vocab_path.to_str().unwrap(),
|
||||
merges_path.to_str().unwrap(),
|
||||
false,
|
||||
);
|
||||
let config = Gpt2Config::from_file(config_path);
|
||||
let gpt2_model = GPT2LMHeadModel::new(&vs.root(), &config);
|
||||
vs.load(weights_path)?;
|
||||
|
||||
// Define input
|
||||
let input = ["One two three four five six seven eight nine ten eleven"];
|
||||
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
||||
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
|
||||
let tokenized_input = tokenized_input.
|
||||
iter().
|
||||
map(|input| input.token_ids.clone()).
|
||||
map(|mut input| {
|
||||
let tokenized_input =
|
||||
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
||||
let max_len = tokenized_input
|
||||
.iter()
|
||||
.map(|input| input.token_ids.len())
|
||||
.max()
|
||||
.unwrap();
|
||||
let tokenized_input = tokenized_input
|
||||
.iter()
|
||||
.map(|input| input.token_ids.clone())
|
||||
.map(|mut input| {
|
||||
input.extend(vec![0; max_len - input.len()]);
|
||||
input
|
||||
}).
|
||||
map(|input|
|
||||
Tensor::of_slice(&(input))).
|
||||
collect::<Vec<_>>();
|
||||
})
|
||||
.map(|input| Tensor::of_slice(&(input)))
|
||||
.collect::<Vec<_>>();
|
||||
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
|
||||
|
||||
// Forward pass
|
||||
let (output, _, _, _, _) = gpt2_model.forward_t(
|
||||
let (output, _, _, _, _) = gpt2_model
|
||||
.forward_t(
|
||||
&Some(input_tensor),
|
||||
Cache::None,
|
||||
&None,
|
||||
@ -65,10 +80,12 @@ fn main() -> failure::Fallible<()> {
|
||||
&None,
|
||||
None,
|
||||
&None,
|
||||
false).unwrap();
|
||||
false,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let next_word_id = output.get(0).get(-1).argmax(-1, true).int64_value(&[0]);
|
||||
let next_word = tokenizer.decode(vec!(next_word_id), true, true);
|
||||
let next_word = tokenizer.decode(vec![next_word_id], true, true);
|
||||
println!("Provided input: {}", input[0]);
|
||||
println!("Next word: {}", next_word);
|
||||
|
||||
|
@ -21,7 +21,7 @@ fn main() -> failure::Fallible<()> {
|
||||
// Define input
|
||||
let input = [
|
||||
"My name is Amélie. I live in Москва.",
|
||||
"Chongqing is a city in China."
|
||||
"Chongqing is a city in China.",
|
||||
];
|
||||
|
||||
// Run model
|
||||
|
@ -12,21 +12,31 @@
|
||||
|
||||
extern crate failure;
|
||||
|
||||
use tch::{Device, nn, Tensor};
|
||||
use rust_tokenizers::{TruncationStrategy, Tokenizer, OpenAiGptTokenizer};
|
||||
use rust_bert::gpt2::Gpt2Config;
|
||||
use rust_bert::openai_gpt::{OpenAIGPTLMHeadModel, OpenAiGptConfigResources, OpenAiGptVocabResources, OpenAiGptMergesResources, OpenAiGptModelResources};
|
||||
use rust_bert::pipelines::generation::{LMHeadModel, Cache};
|
||||
use rust_bert::resources::{Resource, download_resource, RemoteResource};
|
||||
use rust_bert::openai_gpt::{
|
||||
OpenAIGPTLMHeadModel, OpenAiGptConfigResources, OpenAiGptMergesResources,
|
||||
OpenAiGptModelResources, OpenAiGptVocabResources,
|
||||
};
|
||||
use rust_bert::pipelines::generation::{Cache, LMHeadModel};
|
||||
use rust_bert::resources::{download_resource, RemoteResource, Resource};
|
||||
use rust_bert::Config;
|
||||
|
||||
use rust_tokenizers::{OpenAiGptTokenizer, Tokenizer, TruncationStrategy};
|
||||
use tch::{nn, Device, Tensor};
|
||||
|
||||
fn main() -> failure::Fallible<()> {
|
||||
// Resources paths
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptConfigResources::GPT));
|
||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptVocabResources::GPT));
|
||||
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptMergesResources::GPT));
|
||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptModelResources::GPT));
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
OpenAiGptConfigResources::GPT,
|
||||
));
|
||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
OpenAiGptVocabResources::GPT,
|
||||
));
|
||||
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
OpenAiGptMergesResources::GPT,
|
||||
));
|
||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
OpenAiGptModelResources::GPT,
|
||||
));
|
||||
let config_path = download_resource(&config_resource)?;
|
||||
let vocab_path = download_resource(&vocab_resource)?;
|
||||
let merges_path = download_resource(&merges_resource)?;
|
||||
@ -35,29 +45,38 @@ fn main() -> failure::Fallible<()> {
|
||||
// Set-up masked LM model
|
||||
let device = Device::Cpu;
|
||||
let mut vs = nn::VarStore::new(device);
|
||||
let tokenizer = OpenAiGptTokenizer::from_file(vocab_path.to_str().unwrap(), merges_path.to_str().unwrap(), true);
|
||||
let tokenizer = OpenAiGptTokenizer::from_file(
|
||||
vocab_path.to_str().unwrap(),
|
||||
merges_path.to_str().unwrap(),
|
||||
true,
|
||||
);
|
||||
let config = Gpt2Config::from_file(config_path);
|
||||
let openai_gpt = OpenAIGPTLMHeadModel::new(&vs.root(), &config);
|
||||
vs.load(weights_path)?;
|
||||
|
||||
// Define input
|
||||
let input = ["Wondering what the next word will"];
|
||||
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
||||
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
|
||||
let tokenized_input = tokenized_input.
|
||||
iter().
|
||||
map(|input| input.token_ids.clone()).
|
||||
map(|mut input| {
|
||||
let tokenized_input =
|
||||
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
||||
let max_len = tokenized_input
|
||||
.iter()
|
||||
.map(|input| input.token_ids.len())
|
||||
.max()
|
||||
.unwrap();
|
||||
let tokenized_input = tokenized_input
|
||||
.iter()
|
||||
.map(|input| input.token_ids.clone())
|
||||
.map(|mut input| {
|
||||
input.extend(vec![0; max_len - input.len()]);
|
||||
input
|
||||
}).
|
||||
map(|input|
|
||||
Tensor::of_slice(&(input))).
|
||||
collect::<Vec<_>>();
|
||||
})
|
||||
.map(|input| Tensor::of_slice(&(input)))
|
||||
.collect::<Vec<_>>();
|
||||
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
|
||||
|
||||
// Forward pass
|
||||
let (output, _, _, _, _) = openai_gpt.forward_t(
|
||||
let (output, _, _, _, _) = openai_gpt
|
||||
.forward_t(
|
||||
&Some(input_tensor),
|
||||
Cache::None,
|
||||
&None,
|
||||
@ -66,10 +85,12 @@ fn main() -> failure::Fallible<()> {
|
||||
&None,
|
||||
None,
|
||||
&None,
|
||||
false).unwrap();
|
||||
false,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let next_word_id = output.get(0).get(-1).argmax(-1, true).int64_value(&[0]);
|
||||
let next_word = tokenizer.decode(vec!(next_word_id), true, true);
|
||||
let next_word = tokenizer.decode(vec![next_word_id], true, true);
|
||||
println!("Provided input: {}", input[0]);
|
||||
println!("Next word: {}", next_word);
|
||||
|
||||
|
@ -12,8 +12,7 @@
|
||||
|
||||
extern crate failure;
|
||||
|
||||
use rust_bert::pipelines::question_answering::{QuestionAnsweringModel, QaInput};
|
||||
|
||||
use rust_bert::pipelines::question_answering::{QaInput, QuestionAnsweringModel};
|
||||
|
||||
fn main() -> failure::Fallible<()> {
|
||||
// Set-up Question Answering model
|
||||
@ -24,11 +23,17 @@ fn main() -> failure::Fallible<()> {
|
||||
let context_1 = String::from("Amy lives in Amsterdam");
|
||||
let question_2 = String::from("Where does Eric live");
|
||||
let context_2 = String::from("While Amy lives in Amsterdam, Eric is in The Hague.");
|
||||
let qa_input_1 = QaInput { question: question_1, context: context_1 };
|
||||
let qa_input_2 = QaInput { question: question_2, context: context_2 };
|
||||
let qa_input_1 = QaInput {
|
||||
question: question_1,
|
||||
context: context_1,
|
||||
};
|
||||
let qa_input_2 = QaInput {
|
||||
question: question_2,
|
||||
context: context_2,
|
||||
};
|
||||
|
||||
// Get answer
|
||||
let answers = qa_model.predict(&vec!(qa_input_1, qa_input_2), 1, 32);
|
||||
let answers = qa_model.predict(&vec![qa_input_1, qa_input_2], 1, 32);
|
||||
println!("{:?}", answers);
|
||||
Ok(())
|
||||
}
|
@ -12,20 +12,30 @@
|
||||
|
||||
extern crate failure;
|
||||
|
||||
use tch::{Device, nn, Tensor, no_grad};
|
||||
use rust_tokenizers::{TruncationStrategy, Tokenizer, Vocab, RobertaTokenizer};
|
||||
use rust_bert::Config;
|
||||
use rust_bert::bert::BertConfig;
|
||||
use rust_bert::roberta::{RobertaForMaskedLM, RobertaVocabResources, RobertaConfigResources, RobertaMergesResources, RobertaModelResources};
|
||||
use rust_bert::resources::{Resource, download_resource, RemoteResource};
|
||||
|
||||
use rust_bert::resources::{download_resource, RemoteResource, Resource};
|
||||
use rust_bert::roberta::{
|
||||
RobertaConfigResources, RobertaForMaskedLM, RobertaMergesResources, RobertaModelResources,
|
||||
RobertaVocabResources,
|
||||
};
|
||||
use rust_bert::Config;
|
||||
use rust_tokenizers::{RobertaTokenizer, Tokenizer, TruncationStrategy, Vocab};
|
||||
use tch::{nn, no_grad, Device, Tensor};
|
||||
|
||||
fn main() -> failure::Fallible<()> {
|
||||
// Resources paths
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaConfigResources::ROBERTA));
|
||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaVocabResources::ROBERTA));
|
||||
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaMergesResources::ROBERTA));
|
||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaModelResources::ROBERTA));
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
RobertaConfigResources::ROBERTA,
|
||||
));
|
||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
RobertaVocabResources::ROBERTA,
|
||||
));
|
||||
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
RobertaMergesResources::ROBERTA,
|
||||
));
|
||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
RobertaModelResources::ROBERTA,
|
||||
));
|
||||
let config_path = download_resource(&config_resource)?;
|
||||
let vocab_path = download_resource(&vocab_resource)?;
|
||||
let merges_path = download_resource(&merges_resource)?;
|
||||
@ -34,45 +44,57 @@ fn main() -> failure::Fallible<()> {
|
||||
// Set-up masked LM model
|
||||
let device = Device::Cpu;
|
||||
let mut vs = nn::VarStore::new(device);
|
||||
let tokenizer: RobertaTokenizer = RobertaTokenizer::from_file(vocab_path.to_str().unwrap(), merges_path.to_str().unwrap(), true);
|
||||
let tokenizer: RobertaTokenizer = RobertaTokenizer::from_file(
|
||||
vocab_path.to_str().unwrap(),
|
||||
merges_path.to_str().unwrap(),
|
||||
true,
|
||||
);
|
||||
let config = BertConfig::from_file(config_path);
|
||||
let bert_model = RobertaForMaskedLM::new(&vs.root(), &config);
|
||||
vs.load(weights_path)?;
|
||||
|
||||
// Define input
|
||||
let input = ["<pad> Looks like one thing is missing", "It\'s like comparing oranges to apples"];
|
||||
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
||||
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
|
||||
let mut tokenized_input = tokenized_input.
|
||||
iter().
|
||||
map(|input| input.token_ids.clone()).
|
||||
map(|mut input| {
|
||||
let input = [
|
||||
"<pad> Looks like one thing is missing",
|
||||
"It\'s like comparing oranges to apples",
|
||||
];
|
||||
let tokenized_input =
|
||||
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
||||
let max_len = tokenized_input
|
||||
.iter()
|
||||
.map(|input| input.token_ids.len())
|
||||
.max()
|
||||
.unwrap();
|
||||
let mut tokenized_input = tokenized_input
|
||||
.iter()
|
||||
.map(|input| input.token_ids.clone())
|
||||
.map(|mut input| {
|
||||
input.extend(vec![0; max_len - input.len()]);
|
||||
input
|
||||
}).
|
||||
collect::<Vec<_>>();
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
// Masking the token [thing] of sentence 1 and [oranges] of sentence 2
|
||||
tokenized_input[0][4] = 103;
|
||||
tokenized_input[1][5] = 103;
|
||||
let tokenized_input = tokenized_input.
|
||||
iter().
|
||||
map(|input|
|
||||
Tensor::of_slice(&(input))).
|
||||
collect::<Vec<_>>();
|
||||
let tokenized_input = tokenized_input
|
||||
.iter()
|
||||
.map(|input| Tensor::of_slice(&(input)))
|
||||
.collect::<Vec<_>>();
|
||||
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
|
||||
|
||||
// Forward pass
|
||||
let (output, _, _) = no_grad(|| {
|
||||
bert_model
|
||||
.forward_t(Some(input_tensor),
|
||||
bert_model.forward_t(
|
||||
Some(input_tensor),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
&None,
|
||||
&None,
|
||||
false)
|
||||
false,
|
||||
)
|
||||
});
|
||||
|
||||
// Print masked tokens
|
||||
|
@ -14,7 +14,6 @@ extern crate failure;
|
||||
|
||||
use rust_bert::pipelines::sentiment::SentimentModel;
|
||||
|
||||
|
||||
fn main() -> failure::Fallible<()> {
|
||||
// Set-up classifier
|
||||
let sentiment_classifier = SentimentModel::new(Default::default())?;
|
||||
|
@ -12,10 +12,9 @@
|
||||
|
||||
extern crate failure;
|
||||
|
||||
use std::path::PathBuf;
|
||||
use rust_bert::pipelines::question_answering::{squad_processor, QuestionAnsweringModel};
|
||||
use std::env;
|
||||
use rust_bert::pipelines::question_answering::{QuestionAnsweringModel, squad_processor};
|
||||
|
||||
use std::path::PathBuf;
|
||||
|
||||
fn main() -> failure::Fallible<()> {
|
||||
// Set-up Question Answering model
|
||||
|
@ -10,13 +10,12 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
extern crate failure;
|
||||
extern crate dirs;
|
||||
extern crate failure;
|
||||
|
||||
use std::path::PathBuf;
|
||||
use rust_bert::pipelines::sentiment::{SentimentModel, ss2_processor};
|
||||
use rust_bert::pipelines::sentiment::{ss2_processor, SentimentModel};
|
||||
use std::env;
|
||||
|
||||
use std::path::PathBuf;
|
||||
|
||||
fn main() -> failure::Fallible<()> {
|
||||
// Set-up classifier
|
||||
@ -30,11 +29,19 @@ fn main() -> failure::Fallible<()> {
|
||||
|
||||
// Run model
|
||||
let batch_size = 64;
|
||||
let mut output = vec!();
|
||||
let mut output = vec![];
|
||||
for batch in inputs.chunks(batch_size) {
|
||||
output.push(sentiment_classifier.predict(batch.iter().map(|v| v.as_str()).collect::<Vec<&str>>().as_slice()));
|
||||
output.push(
|
||||
sentiment_classifier.predict(
|
||||
batch
|
||||
.iter()
|
||||
.map(|v| v.as_str())
|
||||
.collect::<Vec<&str>>()
|
||||
.as_slice(),
|
||||
),
|
||||
);
|
||||
}
|
||||
let mut flat_outputs = vec!();
|
||||
let mut flat_outputs = vec![];
|
||||
for batch_output in output.iter_mut() {
|
||||
flat_outputs.append(batch_output);
|
||||
}
|
||||
|
@ -14,7 +14,6 @@ extern crate failure;
|
||||
|
||||
use rust_bert::pipelines::summarization::SummarizationModel;
|
||||
|
||||
|
||||
fn main() -> failure::Fallible<()> {
|
||||
let summarization_model = SummarizationModel::new(Default::default())?;
|
||||
|
||||
@ -44,7 +43,7 @@ about exoplanets like K2-18b."];
|
||||
let _output = summarization_model.summarize(&input);
|
||||
for sentence in _output {
|
||||
println!("{:?}", sentence);
|
||||
};
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
@ -10,18 +10,26 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use rust_bert::pipelines::token_classification::{TokenClassificationModel, TokenClassificationConfig, LabelAggregationOption};
|
||||
use rust_bert::resources::{Resource, RemoteResource};
|
||||
use rust_bert::bert::{BertModelResources, BertVocabResources, BertConfigResources};
|
||||
use rust_bert::bert::{BertConfigResources, BertModelResources, BertVocabResources};
|
||||
use rust_bert::pipelines::common::ModelType;
|
||||
use rust_bert::pipelines::token_classification::{
|
||||
LabelAggregationOption, TokenClassificationConfig, TokenClassificationModel,
|
||||
};
|
||||
use rust_bert::resources::{RemoteResource, Resource};
|
||||
|
||||
fn main() -> failure::Fallible<()> {
|
||||
|
||||
// Load a configuration
|
||||
let config = TokenClassificationConfig::new(ModelType::Bert,
|
||||
Resource::Remote(RemoteResource::from_pretrained(BertModelResources::BERT_NER)),
|
||||
Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT_NER)),
|
||||
Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT_NER)),
|
||||
let config = TokenClassificationConfig::new(
|
||||
ModelType::Bert,
|
||||
Resource::Remote(RemoteResource::from_pretrained(
|
||||
BertModelResources::BERT_NER,
|
||||
)),
|
||||
Resource::Remote(RemoteResource::from_pretrained(
|
||||
BertConfigResources::BERT_NER,
|
||||
)),
|
||||
Resource::Remote(RemoteResource::from_pretrained(
|
||||
BertVocabResources::BERT_NER,
|
||||
)),
|
||||
None, //merges resource only relevant with ModelType::Roberta
|
||||
false, //lowercase
|
||||
LabelAggregationOption::Mode,
|
||||
@ -31,7 +39,7 @@ fn main() -> failure::Fallible<()> {
|
||||
let token_classification_model = TokenClassificationModel::new(config)?;
|
||||
let input = [
|
||||
"My name is Amélie. I live in Москва.",
|
||||
"Chongqing is a city in China."
|
||||
"Chongqing is a city in China.",
|
||||
];
|
||||
let token_outputs = token_classification_model.predict(&input, true, false); //ignore_first_label = true (only returns the NER parts, ignoring first label O)
|
||||
|
||||
|
@ -13,12 +13,12 @@
|
||||
|
||||
extern crate failure;
|
||||
|
||||
use rust_bert::pipelines::translation::{TranslationConfig, TranslationModel, Language};
|
||||
use rust_bert::pipelines::translation::{Language, TranslationConfig, TranslationModel};
|
||||
use tch::Device;
|
||||
|
||||
fn main() -> failure::Fallible<()> {
|
||||
|
||||
let translation_config = TranslationConfig::new(Language::EnglishToGerman, Device::cuda_if_available());
|
||||
let translation_config =
|
||||
TranslationConfig::new(Language::EnglishToGerman, Device::cuda_if_available());
|
||||
let model = TranslationModel::new(translation_config)?;
|
||||
|
||||
let input_context_1 = "The quick brown fox jumps over the lazy dog";
|
||||
|
1
rustfmt.toml
Normal file
1
rustfmt.toml
Normal file
@ -0,0 +1 @@
|
||||
format_code_in_doc_comments = true
|
@ -11,16 +11,15 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
|
||||
use std::collections::HashMap;
|
||||
use crate::Config;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use crate::albert::embeddings::AlbertEmbeddings;
|
||||
use crate::albert::encoder::AlbertTransformer;
|
||||
use tch::{nn, Tensor, Kind};
|
||||
use crate::common::activations::{_tanh, _gelu_new, _gelu, _relu, _mish};
|
||||
use tch::nn::Module;
|
||||
use crate::common::activations::{_gelu, _gelu_new, _mish, _relu, _tanh};
|
||||
use crate::common::dropout::Dropout;
|
||||
use crate::Config;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use tch::nn::Module;
|
||||
use tch::{nn, Kind, Tensor};
|
||||
|
||||
/// # ALBERT Pretrained model weight files
|
||||
pub struct AlbertModelResources;
|
||||
@ -33,20 +32,28 @@ pub struct AlbertVocabResources;
|
||||
|
||||
impl AlbertModelResources {
|
||||
/// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/ALBERT. Modified with conversion to C-array format.
|
||||
pub const ALBERT_BASE_V2: (&'static str, &'static str) = ("albert-base-v2/model.ot", "https://cdn.huggingface.co/albert-base-v2/rust_model.ot");
|
||||
pub const ALBERT_BASE_V2: (&'static str, &'static str) = (
|
||||
"albert-base-v2/model.ot",
|
||||
"https://cdn.huggingface.co/albert-base-v2/rust_model.ot",
|
||||
);
|
||||
}
|
||||
|
||||
impl AlbertConfigResources {
|
||||
/// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/ALBERT. Modified with conversion to C-array format.
|
||||
pub const ALBERT_BASE_V2: (&'static str, &'static str) = ("albert-base-v2/config.json", "https://cdn.huggingface.co/albert-base-v2-config.json");
|
||||
pub const ALBERT_BASE_V2: (&'static str, &'static str) = (
|
||||
"albert-base-v2/config.json",
|
||||
"https://cdn.huggingface.co/albert-base-v2-config.json",
|
||||
);
|
||||
}
|
||||
|
||||
impl AlbertVocabResources {
|
||||
/// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/ALBERT. Modified with conversion to C-array format.
|
||||
pub const ALBERT_BASE_V2: (&'static str, &'static str) = ("albert-base-v2/spiece.model", "https://cdn.huggingface.co/albert-base-v2-spiece.model");
|
||||
pub const ALBERT_BASE_V2: (&'static str, &'static str) = (
|
||||
"albert-base-v2/spiece.model",
|
||||
"https://cdn.huggingface.co/albert-base-v2-spiece.model",
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
#[allow(non_camel_case_types)]
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
/// # Activation function used in the attention layer and masked language model head
|
||||
@ -61,7 +68,6 @@ pub enum Activation {
|
||||
mish,
|
||||
}
|
||||
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
/// # ALBERT model configuration
|
||||
/// Defines the ALBERT model architecture (e.g. number of layers, hidden layer size, label mapping...)
|
||||
@ -123,10 +129,10 @@ impl AlbertModel {
|
||||
/// # Example
|
||||
///
|
||||
/// ```no_run
|
||||
/// use tch::{nn, Device};
|
||||
/// use rust_bert::albert::{AlbertConfig, AlbertModel};
|
||||
/// use rust_bert::Config;
|
||||
/// use std::path::Path;
|
||||
/// use rust_bert::albert::{AlbertConfig, AlbertModel};
|
||||
/// use tch::{nn, Device};
|
||||
///
|
||||
/// let config_path = Path::new("path/to/config.json");
|
||||
/// let device = Device::Cpu;
|
||||
@ -134,14 +140,23 @@ impl AlbertModel {
|
||||
/// let config = AlbertConfig::from_file(config_path);
|
||||
/// let albert: AlbertModel = AlbertModel::new(&(&p.root() / "albert"), &config);
|
||||
/// ```
|
||||
///
|
||||
pub fn new(p: &nn::Path, config: &AlbertConfig) -> AlbertModel {
|
||||
let embeddings = AlbertEmbeddings::new(&(p / "embeddings"), config);
|
||||
let encoder = AlbertTransformer::new(&(p / "encoder"), config);
|
||||
let pooler = nn::linear(&(p / "pooler"), config.hidden_size, config.hidden_size, Default::default());
|
||||
let pooler = nn::linear(
|
||||
&(p / "pooler"),
|
||||
config.hidden_size,
|
||||
config.hidden_size,
|
||||
Default::default(),
|
||||
);
|
||||
let pooler_activation = Box::new(_tanh);
|
||||
|
||||
AlbertModel { embeddings, encoder, pooler, pooler_activation }
|
||||
AlbertModel {
|
||||
embeddings,
|
||||
encoder,
|
||||
pooler,
|
||||
pooler_activation,
|
||||
}
|
||||
}
|
||||
|
||||
/// Forward pass through the model
|
||||
@ -179,61 +194,89 @@ impl AlbertModel {
|
||||
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
|
||||
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
|
||||
/// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
|
||||
/// let position_ids = Tensor::arange(sequence_length, (Int64, device)).expand(&[batch_size, sequence_length], true);
|
||||
/// let position_ids = Tensor::arange(sequence_length, (Int64, device))
|
||||
/// .expand(&[batch_size, sequence_length], true);
|
||||
///
|
||||
/// let (output, pooled_output, all_hidden_states, all_attentions) = no_grad(|| {
|
||||
/// albert_model
|
||||
/// .forward_t(Some(input_tensor),
|
||||
/// .forward_t(
|
||||
/// Some(input_tensor),
|
||||
/// Some(mask),
|
||||
/// Some(token_type_ids),
|
||||
/// Some(position_ids),
|
||||
/// None,
|
||||
/// false).unwrap()
|
||||
/// false,
|
||||
/// )
|
||||
/// .unwrap()
|
||||
/// });
|
||||
///
|
||||
/// ```
|
||||
///
|
||||
pub fn forward_t(&self,
|
||||
pub fn forward_t(
|
||||
&self,
|
||||
input_ids: Option<Tensor>,
|
||||
mask: Option<Tensor>,
|
||||
token_type_ids: Option<Tensor>,
|
||||
position_ids: Option<Tensor>,
|
||||
input_embeds: Option<Tensor>,
|
||||
train: bool)
|
||||
-> Result<(Tensor, Tensor, Option<Vec<Tensor>>, Option<Vec<Vec<Tensor>>>), &'static str> {
|
||||
train: bool,
|
||||
) -> Result<
|
||||
(
|
||||
Tensor,
|
||||
Tensor,
|
||||
Option<Vec<Tensor>>,
|
||||
Option<Vec<Vec<Tensor>>>,
|
||||
),
|
||||
&'static str,
|
||||
> {
|
||||
let (input_shape, device) = match &input_ids {
|
||||
Some(input_value) => match &input_embeds {
|
||||
Some(_) => { return Err("Only one of input ids or input embeddings may be set"); }
|
||||
None => (input_value.size(), input_value.device())
|
||||
Some(_) => {
|
||||
return Err("Only one of input ids or input embeddings may be set");
|
||||
}
|
||||
None => (input_value.size(), input_value.device()),
|
||||
},
|
||||
None => match &input_embeds {
|
||||
Some(embeds) => (vec!(embeds.size()[0], embeds.size()[1]), embeds.device()),
|
||||
None => { return Err("At least one of input ids or input embeddings must be set"); }
|
||||
Some(embeds) => (vec![embeds.size()[0], embeds.size()[1]], embeds.device()),
|
||||
None => {
|
||||
return Err("At least one of input ids or input embeddings must be set");
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
let mask = match mask {
|
||||
Some(value) => value,
|
||||
None => Tensor::ones(&input_shape, (Kind::Int64, device))
|
||||
None => Tensor::ones(&input_shape, (Kind::Int64, device)),
|
||||
};
|
||||
|
||||
let extended_attention_mask = mask.unsqueeze(1).unsqueeze(2);
|
||||
let extended_attention_mask: Tensor = (extended_attention_mask.ones_like() - extended_attention_mask) * -10000.0;
|
||||
let extended_attention_mask: Tensor =
|
||||
(extended_attention_mask.ones_like() - extended_attention_mask) * -10000.0;
|
||||
|
||||
let embedding_output = match self.embeddings.forward_t(input_ids, token_type_ids, position_ids, input_embeds, train) {
|
||||
let embedding_output = match self.embeddings.forward_t(
|
||||
input_ids,
|
||||
token_type_ids,
|
||||
position_ids,
|
||||
input_embeds,
|
||||
train,
|
||||
) {
|
||||
Ok(value) => value,
|
||||
Err(e) => { return Err(e); }
|
||||
Err(e) => {
|
||||
return Err(e);
|
||||
}
|
||||
};
|
||||
|
||||
let (hidden_state, all_hidden_states, all_attentions) =
|
||||
self.encoder.forward_t(&embedding_output,
|
||||
Some(extended_attention_mask),
|
||||
train);
|
||||
self.encoder
|
||||
.forward_t(&embedding_output, Some(extended_attention_mask), train);
|
||||
|
||||
let pooled_output = self.pooler.forward(&hidden_state.select(1, 0));
|
||||
let pooled_output = (self.pooler_activation)(&pooled_output);
|
||||
|
||||
Ok((hidden_state, pooled_output, all_hidden_states, all_attentions))
|
||||
Ok((
|
||||
hidden_state,
|
||||
pooled_output,
|
||||
all_hidden_states,
|
||||
all_attentions,
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
@ -248,21 +291,43 @@ impl AlbertMLMHead {
|
||||
pub fn new(p: &nn::Path, config: &AlbertConfig) -> AlbertMLMHead {
|
||||
let layer_norm_eps = match config.layer_norm_eps {
|
||||
Some(value) => value,
|
||||
None => 1e-12
|
||||
None => 1e-12,
|
||||
};
|
||||
let layer_norm_config = nn::LayerNormConfig { eps: layer_norm_eps, ..Default::default() };
|
||||
let layer_norm = nn::layer_norm(&(p / "LayerNorm"), vec![config.embedding_size], layer_norm_config);
|
||||
let dense = nn::linear(&(p / "dense"), config.hidden_size, config.embedding_size, Default::default());
|
||||
let decoder = nn::linear(&(p / "decoder"), config.embedding_size, config.vocab_size, Default::default());
|
||||
let layer_norm_config = nn::LayerNormConfig {
|
||||
eps: layer_norm_eps,
|
||||
..Default::default()
|
||||
};
|
||||
let layer_norm = nn::layer_norm(
|
||||
&(p / "LayerNorm"),
|
||||
vec![config.embedding_size],
|
||||
layer_norm_config,
|
||||
);
|
||||
let dense = nn::linear(
|
||||
&(p / "dense"),
|
||||
config.hidden_size,
|
||||
config.embedding_size,
|
||||
Default::default(),
|
||||
);
|
||||
let decoder = nn::linear(
|
||||
&(p / "decoder"),
|
||||
config.embedding_size,
|
||||
config.vocab_size,
|
||||
Default::default(),
|
||||
);
|
||||
|
||||
let activation = Box::new(match &config.hidden_act {
|
||||
Activation::gelu_new => _gelu_new,
|
||||
Activation::gelu => _gelu,
|
||||
Activation::relu => _relu,
|
||||
Activation::mish => _mish
|
||||
Activation::mish => _mish,
|
||||
});
|
||||
|
||||
AlbertMLMHead { layer_norm, dense, decoder, activation }
|
||||
AlbertMLMHead {
|
||||
layer_norm,
|
||||
dense,
|
||||
decoder,
|
||||
activation,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn forward(&self, hidden_states: &Tensor) -> Tensor {
|
||||
@ -292,10 +357,10 @@ impl AlbertForMaskedLM {
|
||||
/// # Example
|
||||
///
|
||||
/// ```no_run
|
||||
/// use tch::{nn, Device};
|
||||
/// use rust_bert::albert::{AlbertConfig, AlbertForMaskedLM};
|
||||
/// use rust_bert::Config;
|
||||
/// use std::path::Path;
|
||||
/// use rust_bert::albert::{AlbertConfig, AlbertForMaskedLM};
|
||||
/// use tch::{nn, Device};
|
||||
///
|
||||
/// let config_path = Path::new("path/to/config.json");
|
||||
/// let device = Device::Cpu;
|
||||
@ -303,12 +368,14 @@ impl AlbertForMaskedLM {
|
||||
/// let config = AlbertConfig::from_file(config_path);
|
||||
/// let albert: AlbertForMaskedLM = AlbertForMaskedLM::new(&p.root(), &config);
|
||||
/// ```
|
||||
///
|
||||
pub fn new(p: &nn::Path, config: &AlbertConfig) -> AlbertForMaskedLM {
|
||||
let albert = AlbertModel::new(&(p / "albert"), config);
|
||||
let predictions = AlbertMLMHead::new(&(p / "predictions"), config);
|
||||
|
||||
AlbertForMaskedLM { albert, predictions }
|
||||
AlbertForMaskedLM {
|
||||
albert,
|
||||
predictions,
|
||||
}
|
||||
}
|
||||
|
||||
/// Forward pass through the model
|
||||
@ -345,28 +412,40 @@ impl AlbertForMaskedLM {
|
||||
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
|
||||
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
|
||||
/// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
|
||||
/// let position_ids = Tensor::arange(sequence_length, (Int64, device)).expand(&[batch_size, sequence_length], true);
|
||||
/// let position_ids = Tensor::arange(sequence_length, (Int64, device))
|
||||
/// .expand(&[batch_size, sequence_length], true);
|
||||
///
|
||||
/// let (output, all_hidden_states, all_attentions) = no_grad(|| {
|
||||
/// albert_model
|
||||
/// .forward_t(Some(input_tensor),
|
||||
/// albert_model.forward_t(
|
||||
/// Some(input_tensor),
|
||||
/// Some(mask),
|
||||
/// Some(token_type_ids),
|
||||
/// Some(position_ids),
|
||||
/// None,
|
||||
/// false)
|
||||
/// false,
|
||||
/// )
|
||||
/// });
|
||||
///
|
||||
/// ```
|
||||
///
|
||||
pub fn forward_t(&self,
|
||||
pub fn forward_t(
|
||||
&self,
|
||||
input_ids: Option<Tensor>,
|
||||
mask: Option<Tensor>,
|
||||
token_type_ids: Option<Tensor>,
|
||||
position_ids: Option<Tensor>,
|
||||
input_embeds: Option<Tensor>,
|
||||
train: bool) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Vec<Tensor>>>) {
|
||||
let (hidden_state, _, all_hidden_states, all_attentions) = self.albert.forward_t(input_ids, mask, token_type_ids, position_ids, input_embeds, train).unwrap();
|
||||
train: bool,
|
||||
) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Vec<Tensor>>>) {
|
||||
let (hidden_state, _, all_hidden_states, all_attentions) = self
|
||||
.albert
|
||||
.forward_t(
|
||||
input_ids,
|
||||
mask,
|
||||
token_type_ids,
|
||||
position_ids,
|
||||
input_embeds,
|
||||
train,
|
||||
)
|
||||
.unwrap();
|
||||
let prediction_scores = self.predictions.forward(&hidden_state);
|
||||
(prediction_scores, all_hidden_states, all_attentions)
|
||||
}
|
||||
@ -395,29 +474,42 @@ impl AlbertForSequenceClassification {
|
||||
/// # Example
|
||||
///
|
||||
/// ```no_run
|
||||
/// use tch::{nn, Device};
|
||||
/// use rust_bert::albert::{AlbertConfig, AlbertForSequenceClassification};
|
||||
/// use rust_bert::Config;
|
||||
/// use std::path::Path;
|
||||
/// use rust_bert::albert::{AlbertConfig, AlbertForSequenceClassification};
|
||||
/// use tch::{nn, Device};
|
||||
///
|
||||
/// let config_path = Path::new("path/to/config.json");
|
||||
/// let device = Device::Cpu;
|
||||
/// let p = nn::VarStore::new(device);
|
||||
/// let config = AlbertConfig::from_file(config_path);
|
||||
/// let albert: AlbertForSequenceClassification = AlbertForSequenceClassification::new(&p.root(), &config);
|
||||
/// let albert: AlbertForSequenceClassification =
|
||||
/// AlbertForSequenceClassification::new(&p.root(), &config);
|
||||
/// ```
|
||||
///
|
||||
pub fn new(p: &nn::Path, config: &AlbertConfig) -> AlbertForSequenceClassification {
|
||||
let albert = AlbertModel::new(&(p / "albert"), config);
|
||||
let classifier_dropout_prob = match config.classifier_dropout_prob {
|
||||
Some(value) => value,
|
||||
None => 0.1
|
||||
None => 0.1,
|
||||
};
|
||||
let dropout = Dropout::new(classifier_dropout_prob);
|
||||
let num_labels = config.id2label.as_ref().expect("num_labels not provided in configuration").len() as i64;
|
||||
let classifier = nn::linear(&(p / "classifier"), config.hidden_size, num_labels, Default::default());
|
||||
let num_labels = config
|
||||
.id2label
|
||||
.as_ref()
|
||||
.expect("num_labels not provided in configuration")
|
||||
.len() as i64;
|
||||
let classifier = nn::linear(
|
||||
&(p / "classifier"),
|
||||
config.hidden_size,
|
||||
num_labels,
|
||||
Default::default(),
|
||||
);
|
||||
|
||||
AlbertForSequenceClassification { albert, dropout, classifier }
|
||||
AlbertForSequenceClassification {
|
||||
albert,
|
||||
dropout,
|
||||
classifier,
|
||||
}
|
||||
}
|
||||
|
||||
/// Forward pass through the model
|
||||
@ -465,18 +557,30 @@ impl AlbertForSequenceClassification {
|
||||
/// None,
|
||||
/// false)
|
||||
/// });
|
||||
///
|
||||
/// ```
|
||||
///
|
||||
pub fn forward_t(&self,
|
||||
pub fn forward_t(
|
||||
&self,
|
||||
input_ids: Option<Tensor>,
|
||||
mask: Option<Tensor>,
|
||||
token_type_ids: Option<Tensor>,
|
||||
position_ids: Option<Tensor>,
|
||||
input_embeds: Option<Tensor>,
|
||||
train: bool) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Vec<Tensor>>>) {
|
||||
let (_, pooled_output, all_hidden_states, all_attentions) = self.albert.forward_t(input_ids, mask, token_type_ids, position_ids, input_embeds, train).unwrap();
|
||||
let logits = pooled_output.apply_t(&self.dropout, train).apply(&self.classifier);
|
||||
train: bool,
|
||||
) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Vec<Tensor>>>) {
|
||||
let (_, pooled_output, all_hidden_states, all_attentions) = self
|
||||
.albert
|
||||
.forward_t(
|
||||
input_ids,
|
||||
mask,
|
||||
token_type_ids,
|
||||
position_ids,
|
||||
input_embeds,
|
||||
train,
|
||||
)
|
||||
.unwrap();
|
||||
let logits = pooled_output
|
||||
.apply_t(&self.dropout, train)
|
||||
.apply(&self.classifier);
|
||||
(logits, all_hidden_states, all_attentions)
|
||||
}
|
||||
}
|
||||
@ -505,25 +609,38 @@ impl AlbertForTokenClassification {
|
||||
/// # Example
|
||||
///
|
||||
/// ```no_run
|
||||
/// use tch::{nn, Device};
|
||||
/// use rust_bert::albert::{AlbertConfig, AlbertForTokenClassification};
|
||||
/// use rust_bert::Config;
|
||||
/// use std::path::Path;
|
||||
/// use rust_bert::albert::{AlbertConfig, AlbertForTokenClassification};
|
||||
/// use tch::{nn, Device};
|
||||
///
|
||||
/// let config_path = Path::new("path/to/config.json");
|
||||
/// let device = Device::Cpu;
|
||||
/// let p = nn::VarStore::new(device);
|
||||
/// let config = AlbertConfig::from_file(config_path);
|
||||
/// let albert: AlbertForTokenClassification = AlbertForTokenClassification::new(&p.root(), &config);
|
||||
/// let albert: AlbertForTokenClassification =
|
||||
/// AlbertForTokenClassification::new(&p.root(), &config);
|
||||
/// ```
|
||||
///
|
||||
pub fn new(p: &nn::Path, config: &AlbertConfig) -> AlbertForTokenClassification {
|
||||
let albert = AlbertModel::new(&(p / "albert"), config);
|
||||
let dropout = Dropout::new(config.hidden_dropout_prob);
|
||||
let num_labels = config.id2label.as_ref().expect("num_labels not provided in configuration").len() as i64;
|
||||
let classifier = nn::linear(&(p / "classifier"), config.hidden_size, num_labels, Default::default());
|
||||
let num_labels = config
|
||||
.id2label
|
||||
.as_ref()
|
||||
.expect("num_labels not provided in configuration")
|
||||
.len() as i64;
|
||||
let classifier = nn::linear(
|
||||
&(p / "classifier"),
|
||||
config.hidden_size,
|
||||
num_labels,
|
||||
Default::default(),
|
||||
);
|
||||
|
||||
AlbertForTokenClassification { albert, dropout, classifier }
|
||||
AlbertForTokenClassification {
|
||||
albert,
|
||||
dropout,
|
||||
classifier,
|
||||
}
|
||||
}
|
||||
|
||||
/// Forward pass through the model
|
||||
@ -571,18 +688,30 @@ impl AlbertForTokenClassification {
|
||||
/// None,
|
||||
/// false)
|
||||
/// });
|
||||
///
|
||||
/// ```
|
||||
///
|
||||
pub fn forward_t(&self,
|
||||
pub fn forward_t(
|
||||
&self,
|
||||
input_ids: Option<Tensor>,
|
||||
mask: Option<Tensor>,
|
||||
token_type_ids: Option<Tensor>,
|
||||
position_ids: Option<Tensor>,
|
||||
input_embeds: Option<Tensor>,
|
||||
train: bool) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Vec<Tensor>>>) {
|
||||
let (sequence_output, _, all_hidden_states, all_attentions) = self.albert.forward_t(input_ids, mask, token_type_ids, position_ids, input_embeds, train).unwrap();
|
||||
let logits = sequence_output.apply_t(&self.dropout, train).apply(&self.classifier);
|
||||
train: bool,
|
||||
) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Vec<Tensor>>>) {
|
||||
let (sequence_output, _, all_hidden_states, all_attentions) = self
|
||||
.albert
|
||||
.forward_t(
|
||||
input_ids,
|
||||
mask,
|
||||
token_type_ids,
|
||||
position_ids,
|
||||
input_embeds,
|
||||
train,
|
||||
)
|
||||
.unwrap();
|
||||
let logits = sequence_output
|
||||
.apply_t(&self.dropout, train)
|
||||
.apply(&self.classifier);
|
||||
(logits, all_hidden_states, all_attentions)
|
||||
}
|
||||
}
|
||||
@ -610,10 +739,10 @@ impl AlbertForQuestionAnswering {
|
||||
/// # Example
|
||||
///
|
||||
/// ```no_run
|
||||
/// use tch::{nn, Device};
|
||||
/// use rust_bert::albert::{AlbertConfig, AlbertForQuestionAnswering};
|
||||
/// use rust_bert::Config;
|
||||
/// use std::path::Path;
|
||||
/// use rust_bert::albert::{AlbertConfig, AlbertForQuestionAnswering};
|
||||
/// use tch::{nn, Device};
|
||||
///
|
||||
/// let config_path = Path::new("path/to/config.json");
|
||||
/// let device = Device::Cpu;
|
||||
@ -621,11 +750,15 @@ impl AlbertForQuestionAnswering {
|
||||
/// let config = AlbertConfig::from_file(config_path);
|
||||
/// let albert: AlbertForQuestionAnswering = AlbertForQuestionAnswering::new(&p.root(), &config);
|
||||
/// ```
|
||||
///
|
||||
pub fn new(p: &nn::Path, config: &AlbertConfig) -> AlbertForQuestionAnswering {
|
||||
let albert = AlbertModel::new(&(p / "albert"), config);
|
||||
let num_labels = 2;
|
||||
let qa_outputs = nn::linear(&(p / "qa_outputs"), config.hidden_size, num_labels, Default::default());
|
||||
let qa_outputs = nn::linear(
|
||||
&(p / "qa_outputs"),
|
||||
config.hidden_size,
|
||||
num_labels,
|
||||
Default::default(),
|
||||
);
|
||||
|
||||
AlbertForQuestionAnswering { albert, qa_outputs }
|
||||
}
|
||||
@ -676,17 +809,32 @@ impl AlbertForQuestionAnswering {
|
||||
/// None,
|
||||
/// false)
|
||||
/// });
|
||||
///
|
||||
/// ```
|
||||
///
|
||||
pub fn forward_t(&self,
|
||||
pub fn forward_t(
|
||||
&self,
|
||||
input_ids: Option<Tensor>,
|
||||
mask: Option<Tensor>,
|
||||
token_type_ids: Option<Tensor>,
|
||||
position_ids: Option<Tensor>,
|
||||
input_embeds: Option<Tensor>,
|
||||
train: bool) -> (Tensor, Tensor, Option<Vec<Tensor>>, Option<Vec<Vec<Tensor>>>) {
|
||||
let (sequence_output, _, all_hidden_states, all_attentions) = self.albert.forward_t(input_ids, mask, token_type_ids, position_ids, input_embeds, train).unwrap();
|
||||
train: bool,
|
||||
) -> (
|
||||
Tensor,
|
||||
Tensor,
|
||||
Option<Vec<Tensor>>,
|
||||
Option<Vec<Vec<Tensor>>>,
|
||||
) {
|
||||
let (sequence_output, _, all_hidden_states, all_attentions) = self
|
||||
.albert
|
||||
.forward_t(
|
||||
input_ids,
|
||||
mask,
|
||||
token_type_ids,
|
||||
position_ids,
|
||||
input_embeds,
|
||||
train,
|
||||
)
|
||||
.unwrap();
|
||||
let logits = sequence_output.apply(&self.qa_outputs).split(1, -1);
|
||||
let (start_logits, end_logits) = (&logits[0], &logits[1]);
|
||||
let start_logits = start_logits.squeeze1(-1);
|
||||
@ -721,10 +869,10 @@ impl AlbertForMultipleChoice {
|
||||
/// # Example
|
||||
///
|
||||
/// ```no_run
|
||||
/// use tch::{nn, Device};
|
||||
/// use rust_bert::albert::{AlbertConfig, AlbertForMultipleChoice};
|
||||
/// use rust_bert::Config;
|
||||
/// use std::path::Path;
|
||||
/// use rust_bert::albert::{AlbertConfig, AlbertForMultipleChoice};
|
||||
/// use tch::{nn, Device};
|
||||
///
|
||||
/// let config_path = Path::new("path/to/config.json");
|
||||
/// let device = Device::Cpu;
|
||||
@ -732,14 +880,22 @@ impl AlbertForMultipleChoice {
|
||||
/// let config = AlbertConfig::from_file(config_path);
|
||||
/// let albert: AlbertForMultipleChoice = AlbertForMultipleChoice::new(&p.root(), &config);
|
||||
/// ```
|
||||
///
|
||||
pub fn new(p: &nn::Path, config: &AlbertConfig) -> AlbertForMultipleChoice {
|
||||
let albert = AlbertModel::new(&(p / "albert"), config);
|
||||
let dropout = Dropout::new(config.hidden_dropout_prob);
|
||||
let num_labels = 1;
|
||||
let classifier = nn::linear(&(p / "classifier"), config.hidden_size, num_labels, Default::default());
|
||||
let classifier = nn::linear(
|
||||
&(p / "classifier"),
|
||||
config.hidden_size,
|
||||
num_labels,
|
||||
Default::default(),
|
||||
);
|
||||
|
||||
AlbertForMultipleChoice { albert, dropout, classifier }
|
||||
AlbertForMultipleChoice {
|
||||
albert,
|
||||
dropout,
|
||||
classifier,
|
||||
}
|
||||
}
|
||||
|
||||
/// Forward pass through the model
|
||||
@ -787,43 +943,67 @@ impl AlbertForMultipleChoice {
|
||||
/// None,
|
||||
/// false).unwrap()
|
||||
/// });
|
||||
///
|
||||
/// ```
|
||||
///
|
||||
pub fn forward_t(&self,
|
||||
pub fn forward_t(
|
||||
&self,
|
||||
input_ids: Option<Tensor>,
|
||||
mask: Option<Tensor>,
|
||||
token_type_ids: Option<Tensor>,
|
||||
position_ids: Option<Tensor>,
|
||||
input_embeds: Option<Tensor>,
|
||||
train: bool) -> Result<(Tensor, Option<Vec<Tensor>>, Option<Vec<Vec<Tensor>>>), &'static str> {
|
||||
train: bool,
|
||||
) -> Result<(Tensor, Option<Vec<Tensor>>, Option<Vec<Vec<Tensor>>>), &'static str> {
|
||||
let (input_ids, input_embeds, num_choices) = match &input_ids {
|
||||
Some(input_value) => match &input_embeds {
|
||||
Some(_) => { return Err("Only one of input ids or input embeddings may be set"); }
|
||||
None => (Some(input_value.view((-1, *input_value.size().last().unwrap()))), None, input_value.size()[1])
|
||||
Some(_) => {
|
||||
return Err("Only one of input ids or input embeddings may be set");
|
||||
}
|
||||
None => (
|
||||
Some(input_value.view((-1, *input_value.size().last().unwrap()))),
|
||||
None,
|
||||
input_value.size()[1],
|
||||
),
|
||||
},
|
||||
None => match &input_embeds {
|
||||
Some(embeds) => (None, Some(embeds.view((-1, embeds.size()[1], embeds.size()[2]))), embeds.size()[1]),
|
||||
None => { return Err("At least one of input ids or input embeddings must be set"); }
|
||||
Some(embeds) => (
|
||||
None,
|
||||
Some(embeds.view((-1, embeds.size()[1], embeds.size()[2]))),
|
||||
embeds.size()[1],
|
||||
),
|
||||
None => {
|
||||
return Err("At least one of input ids or input embeddings must be set");
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
let mask = match mask {
|
||||
Some(value) => Some(value.view((-1, *value.size().last().unwrap()))),
|
||||
None => None
|
||||
None => None,
|
||||
};
|
||||
let token_type_ids = match token_type_ids {
|
||||
Some(value) => Some(value.view((-1, *value.size().last().unwrap()))),
|
||||
None => None
|
||||
None => None,
|
||||
};
|
||||
let position_ids = match position_ids {
|
||||
Some(value) => Some(value.view((-1, *value.size().last().unwrap()))),
|
||||
None => None
|
||||
None => None,
|
||||
};
|
||||
|
||||
|
||||
let (_, pooled_output, all_hidden_states, all_attentions) = self.albert.forward_t(input_ids, mask, token_type_ids, position_ids, input_embeds, train).unwrap();
|
||||
let logits = pooled_output.apply_t(&self.dropout, train).apply(&self.classifier).view((-1, num_choices));
|
||||
let (_, pooled_output, all_hidden_states, all_attentions) = self
|
||||
.albert
|
||||
.forward_t(
|
||||
input_ids,
|
||||
mask,
|
||||
token_type_ids,
|
||||
position_ids,
|
||||
input_embeds,
|
||||
train,
|
||||
)
|
||||
.unwrap();
|
||||
let logits = pooled_output
|
||||
.apply_t(&self.dropout, train)
|
||||
.apply(&self.classifier)
|
||||
.view((-1, num_choices));
|
||||
|
||||
Ok((logits, all_hidden_states, all_attentions))
|
||||
}
|
||||
|
@ -11,10 +11,10 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use crate::common::dropout::Dropout;
|
||||
use tch::{nn, Tensor};
|
||||
use crate::albert::AlbertConfig;
|
||||
use crate::common::dropout::Dropout;
|
||||
use tch::kind::Kind::Float;
|
||||
use tch::{nn, Tensor};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct AlbertSelfAttention {
|
||||
@ -32,24 +32,55 @@ pub struct AlbertSelfAttention {
|
||||
|
||||
impl AlbertSelfAttention {
|
||||
pub fn new(p: nn::Path, config: &AlbertConfig) -> AlbertSelfAttention {
|
||||
assert_eq!(config.hidden_size % config.num_attention_heads, 0, "Hidden size not a multiple of the number of attention heads");
|
||||
assert_eq!(
|
||||
config.hidden_size % config.num_attention_heads,
|
||||
0,
|
||||
"Hidden size not a multiple of the number of attention heads"
|
||||
);
|
||||
|
||||
let query = nn::linear(&p / "query", config.hidden_size, config.hidden_size, Default::default());
|
||||
let key = nn::linear(&p / "key", config.hidden_size, config.hidden_size, Default::default());
|
||||
let value = nn::linear(&p / "value", config.hidden_size, config.hidden_size, Default::default());
|
||||
let dense = nn::linear(&p / "dense", config.hidden_size, config.hidden_size, Default::default());
|
||||
let query = nn::linear(
|
||||
&p / "query",
|
||||
config.hidden_size,
|
||||
config.hidden_size,
|
||||
Default::default(),
|
||||
);
|
||||
let key = nn::linear(
|
||||
&p / "key",
|
||||
config.hidden_size,
|
||||
config.hidden_size,
|
||||
Default::default(),
|
||||
);
|
||||
let value = nn::linear(
|
||||
&p / "value",
|
||||
config.hidden_size,
|
||||
config.hidden_size,
|
||||
Default::default(),
|
||||
);
|
||||
let dense = nn::linear(
|
||||
&p / "dense",
|
||||
config.hidden_size,
|
||||
config.hidden_size,
|
||||
Default::default(),
|
||||
);
|
||||
let dropout = Dropout::new(config.attention_probs_dropout_prob);
|
||||
let attention_head_size = config.hidden_size / config.num_attention_heads;
|
||||
let output_attentions = match config.output_attentions {
|
||||
Some(value) => value,
|
||||
None => false
|
||||
None => false,
|
||||
};
|
||||
let layer_norm_eps = match config.layer_norm_eps {
|
||||
Some(value) => value,
|
||||
None => 1e-12
|
||||
None => 1e-12,
|
||||
};
|
||||
let layer_norm_config = nn::LayerNormConfig { eps: layer_norm_eps, ..Default::default() };
|
||||
let layer_norm = nn::layer_norm(&p / "LayerNorm", vec![config.hidden_size], layer_norm_config);
|
||||
let layer_norm_config = nn::LayerNormConfig {
|
||||
eps: layer_norm_eps,
|
||||
..Default::default()
|
||||
};
|
||||
let layer_norm = nn::layer_norm(
|
||||
&p / "LayerNorm",
|
||||
vec![config.hidden_size],
|
||||
layer_norm_config,
|
||||
);
|
||||
|
||||
AlbertSelfAttention {
|
||||
num_attention_heads: config.num_attention_heads,
|
||||
@ -66,19 +97,23 @@ impl AlbertSelfAttention {
|
||||
}
|
||||
|
||||
fn split_heads(&self, x: Tensor, bs: i64, dim_per_head: i64) -> Tensor {
|
||||
x.view((bs, -1, self.num_attention_heads, dim_per_head)).transpose(1, 2)
|
||||
x.view((bs, -1, self.num_attention_heads, dim_per_head))
|
||||
.transpose(1, 2)
|
||||
}
|
||||
|
||||
pub fn forward_t(&self,
|
||||
pub fn forward_t(
|
||||
&self,
|
||||
input_ids: &Tensor,
|
||||
mask: &Option<Tensor>,
|
||||
train: bool) -> (Tensor, Option<Tensor>) {
|
||||
|
||||
train: bool,
|
||||
) -> (Tensor, Option<Tensor>) {
|
||||
let bs = *input_ids.size().first().unwrap();
|
||||
|
||||
let key_layer = self.split_heads(input_ids.apply(&self.key), bs, self.attention_head_size);
|
||||
let value_layer = self.split_heads(input_ids.apply(&self.value), bs, self.attention_head_size);
|
||||
let query_layer = self.split_heads(input_ids.apply(&self.query), bs, self.attention_head_size);
|
||||
let value_layer =
|
||||
self.split_heads(input_ids.apply(&self.value), bs, self.attention_head_size);
|
||||
let query_layer =
|
||||
self.split_heads(input_ids.apply(&self.query), bs, self.attention_head_size);
|
||||
|
||||
let query_layer: Tensor = query_layer / (self.attention_head_size as f64).sqrt();
|
||||
|
||||
@ -91,9 +126,11 @@ impl AlbertSelfAttention {
|
||||
let weights = scores.softmax(-1, Float).apply_t(&self.dropout, train);
|
||||
let context = weights.matmul(&value_layer).transpose(1, 2).contiguous();
|
||||
|
||||
let w = self.dense.ws
|
||||
.transpose(0, 1)
|
||||
.view((self.num_attention_heads, self.attention_head_size, self.hidden_size));
|
||||
let w = self.dense.ws.transpose(0, 1).view((
|
||||
self.num_attention_heads,
|
||||
self.attention_head_size,
|
||||
self.hidden_size,
|
||||
));
|
||||
|
||||
let context: Tensor = Tensor::einsum("bfnd,ndh->bfh", &[context, w]) + &self.dense.bs;
|
||||
let context = (input_ids + context.apply_t(&self.dropout, train)).apply(&self.layer_norm);
|
||||
|
@ -11,10 +11,10 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use tch::{nn, Tensor, Kind};
|
||||
use crate::common::dropout::Dropout;
|
||||
use crate::albert::AlbertConfig;
|
||||
use tch::nn::{EmbeddingConfig, embedding};
|
||||
use crate::common::dropout::Dropout;
|
||||
use tch::nn::{embedding, EmbeddingConfig};
|
||||
use tch::{nn, Kind, Tensor};
|
||||
|
||||
/// # Embeddings implementation for Albert model
|
||||
#[derive(Debug)]
|
||||
@ -34,49 +34,77 @@ impl AlbertEmbeddings {
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let word_embeddings: nn::Embedding = embedding(p / "word_embeddings",
|
||||
let word_embeddings: nn::Embedding = embedding(
|
||||
p / "word_embeddings",
|
||||
config.vocab_size,
|
||||
config.embedding_size,
|
||||
embedding_config);
|
||||
embedding_config,
|
||||
);
|
||||
|
||||
let position_embeddings: nn::Embedding = embedding(p / "position_embeddings",
|
||||
let position_embeddings: nn::Embedding = embedding(
|
||||
p / "position_embeddings",
|
||||
config.max_position_embeddings,
|
||||
config.embedding_size,
|
||||
Default::default());
|
||||
Default::default(),
|
||||
);
|
||||
|
||||
let token_type_embeddings: nn::Embedding = embedding(p / "token_type_embeddings",
|
||||
let token_type_embeddings: nn::Embedding = embedding(
|
||||
p / "token_type_embeddings",
|
||||
config.type_vocab_size,
|
||||
config.embedding_size,
|
||||
Default::default());
|
||||
Default::default(),
|
||||
);
|
||||
|
||||
let layer_norm_eps = match config.layer_norm_eps {
|
||||
Some(value) => value,
|
||||
None => 1e-12
|
||||
None => 1e-12,
|
||||
};
|
||||
let layer_norm_config = nn::LayerNormConfig { eps: layer_norm_eps, ..Default::default() };
|
||||
let layer_norm: nn::LayerNorm = nn::layer_norm(p / "LayerNorm", vec![config.embedding_size], layer_norm_config);
|
||||
let layer_norm_config = nn::LayerNormConfig {
|
||||
eps: layer_norm_eps,
|
||||
..Default::default()
|
||||
};
|
||||
let layer_norm: nn::LayerNorm = nn::layer_norm(
|
||||
p / "LayerNorm",
|
||||
vec![config.embedding_size],
|
||||
layer_norm_config,
|
||||
);
|
||||
let dropout: Dropout = Dropout::new(config.hidden_dropout_prob);
|
||||
AlbertEmbeddings { word_embeddings, position_embeddings, token_type_embeddings, layer_norm, dropout}
|
||||
AlbertEmbeddings {
|
||||
word_embeddings,
|
||||
position_embeddings,
|
||||
token_type_embeddings,
|
||||
layer_norm,
|
||||
dropout,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn forward_t(&self,
|
||||
pub fn forward_t(
|
||||
&self,
|
||||
input_ids: Option<Tensor>,
|
||||
token_type_ids: Option<Tensor>,
|
||||
position_ids: Option<Tensor>,
|
||||
input_embeds: Option<Tensor>,
|
||||
train: bool) -> Result<Tensor, &'static str> {
|
||||
train: bool,
|
||||
) -> Result<Tensor, &'static str> {
|
||||
let (input_embeddings, input_shape) = match input_ids {
|
||||
Some(input_value) => match input_embeds {
|
||||
Some(_) => { return Err("Only one of input ids or input embeddings may be set"); }
|
||||
None => (input_value.apply_t(&self.word_embeddings, train), input_value.size())
|
||||
Some(_) => {
|
||||
return Err("Only one of input ids or input embeddings may be set");
|
||||
}
|
||||
None => (
|
||||
input_value.apply_t(&self.word_embeddings, train),
|
||||
input_value.size(),
|
||||
),
|
||||
},
|
||||
None => match input_embeds {
|
||||
Some(embeds) => {
|
||||
let size = vec!(embeds.size()[0], embeds.size()[1]);
|
||||
let size = vec![embeds.size()[0], embeds.size()[1]];
|
||||
(embeds, size)
|
||||
},
|
||||
None => { return Err("Only one of input ids or input embeddings may be set"); }
|
||||
}
|
||||
None => {
|
||||
return Err("Only one of input ids or input embeddings may be set");
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
let seq_length = input_embeddings.as_ref().size()[1].to_owned();
|
||||
@ -84,19 +112,22 @@ impl AlbertEmbeddings {
|
||||
let position_ids = match position_ids {
|
||||
Some(value) => value,
|
||||
None => Tensor::arange(seq_length, (Kind::Int64, input_embeddings.device()))
|
||||
.unsqueeze(0).
|
||||
expand(&input_shape, true)
|
||||
.unsqueeze(0)
|
||||
.expand(&input_shape, true),
|
||||
};
|
||||
|
||||
let token_type_ids = match token_type_ids {
|
||||
Some(value) => value,
|
||||
None => Tensor::zeros(&input_shape, (Kind::Int64, input_embeddings.device()))
|
||||
None => Tensor::zeros(&input_shape, (Kind::Int64, input_embeddings.device())),
|
||||
};
|
||||
|
||||
let position_embeddings = position_ids.apply(&self.position_embeddings);
|
||||
let token_type_embeddings = token_type_ids.apply(&self.token_type_embeddings);
|
||||
|
||||
let input_embeddings: Tensor = input_embeddings + position_embeddings + token_type_embeddings;
|
||||
Ok(input_embeddings.apply(&self.layer_norm).apply_t(&self.dropout, train))
|
||||
let input_embeddings: Tensor =
|
||||
input_embeddings + position_embeddings + token_type_embeddings;
|
||||
Ok(input_embeddings
|
||||
.apply(&self.layer_norm)
|
||||
.apply_t(&self.dropout, train))
|
||||
}
|
||||
}
|
@ -11,12 +11,12 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use crate::albert::attention::AlbertSelfAttention;
|
||||
use tch::{nn, Tensor};
|
||||
use crate::albert::AlbertConfig;
|
||||
use crate::albert::albert::Activation;
|
||||
use crate::common::activations::{_gelu_new, _gelu, _relu, _mish};
|
||||
use crate::albert::attention::AlbertSelfAttention;
|
||||
use crate::albert::AlbertConfig;
|
||||
use crate::common::activations::{_gelu, _gelu_new, _mish, _relu};
|
||||
use std::borrow::BorrowMut;
|
||||
use tch::{nn, Tensor};
|
||||
|
||||
pub struct AlbertLayer {
|
||||
attention: AlbertSelfAttention,
|
||||
@ -32,29 +32,55 @@ impl AlbertLayer {
|
||||
|
||||
let layer_norm_eps = match config.layer_norm_eps {
|
||||
Some(value) => value,
|
||||
None => 1e-12
|
||||
None => 1e-12,
|
||||
};
|
||||
let layer_norm_config = nn::LayerNormConfig { eps: layer_norm_eps, ..Default::default() };
|
||||
let full_layer_layer_norm = nn::layer_norm(&(p / "full_layer_layer_norm"), vec![config.hidden_size], layer_norm_config);
|
||||
let layer_norm_config = nn::LayerNormConfig {
|
||||
eps: layer_norm_eps,
|
||||
..Default::default()
|
||||
};
|
||||
let full_layer_layer_norm = nn::layer_norm(
|
||||
&(p / "full_layer_layer_norm"),
|
||||
vec![config.hidden_size],
|
||||
layer_norm_config,
|
||||
);
|
||||
|
||||
let ffn = nn::linear(&(p / "ffn"), config.hidden_size, config.intermediate_size, Default::default());
|
||||
let ffn_output = nn::linear(&(p / "ffn_output"), config.intermediate_size, config.hidden_size, Default::default());
|
||||
let ffn = nn::linear(
|
||||
&(p / "ffn"),
|
||||
config.hidden_size,
|
||||
config.intermediate_size,
|
||||
Default::default(),
|
||||
);
|
||||
let ffn_output = nn::linear(
|
||||
&(p / "ffn_output"),
|
||||
config.intermediate_size,
|
||||
config.hidden_size,
|
||||
Default::default(),
|
||||
);
|
||||
|
||||
let activation = Box::new(match &config.hidden_act {
|
||||
Activation::gelu_new => _gelu_new,
|
||||
Activation::gelu => _gelu,
|
||||
Activation::relu => _relu,
|
||||
Activation::mish => _mish
|
||||
Activation::mish => _mish,
|
||||
});
|
||||
|
||||
AlbertLayer { attention, full_layer_layer_norm, ffn, ffn_output, activation }
|
||||
AlbertLayer {
|
||||
attention,
|
||||
full_layer_layer_norm,
|
||||
ffn,
|
||||
ffn_output,
|
||||
activation,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn forward_t(&self,
|
||||
pub fn forward_t(
|
||||
&self,
|
||||
hidden_states: &Tensor,
|
||||
mask: &Option<Tensor>,
|
||||
train: bool) -> (Tensor, Option<Tensor>) {
|
||||
let (attention_output, attention_weights) = self.attention.forward_t(hidden_states, mask, train);
|
||||
train: bool,
|
||||
) -> (Tensor, Option<Tensor>) {
|
||||
let (attention_output, attention_weights) =
|
||||
self.attention.forward_t(hidden_states, mask, train);
|
||||
let ffn_output = attention_output.apply(&self.ffn);
|
||||
let ffn_output: Tensor = (self.activation)(&ffn_output);
|
||||
let ffn_output = ffn_output.apply(&self.ffn_output);
|
||||
@ -76,29 +102,42 @@ impl AlbertLayerGroup {
|
||||
|
||||
let output_attentions = match config.output_attentions {
|
||||
Some(value) => value,
|
||||
None => false
|
||||
None => false,
|
||||
};
|
||||
|
||||
let output_hidden_states = match config.output_hidden_states {
|
||||
Some(value) => value,
|
||||
None => false
|
||||
None => false,
|
||||
};
|
||||
|
||||
let mut layers: Vec<AlbertLayer> = vec!();
|
||||
let mut layers: Vec<AlbertLayer> = vec![];
|
||||
for layer_index in 0..config.inner_group_num {
|
||||
layers.push(AlbertLayer::new(&(p / layer_index), config));
|
||||
};
|
||||
|
||||
AlbertLayerGroup { output_hidden_states, output_attentions, layers }
|
||||
}
|
||||
|
||||
pub fn forward_t(&self,
|
||||
AlbertLayerGroup {
|
||||
output_hidden_states,
|
||||
output_attentions,
|
||||
layers,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn forward_t(
|
||||
&self,
|
||||
hidden_states: &Tensor,
|
||||
mask: &Option<Tensor>,
|
||||
train: bool)
|
||||
-> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
|
||||
let mut all_hidden_states: Option<Vec<Tensor>> = if self.output_hidden_states { Some(vec!()) } else { None };
|
||||
let mut all_attentions: Option<Vec<Tensor>> = if self.output_attentions { Some(vec!()) } else { None };
|
||||
train: bool,
|
||||
) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
|
||||
let mut all_hidden_states: Option<Vec<Tensor>> = if self.output_hidden_states {
|
||||
Some(vec![])
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let mut all_attentions: Option<Vec<Tensor>> = if self.output_attentions {
|
||||
Some(vec![])
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let mut hidden_state = hidden_states.copy();
|
||||
let mut attention_weights: Option<Tensor>;
|
||||
@ -117,9 +156,9 @@ impl AlbertLayerGroup {
|
||||
attentions.push(attention_weights.as_ref().unwrap().copy());
|
||||
};
|
||||
}
|
||||
None => break
|
||||
};
|
||||
None => break,
|
||||
};
|
||||
}
|
||||
|
||||
(hidden_state, all_hidden_states, all_attentions)
|
||||
}
|
||||
@ -140,20 +179,25 @@ impl AlbertTransformer {
|
||||
|
||||
let output_attentions = match config.output_attentions {
|
||||
Some(value) => value,
|
||||
None => false
|
||||
None => false,
|
||||
};
|
||||
|
||||
let output_hidden_states = match config.output_hidden_states {
|
||||
Some(value) => value,
|
||||
None => false
|
||||
None => false,
|
||||
};
|
||||
|
||||
let embedding_hidden_mapping_in = nn::linear(&(p / "embedding_hidden_mapping_in"), config.embedding_size, config.hidden_size, Default::default());
|
||||
let embedding_hidden_mapping_in = nn::linear(
|
||||
&(p / "embedding_hidden_mapping_in"),
|
||||
config.embedding_size,
|
||||
config.hidden_size,
|
||||
Default::default(),
|
||||
);
|
||||
|
||||
let mut layers: Vec<AlbertLayerGroup> = vec!();
|
||||
let mut layers: Vec<AlbertLayerGroup> = vec![];
|
||||
for layer_index in 0..config.inner_group_num {
|
||||
layers.push(AlbertLayerGroup::new(&(p_layers / layer_index), config));
|
||||
};
|
||||
}
|
||||
|
||||
AlbertTransformer {
|
||||
output_hidden_states,
|
||||
@ -165,16 +209,24 @@ impl AlbertTransformer {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn forward_t(&self,
|
||||
pub fn forward_t(
|
||||
&self,
|
||||
hidden_states: &Tensor,
|
||||
mask: Option<Tensor>,
|
||||
train: bool)
|
||||
-> (Tensor, Option<Vec<Tensor>>, Option<Vec<Vec<Tensor>>>) {
|
||||
train: bool,
|
||||
) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Vec<Tensor>>>) {
|
||||
let mut hidden_state = hidden_states.apply(&self.embedding_hidden_mapping_in);
|
||||
|
||||
let mut all_hidden_states: Option<Vec<Tensor>> = if self.output_hidden_states { Some(vec!()) } else { None };
|
||||
let mut all_attentions: Option<Vec<Vec<Tensor>>> = if self.output_attentions { Some(vec!()) } else { None };
|
||||
|
||||
let mut all_hidden_states: Option<Vec<Tensor>> = if self.output_hidden_states {
|
||||
Some(vec![])
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let mut all_attentions: Option<Vec<Vec<Tensor>>> = if self.output_attentions {
|
||||
Some(vec![])
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
for i in 0..self.num_hidden_layers {
|
||||
let group_idx = i / (self.num_hidden_layers / self.num_hidden_groups);
|
||||
@ -190,9 +242,8 @@ impl AlbertTransformer {
|
||||
if let Some(attentions) = all_attentions.borrow_mut() {
|
||||
attentions.push(attention_weights.unwrap());
|
||||
};
|
||||
};
|
||||
}
|
||||
|
||||
(hidden_state, all_hidden_states, all_attentions)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -25,19 +25,26 @@
|
||||
//! use rust_tokenizers::AlbertTokenizer;
|
||||
//! use tch::{nn, Device};
|
||||
//! # use std::path::PathBuf;
|
||||
//! use rust_bert::albert::{AlbertForMaskedLM, AlbertConfig};
|
||||
//! use rust_bert::albert::{AlbertConfig, AlbertForMaskedLM};
|
||||
//! use rust_bert::resources::{download_resource, LocalResource, Resource};
|
||||
//! use rust_bert::Config;
|
||||
//! use rust_bert::resources::{Resource, download_resource, LocalResource};
|
||||
//!
|
||||
//! let config_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/config.json")});
|
||||
//! let vocab_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/vocab.txt")});
|
||||
//! let weights_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/model.ot")});
|
||||
//! let config_resource = Resource::Local(LocalResource {
|
||||
//! local_path: PathBuf::from("path/to/config.json"),
|
||||
//! });
|
||||
//! let vocab_resource = Resource::Local(LocalResource {
|
||||
//! local_path: PathBuf::from("path/to/vocab.txt"),
|
||||
//! });
|
||||
//! let weights_resource = Resource::Local(LocalResource {
|
||||
//! local_path: PathBuf::from("path/to/model.ot"),
|
||||
//! });
|
||||
//! let config_path = download_resource(&config_resource)?;
|
||||
//! let vocab_path = download_resource(&vocab_resource)?;
|
||||
//! let weights_path = download_resource(&weights_resource)?;
|
||||
//! let device = Device::cuda_if_available();
|
||||
//! let mut vs = nn::VarStore::new(device);
|
||||
//! let tokenizer: AlbertTokenizer = AlbertTokenizer::from_file(vocab_path.to_str().unwrap(), true, true);
|
||||
//! let tokenizer: AlbertTokenizer =
|
||||
//! AlbertTokenizer::from_file(vocab_path.to_str().unwrap(), true, true);
|
||||
//! let config = AlbertConfig::from_file(config_path);
|
||||
//! let bert_model = AlbertForMaskedLM::new(&vs.root(), &config);
|
||||
//! vs.load(weights_path)?;
|
||||
@ -46,11 +53,13 @@
|
||||
//! # }
|
||||
//! ```
|
||||
|
||||
|
||||
|
||||
mod encoder;
|
||||
mod albert;
|
||||
mod attention;
|
||||
mod embeddings;
|
||||
mod albert;
|
||||
mod encoder;
|
||||
|
||||
pub use albert::{AlbertConfig, AlbertModelResources, AlbertConfigResources, AlbertVocabResources, AlbertModel, AlbertForMaskedLM, AlbertForSequenceClassification, AlbertForTokenClassification, AlbertForQuestionAnswering, AlbertForMultipleChoice};
|
||||
pub use albert::{
|
||||
AlbertConfig, AlbertConfigResources, AlbertForMaskedLM, AlbertForMultipleChoice,
|
||||
AlbertForQuestionAnswering, AlbertForSequenceClassification, AlbertForTokenClassification,
|
||||
AlbertModel, AlbertModelResources, AlbertVocabResources,
|
||||
};
|
||||
|
@ -12,8 +12,8 @@
|
||||
// limitations under the License.
|
||||
|
||||
use crate::common::dropout::Dropout;
|
||||
use tch::{nn, Tensor};
|
||||
use tch::kind::Kind::Float;
|
||||
use tch::{nn, Tensor};
|
||||
|
||||
#[derive(Debug)]
|
||||
/// # Cache for BART attention layers
|
||||
@ -31,7 +31,7 @@ impl Clone for LayerState {
|
||||
fn clone(&self) -> Self {
|
||||
let prev_key_padding_mask = match &self.prev_key_padding_mask {
|
||||
Some(key_padding_mask) => Some(key_padding_mask.copy()),
|
||||
None => None
|
||||
None => None,
|
||||
};
|
||||
LayerState {
|
||||
prev_key: self.prev_key.copy(),
|
||||
@ -46,12 +46,16 @@ impl LayerState {
|
||||
self.prev_key = self.prev_key.index_select(0, new_indices);
|
||||
self.prev_value = self.prev_value.index_select(0, new_indices);
|
||||
if self.prev_key_padding_mask.is_some() {
|
||||
self.prev_key_padding_mask = Some(self.prev_key_padding_mask.as_ref().unwrap().index_select(0, new_indices));
|
||||
self.prev_key_padding_mask = Some(
|
||||
self.prev_key_padding_mask
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.index_select(0, new_indices),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct SelfAttention {
|
||||
num_heads: i64,
|
||||
@ -68,8 +72,15 @@ pub struct SelfAttention {
|
||||
}
|
||||
|
||||
impl SelfAttention {
|
||||
pub fn new(p: nn::Path, embed_dim: i64, num_heads: i64, dropout: f64,
|
||||
encoder_decoder_attention: bool, store_cache: bool, output_attentions: bool) -> SelfAttention {
|
||||
pub fn new(
|
||||
p: nn::Path,
|
||||
embed_dim: i64,
|
||||
num_heads: i64,
|
||||
dropout: f64,
|
||||
encoder_decoder_attention: bool,
|
||||
store_cache: bool,
|
||||
output_attentions: bool,
|
||||
) -> SelfAttention {
|
||||
let k_proj = nn::linear(&p / "k_proj", embed_dim, embed_dim, Default::default());
|
||||
let v_proj = nn::linear(&p / "v_proj", embed_dim, embed_dim, Default::default());
|
||||
let q_proj = nn::linear(&p / "q_proj", embed_dim, embed_dim, Default::default());
|
||||
@ -95,58 +106,90 @@ impl SelfAttention {
|
||||
}
|
||||
|
||||
fn flatten(&self, x: Tensor, dim_0: i64, bs: i64) -> Tensor {
|
||||
x.contiguous().view((dim_0, bs * self.num_heads, self.head_dim)).transpose(0, 1)
|
||||
x.contiguous()
|
||||
.view((dim_0, bs * self.num_heads, self.head_dim))
|
||||
.transpose(0, 1)
|
||||
}
|
||||
|
||||
pub fn forward_t(&self, query: &Tensor,
|
||||
pub fn forward_t(
|
||||
&self,
|
||||
query: &Tensor,
|
||||
key: Option<&Tensor>,
|
||||
key_padding_mask: Option<&Tensor>,
|
||||
attention_mask: Option<&Tensor>,
|
||||
mut layer_state: Option<LayerState>,
|
||||
train: bool) -> (Tensor, Option<Tensor>, Option<LayerState>) {
|
||||
train: bool,
|
||||
) -> (Tensor, Option<Tensor>, Option<LayerState>) {
|
||||
let query_size = query.size();
|
||||
let (target_sequence_length, bs) = (query_size[0], query_size[1]);
|
||||
let q: Tensor = self.flatten(query.as_ref().apply(&self.q_proj) * self.scaling, target_sequence_length, bs);
|
||||
let q: Tensor = self.flatten(
|
||||
query.as_ref().apply(&self.q_proj) * self.scaling,
|
||||
target_sequence_length,
|
||||
bs,
|
||||
);
|
||||
let key = match &layer_state {
|
||||
Some(_) => { if self.encoder_decoder_attention { None } else { key } }
|
||||
None => key
|
||||
Some(_) => {
|
||||
if self.encoder_decoder_attention {
|
||||
None
|
||||
} else {
|
||||
key
|
||||
}
|
||||
}
|
||||
None => key,
|
||||
};
|
||||
|
||||
let (k, v) = if self.encoder_decoder_attention {
|
||||
match key {
|
||||
Some(key) => {
|
||||
(Some(self.flatten(key.apply(&self.k_proj), -1, bs)),
|
||||
Some(self.flatten(key.apply(&self.v_proj), -1, bs))
|
||||
)
|
||||
}
|
||||
None => (None, None)
|
||||
Some(key) => (
|
||||
Some(self.flatten(key.apply(&self.k_proj), -1, bs)),
|
||||
Some(self.flatten(key.apply(&self.v_proj), -1, bs)),
|
||||
),
|
||||
None => (None, None),
|
||||
}
|
||||
} else {
|
||||
(Some(self.flatten(query.apply(&self.k_proj), -1, bs)),
|
||||
Some(self.flatten(query.apply(&self.v_proj), -1, bs))
|
||||
(
|
||||
Some(self.flatten(query.apply(&self.k_proj), -1, bs)),
|
||||
Some(self.flatten(query.apply(&self.v_proj), -1, bs)),
|
||||
)
|
||||
};
|
||||
|
||||
let (k, v, key_padding_mask) = self.use_saved_state(&layer_state, k, v, key_padding_mask, bs);
|
||||
let (k, v, key_padding_mask) =
|
||||
self.use_saved_state(&layer_state, k, v, key_padding_mask, bs);
|
||||
|
||||
let source_sequence_length = k.size()[1];
|
||||
let attention_weights = q.bmm(&k.transpose(1, 2));
|
||||
let attention_weights = match attention_mask {
|
||||
Some(mask) => {
|
||||
let attention_weights = attention_weights.view((bs, self.num_heads, target_sequence_length, source_sequence_length)) + mask;
|
||||
attention_weights.view((bs * self.num_heads, target_sequence_length, source_sequence_length))
|
||||
let attention_weights = attention_weights.view((
|
||||
bs,
|
||||
self.num_heads,
|
||||
target_sequence_length,
|
||||
source_sequence_length,
|
||||
)) + mask;
|
||||
attention_weights.view((
|
||||
bs * self.num_heads,
|
||||
target_sequence_length,
|
||||
source_sequence_length,
|
||||
))
|
||||
}
|
||||
None => attention_weights
|
||||
None => attention_weights,
|
||||
};
|
||||
|
||||
let attention_weights = match key_padding_mask.as_ref() {
|
||||
Some(mask) => {
|
||||
attention_weights
|
||||
.view((bs, self.num_heads, target_sequence_length, source_sequence_length))
|
||||
Some(mask) => attention_weights
|
||||
.view((
|
||||
bs,
|
||||
self.num_heads,
|
||||
target_sequence_length,
|
||||
source_sequence_length,
|
||||
))
|
||||
.masked_fill(&mask.unsqueeze(1).unsqueeze(2), std::f64::NEG_INFINITY)
|
||||
.view((bs * self.num_heads, target_sequence_length, source_sequence_length))
|
||||
}
|
||||
None => attention_weights
|
||||
.view((
|
||||
bs * self.num_heads,
|
||||
target_sequence_length,
|
||||
source_sequence_length,
|
||||
)),
|
||||
None => attention_weights,
|
||||
};
|
||||
|
||||
let attention_weights = attention_weights.softmax(-1, Float);
|
||||
@ -159,16 +202,25 @@ impl SelfAttention {
|
||||
.apply(&self.out_proj);
|
||||
|
||||
let attention_weights = if self.output_attentions {
|
||||
Some(attention_weights.view((bs, self.num_heads, target_sequence_length, source_sequence_length)))
|
||||
} else { None };
|
||||
Some(attention_weights.view((
|
||||
bs,
|
||||
self.num_heads,
|
||||
target_sequence_length,
|
||||
source_sequence_length,
|
||||
)))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
if self.store_cache {
|
||||
if layer_state.is_some() {
|
||||
layer_state.as_mut().unwrap().prev_key = k.view((bs, self.num_heads, -1, self.head_dim));
|
||||
layer_state.as_mut().unwrap().prev_value = v.view((bs, self.num_heads, -1, self.head_dim));
|
||||
layer_state.as_mut().unwrap().prev_key =
|
||||
k.view((bs, self.num_heads, -1, self.head_dim));
|
||||
layer_state.as_mut().unwrap().prev_value =
|
||||
v.view((bs, self.num_heads, -1, self.head_dim));
|
||||
layer_state.as_mut().unwrap().prev_key_padding_mask = match key_padding_mask {
|
||||
Some(tensor) => Some(tensor),
|
||||
None => None
|
||||
None => None,
|
||||
};
|
||||
} else {
|
||||
layer_state = Some(LayerState {
|
||||
@ -176,7 +228,7 @@ impl SelfAttention {
|
||||
prev_value: v.view((bs, self.num_heads, -1, self.head_dim)),
|
||||
prev_key_padding_mask: match key_padding_mask {
|
||||
Some(tensor) => Some(tensor),
|
||||
None => None
|
||||
None => None,
|
||||
},
|
||||
})
|
||||
};
|
||||
@ -185,17 +237,23 @@ impl SelfAttention {
|
||||
(output, attention_weights, layer_state)
|
||||
}
|
||||
|
||||
fn use_saved_state(&self,
|
||||
fn use_saved_state(
|
||||
&self,
|
||||
layer_state: &Option<LayerState>,
|
||||
k: Option<Tensor>,
|
||||
v: Option<Tensor>,
|
||||
key_padding_mask: Option<&Tensor>,
|
||||
bs: i64)
|
||||
-> (Tensor, Tensor, Option<Tensor>) {
|
||||
bs: i64,
|
||||
) -> (Tensor, Tensor, Option<Tensor>) {
|
||||
match &layer_state {
|
||||
Some(prev_state) => {
|
||||
let prev_key = prev_state.prev_key.view((bs * self.num_heads, -1, self.head_dim));
|
||||
let prev_value = prev_state.prev_value.view((bs * self.num_heads, -1, self.head_dim));
|
||||
let prev_key = prev_state
|
||||
.prev_key
|
||||
.view((bs * self.num_heads, -1, self.head_dim));
|
||||
let prev_value =
|
||||
prev_state
|
||||
.prev_value
|
||||
.view((bs * self.num_heads, -1, self.head_dim));
|
||||
let k = if self.encoder_decoder_attention {
|
||||
prev_key
|
||||
} else {
|
||||
@ -207,38 +265,53 @@ impl SelfAttention {
|
||||
Tensor::cat(&[prev_value, v.unwrap()], 1)
|
||||
};
|
||||
|
||||
let key_padding_mask = self.use_saved_key_padding_mask(key_padding_mask,
|
||||
let key_padding_mask = self.use_saved_key_padding_mask(
|
||||
key_padding_mask,
|
||||
&prev_state.prev_key_padding_mask,
|
||||
bs,
|
||||
k.size()[1]);
|
||||
k.size()[1],
|
||||
);
|
||||
(k, v, key_padding_mask)
|
||||
}
|
||||
None => {
|
||||
let key_padding_mask = match key_padding_mask {
|
||||
Some(value) => Some(value.copy()),
|
||||
None => None
|
||||
None => None,
|
||||
};
|
||||
(k.unwrap(), v.unwrap(), key_padding_mask)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn use_saved_key_padding_mask(&self, key_padding_mask: Option<&Tensor>, prev_key_padding_mask: &Option<Tensor>,
|
||||
bs: i64, sequence_length: i64) -> Option<Tensor> {
|
||||
fn use_saved_key_padding_mask(
|
||||
&self,
|
||||
key_padding_mask: Option<&Tensor>,
|
||||
prev_key_padding_mask: &Option<Tensor>,
|
||||
bs: i64,
|
||||
sequence_length: i64,
|
||||
) -> Option<Tensor> {
|
||||
if prev_key_padding_mask.is_some() {
|
||||
if self.encoder_decoder_attention {
|
||||
Some(prev_key_padding_mask.as_ref().unwrap().copy())
|
||||
} else {
|
||||
Some(Tensor::cat(&[prev_key_padding_mask.as_ref().unwrap(), key_padding_mask.as_ref().unwrap()], 1))
|
||||
Some(Tensor::cat(
|
||||
&[
|
||||
prev_key_padding_mask.as_ref().unwrap(),
|
||||
key_padding_mask.as_ref().unwrap(),
|
||||
],
|
||||
1,
|
||||
))
|
||||
}
|
||||
} else {
|
||||
match key_padding_mask {
|
||||
Some(key_padding_mask) => {
|
||||
let filler = Tensor::zeros(&[bs, sequence_length - key_padding_mask.size()[1]],
|
||||
(key_padding_mask.kind(), key_padding_mask.device()));
|
||||
let filler = Tensor::zeros(
|
||||
&[bs, sequence_length - key_padding_mask.size()[1]],
|
||||
(key_padding_mask.kind(), key_padding_mask.device()),
|
||||
);
|
||||
Some(Tensor::cat(&[filler, key_padding_mask.copy()], 1))
|
||||
}
|
||||
None => None
|
||||
None => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
425
src/bart/bart.rs
425
src/bart/bart.rs
@ -11,18 +11,18 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use crate::Config;
|
||||
use tch::{Tensor, nn};
|
||||
use tch::kind::Kind::{Int64, Float};
|
||||
use crate::bart::encoder::BartEncoder;
|
||||
use crate::bart::decoder::BartDecoder;
|
||||
use tch::nn::{embedding, EmbeddingConfig};
|
||||
use crate::bart::attention::LayerState;
|
||||
use std::borrow::BorrowMut;
|
||||
use crate::bart::decoder::BartDecoder;
|
||||
use crate::bart::encoder::BartEncoder;
|
||||
use crate::common::dropout::Dropout;
|
||||
use crate::pipelines::generation::{Cache, LMHeadModel};
|
||||
use crate::Config;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::borrow::BorrowMut;
|
||||
use std::collections::HashMap;
|
||||
use tch::kind::Kind::{Float, Int64};
|
||||
use tch::nn::{embedding, EmbeddingConfig};
|
||||
use tch::{nn, Tensor};
|
||||
|
||||
/// # BART Pretrained model weight files
|
||||
pub struct BartModelResources;
|
||||
@ -38,38 +38,74 @@ pub struct BartMergesResources;
|
||||
|
||||
impl BartModelResources {
|
||||
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
|
||||
pub const BART: (&'static str, &'static str) = ("bart/model.ot", "https://cdn.huggingface.co/facebook/bart-large/rust_model.ot");
|
||||
pub const BART: (&'static str, &'static str) = (
|
||||
"bart/model.ot",
|
||||
"https://cdn.huggingface.co/facebook/bart-large/rust_model.ot",
|
||||
);
|
||||
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
|
||||
pub const BART_CNN: (&'static str, &'static str) = ("bart-cnn/model.ot", "https://cdn.huggingface.co/facebook/bart-large-cnn/rust_model.ot");
|
||||
pub const BART_CNN: (&'static str, &'static str) = (
|
||||
"bart-cnn/model.ot",
|
||||
"https://cdn.huggingface.co/facebook/bart-large-cnn/rust_model.ot",
|
||||
);
|
||||
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
|
||||
pub const BART_XSUM: (&'static str, &'static str) = ("bart-xsum/model.ot", "https://cdn.huggingface.co/facebook/bart-large-xsum/rust_model.ot");
|
||||
pub const BART_XSUM: (&'static str, &'static str) = (
|
||||
"bart-xsum/model.ot",
|
||||
"https://cdn.huggingface.co/facebook/bart-large-xsum/rust_model.ot",
|
||||
);
|
||||
}
|
||||
|
||||
impl BartConfigResources {
|
||||
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
|
||||
pub const BART: (&'static str, &'static str) = ("bart/config.json", "https://cdn.huggingface.co/facebook/bart-large/config.json");
|
||||
pub const BART: (&'static str, &'static str) = (
|
||||
"bart/config.json",
|
||||
"https://cdn.huggingface.co/facebook/bart-large/config.json",
|
||||
);
|
||||
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
|
||||
pub const BART_CNN: (&'static str, &'static str) = ("bart-cnn/config.json", "https://cdn.huggingface.co/facebook/bart-large-cnn/config.json");
|
||||
pub const BART_CNN: (&'static str, &'static str) = (
|
||||
"bart-cnn/config.json",
|
||||
"https://cdn.huggingface.co/facebook/bart-large-cnn/config.json",
|
||||
);
|
||||
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
|
||||
pub const BART_XSUM: (&'static str, &'static str) = ("bart-xsum/config.json", "https://cdn.huggingface.co/facebook/bart-large-xsum/config.json");
|
||||
pub const BART_XSUM: (&'static str, &'static str) = (
|
||||
"bart-xsum/config.json",
|
||||
"https://cdn.huggingface.co/facebook/bart-large-xsum/config.json",
|
||||
);
|
||||
}
|
||||
|
||||
impl BartVocabResources {
|
||||
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
|
||||
pub const BART: (&'static str, &'static str) = ("bart/vocab.txt", "https://cdn.huggingface.co/roberta-large-vocab.json");
|
||||
pub const BART: (&'static str, &'static str) = (
|
||||
"bart/vocab.txt",
|
||||
"https://cdn.huggingface.co/roberta-large-vocab.json",
|
||||
);
|
||||
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
|
||||
pub const BART_CNN: (&'static str, &'static str) = ("bart-cnn/vocab.txt", "https://cdn.huggingface.co/roberta-large-vocab.json");
|
||||
pub const BART_CNN: (&'static str, &'static str) = (
|
||||
"bart-cnn/vocab.txt",
|
||||
"https://cdn.huggingface.co/roberta-large-vocab.json",
|
||||
);
|
||||
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
|
||||
pub const BART_XSUM: (&'static str, &'static str) = ("bart-xsum/vocab.txt", "https://cdn.huggingface.co/roberta-large-vocab.json");
|
||||
pub const BART_XSUM: (&'static str, &'static str) = (
|
||||
"bart-xsum/vocab.txt",
|
||||
"https://cdn.huggingface.co/roberta-large-vocab.json",
|
||||
);
|
||||
}
|
||||
|
||||
impl BartMergesResources {
|
||||
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
|
||||
pub const BART: (&'static str, &'static str) = ("bart/merges.txt", "https://cdn.huggingface.co/roberta-large-merges.txt");
|
||||
pub const BART: (&'static str, &'static str) = (
|
||||
"bart/merges.txt",
|
||||
"https://cdn.huggingface.co/roberta-large-merges.txt",
|
||||
);
|
||||
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
|
||||
pub const BART_CNN: (&'static str, &'static str) = ("bart-cnn/merges.txt", "https://cdn.huggingface.co/roberta-large-merges.txt");
|
||||
pub const BART_CNN: (&'static str, &'static str) = (
|
||||
"bart-cnn/merges.txt",
|
||||
"https://cdn.huggingface.co/roberta-large-merges.txt",
|
||||
);
|
||||
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
|
||||
pub const BART_XSUM: (&'static str, &'static str) = ("bart-xsum/merges.txt", "https://cdn.huggingface.co/roberta-large-merges.txt");
|
||||
pub const BART_XSUM: (&'static str, &'static str) = (
|
||||
"bart-xsum/merges.txt",
|
||||
"https://cdn.huggingface.co/roberta-large-merges.txt",
|
||||
);
|
||||
}
|
||||
|
||||
#[allow(non_camel_case_types)]
|
||||
@ -130,14 +166,15 @@ pub struct BartConfig {
|
||||
|
||||
impl Config<BartConfig> for BartConfig {}
|
||||
|
||||
fn _prepare_bart_decoder_inputs(pad_token_id: i64,
|
||||
fn _prepare_bart_decoder_inputs(
|
||||
pad_token_id: i64,
|
||||
input_ids: &Tensor,
|
||||
decoder_input_ids: Option<&Tensor>,
|
||||
decoder_padding_mask: Option<&Tensor>)
|
||||
-> (Tensor, Option<Tensor>, Option<Tensor>) {
|
||||
decoder_padding_mask: Option<&Tensor>,
|
||||
) -> (Tensor, Option<Tensor>, Option<Tensor>) {
|
||||
let decoder_input_ids = match decoder_input_ids {
|
||||
Some(value) => value.copy(),
|
||||
None => _shift_tokens_right(input_ids, pad_token_id)
|
||||
None => _shift_tokens_right(input_ids, pad_token_id),
|
||||
};
|
||||
|
||||
let decoder_padding_mask = match decoder_padding_mask {
|
||||
@ -159,7 +196,6 @@ fn _prepare_bart_decoder_inputs(pad_token_id: i64,
|
||||
(decoder_input_ids, decoder_padding_mask, Some(causal_mask))
|
||||
}
|
||||
|
||||
|
||||
fn _shift_tokens_right(input_ids: &Tensor, pad_token_id: i64) -> Tensor {
|
||||
let index_eos: Tensor = input_ids.ne(pad_token_id).sum1(&[-1], true, Int64) - 1;
|
||||
let output = input_ids.empty_like().to_kind(Int64);
|
||||
@ -200,10 +236,10 @@ impl BartModel {
|
||||
/// # Example
|
||||
///
|
||||
/// ```no_run
|
||||
/// use tch::{nn, Device};
|
||||
/// use rust_bert::bart::{BartConfig, BartModel};
|
||||
/// use rust_bert::Config;
|
||||
/// use std::path::Path;
|
||||
/// use rust_bert::bart::{BartConfig, BartModel};
|
||||
/// use tch::{nn, Device};
|
||||
///
|
||||
/// let config_path = Path::new("path/to/config.json");
|
||||
/// let device = Device::Cpu;
|
||||
@ -212,22 +248,32 @@ impl BartModel {
|
||||
/// let generation_mode = true;
|
||||
/// let bart: BartModel = BartModel::new(&(&p.root() / "bart"), &config, generation_mode);
|
||||
/// ```
|
||||
///
|
||||
pub fn new(p: &nn::Path, config: &BartConfig, generation_mode: bool) -> BartModel {
|
||||
let pad_token_id = match config.pad_token_id {
|
||||
Some(value) => value,
|
||||
None => 1
|
||||
None => 1,
|
||||
};
|
||||
let embedding_config = EmbeddingConfig { padding_idx: pad_token_id, ..Default::default() };
|
||||
let embeddings: nn::Embedding = embedding(p / "shared",
|
||||
let embedding_config = EmbeddingConfig {
|
||||
padding_idx: pad_token_id,
|
||||
..Default::default()
|
||||
};
|
||||
let embeddings: nn::Embedding = embedding(
|
||||
p / "shared",
|
||||
config.vocab_size,
|
||||
config.d_model,
|
||||
embedding_config);
|
||||
embedding_config,
|
||||
);
|
||||
|
||||
let encoder = BartEncoder::new(p / "encoder", config);
|
||||
let decoder = BartDecoder::new(p / "decoder", config, generation_mode);
|
||||
|
||||
BartModel { encoder, decoder, generation_mode, pad_token_id, embeddings }
|
||||
BartModel {
|
||||
encoder,
|
||||
decoder,
|
||||
generation_mode,
|
||||
pad_token_id,
|
||||
embeddings,
|
||||
}
|
||||
}
|
||||
|
||||
/// Forward pass through the model
|
||||
@ -271,68 +317,102 @@ impl BartModel {
|
||||
/// let (batch_size, source_sequence_length, target_sequence_length) = (64, 128, 56);
|
||||
/// let input_tensor = Tensor::rand(&[batch_size, source_sequence_length], (Int64, device));
|
||||
/// let target_tensor = Tensor::rand(&[batch_size, target_sequence_length], (Int64, device));
|
||||
/// let encoder_attention_mask = Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
|
||||
/// let decoder_attention_mask = Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
|
||||
/// let encoder_attention_mask =
|
||||
/// Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
|
||||
/// let decoder_attention_mask =
|
||||
/// Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
|
||||
///
|
||||
/// let (decoder_output, encoder_hidden_states, decoder_cache,
|
||||
/// all_encoder_hidden_states, all_encoder_attentions,
|
||||
/// all_decoder_hidden_states, all_decoder_attentions) = no_grad(|| {
|
||||
/// bart_model
|
||||
/// .forward_t(Some(&input_tensor),
|
||||
/// let (
|
||||
/// decoder_output,
|
||||
/// encoder_hidden_states,
|
||||
/// decoder_cache,
|
||||
/// all_encoder_hidden_states,
|
||||
/// all_encoder_attentions,
|
||||
/// all_decoder_hidden_states,
|
||||
/// all_decoder_attentions,
|
||||
/// ) = no_grad(|| {
|
||||
/// bart_model.forward_t(
|
||||
/// Some(&input_tensor),
|
||||
/// Some(&encoder_attention_mask),
|
||||
/// Some(&target_tensor),
|
||||
/// None,
|
||||
/// Some(&decoder_attention_mask),
|
||||
/// None,
|
||||
/// false)
|
||||
/// false,
|
||||
/// )
|
||||
/// });
|
||||
///
|
||||
/// ```
|
||||
///
|
||||
pub fn forward_t(&self,
|
||||
pub fn forward_t(
|
||||
&self,
|
||||
input_ids: Option<&Tensor>,
|
||||
attention_mask: Option<&Tensor>,
|
||||
decoder_input_ids: Option<&Tensor>,
|
||||
encoder_outputs: Option<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>)>,
|
||||
decoder_attention_mask: Option<&Tensor>,
|
||||
layer_states: Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
|
||||
train: bool) ->
|
||||
(Tensor, Tensor, Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
|
||||
Option<Vec<Tensor>>, Option<Vec<Tensor>>,
|
||||
Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
|
||||
train: bool,
|
||||
) -> (
|
||||
Tensor,
|
||||
Tensor,
|
||||
Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
|
||||
Option<Vec<Tensor>>,
|
||||
Option<Vec<Tensor>>,
|
||||
Option<Vec<Tensor>>,
|
||||
Option<Vec<Tensor>>,
|
||||
) {
|
||||
let (decoder_input_ids, decoder_padding_mask, causal_mask) = if self.generation_mode {
|
||||
(decoder_input_ids.unwrap().copy(), None, None)
|
||||
} else {
|
||||
assert!(input_ids.is_some(), "input_ids must be provided when not in generation mode");
|
||||
_prepare_bart_decoder_inputs(self.pad_token_id, input_ids.unwrap(), decoder_input_ids, decoder_attention_mask)
|
||||
assert!(
|
||||
input_ids.is_some(),
|
||||
"input_ids must be provided when not in generation mode"
|
||||
);
|
||||
_prepare_bart_decoder_inputs(
|
||||
self.pad_token_id,
|
||||
input_ids.unwrap(),
|
||||
decoder_input_ids,
|
||||
decoder_attention_mask,
|
||||
)
|
||||
};
|
||||
|
||||
let (encoder_hidden_states,
|
||||
all_encoder_hidden_states,
|
||||
all_encoder_attentions) = match encoder_outputs {
|
||||
let (encoder_hidden_states, all_encoder_hidden_states, all_encoder_attentions) =
|
||||
match encoder_outputs {
|
||||
Some(value) => value,
|
||||
None => {
|
||||
assert!(input_ids.is_some(), "input_ids must be provided when encoder output is not pre-computed");
|
||||
self.encoder.forward_t(input_ids.unwrap(), attention_mask, &self.embeddings, train)
|
||||
assert!(
|
||||
input_ids.is_some(),
|
||||
"input_ids must be provided when encoder output is not pre-computed"
|
||||
);
|
||||
self.encoder.forward_t(
|
||||
input_ids.unwrap(),
|
||||
attention_mask,
|
||||
&self.embeddings,
|
||||
train,
|
||||
)
|
||||
}
|
||||
};
|
||||
|
||||
let (decoder_outputs,
|
||||
decoder_cache,
|
||||
all_decoder_hidden_states,
|
||||
all_decoder_attentions) = self.decoder.forward_t(&decoder_input_ids,
|
||||
let (decoder_outputs, decoder_cache, all_decoder_hidden_states, all_decoder_attentions) =
|
||||
self.decoder.forward_t(
|
||||
&decoder_input_ids,
|
||||
&encoder_hidden_states,
|
||||
attention_mask,
|
||||
decoder_padding_mask.as_ref(),
|
||||
causal_mask.as_ref(),
|
||||
&self.embeddings,
|
||||
layer_states,
|
||||
train);
|
||||
(decoder_outputs, encoder_hidden_states, decoder_cache.1,
|
||||
all_decoder_hidden_states, all_decoder_attentions,
|
||||
all_encoder_hidden_states, all_encoder_attentions)
|
||||
train,
|
||||
);
|
||||
(
|
||||
decoder_outputs,
|
||||
encoder_hidden_states,
|
||||
decoder_cache.1,
|
||||
all_decoder_hidden_states,
|
||||
all_decoder_attentions,
|
||||
all_encoder_hidden_states,
|
||||
all_encoder_attentions,
|
||||
)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
/// # BART Model for conditional generation
|
||||
@ -356,20 +436,24 @@ impl BartForConditionalGeneration {
|
||||
/// # Example
|
||||
///
|
||||
/// ```no_run
|
||||
/// use tch::{nn, Device};
|
||||
/// use rust_bert::bart::{BartConfig, BartForConditionalGeneration};
|
||||
/// use rust_bert::Config;
|
||||
/// use std::path::Path;
|
||||
/// use rust_bert::bart::{BartConfig, BartForConditionalGeneration};
|
||||
/// use tch::{nn, Device};
|
||||
///
|
||||
/// let config_path = Path::new("path/to/config.json");
|
||||
/// let device = Device::Cpu;
|
||||
/// let p = nn::VarStore::new(device);
|
||||
/// let config = BartConfig::from_file(config_path);
|
||||
/// let generation_mode = true;
|
||||
/// let bart: BartForConditionalGeneration = BartForConditionalGeneration::new(&(&p.root() / "bart"), &config, generation_mode);
|
||||
/// let bart: BartForConditionalGeneration =
|
||||
/// BartForConditionalGeneration::new(&(&p.root() / "bart"), &config, generation_mode);
|
||||
/// ```
|
||||
///
|
||||
pub fn new(p: &nn::Path, config: &BartConfig, generation_mode: bool) -> BartForConditionalGeneration {
|
||||
pub fn new(
|
||||
p: &nn::Path,
|
||||
config: &BartConfig,
|
||||
generation_mode: bool,
|
||||
) -> BartForConditionalGeneration {
|
||||
let base_model = BartModel::new(&(p / "model"), config, generation_mode);
|
||||
BartForConditionalGeneration { base_model }
|
||||
}
|
||||
@ -427,37 +511,64 @@ impl BartForConditionalGeneration {
|
||||
/// None,
|
||||
/// false)
|
||||
/// });
|
||||
///
|
||||
/// ```
|
||||
///
|
||||
pub fn forward_t(&self,
|
||||
pub fn forward_t(
|
||||
&self,
|
||||
input_ids: Option<&Tensor>,
|
||||
attention_mask: Option<&Tensor>,
|
||||
encoder_outputs: Option<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>)>,
|
||||
decoder_input_ids: Option<&Tensor>,
|
||||
decoder_attention_mask: Option<&Tensor>,
|
||||
old_layer_states: Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
|
||||
train: bool)
|
||||
-> (Tensor, Tensor, Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
|
||||
Option<Vec<Tensor>>, Option<Vec<Tensor>>,
|
||||
Option<Vec<Tensor>>, Option<Vec<Tensor>>)
|
||||
{
|
||||
let (decoder_outputs, encoder_hidden_states, decoder_cache,
|
||||
all_decoder_hidden_states, all_decoder_attentions,
|
||||
all_encoder_hidden_states, all_encoder_attentions) =
|
||||
self.base_model.forward_t(input_ids, attention_mask, decoder_input_ids, encoder_outputs, decoder_attention_mask, old_layer_states, train);
|
||||
train: bool,
|
||||
) -> (
|
||||
Tensor,
|
||||
Tensor,
|
||||
Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
|
||||
Option<Vec<Tensor>>,
|
||||
Option<Vec<Tensor>>,
|
||||
Option<Vec<Tensor>>,
|
||||
Option<Vec<Tensor>>,
|
||||
) {
|
||||
let (
|
||||
decoder_outputs,
|
||||
encoder_hidden_states,
|
||||
decoder_cache,
|
||||
all_decoder_hidden_states,
|
||||
all_decoder_attentions,
|
||||
all_encoder_hidden_states,
|
||||
all_encoder_attentions,
|
||||
) = self.base_model.forward_t(
|
||||
input_ids,
|
||||
attention_mask,
|
||||
decoder_input_ids,
|
||||
encoder_outputs,
|
||||
decoder_attention_mask,
|
||||
old_layer_states,
|
||||
train,
|
||||
);
|
||||
|
||||
let lm_logits = decoder_outputs.linear::<Tensor>(&self.base_model.embeddings.ws, None);
|
||||
(lm_logits, encoder_hidden_states, decoder_cache,
|
||||
all_decoder_hidden_states, all_decoder_attentions,
|
||||
all_encoder_hidden_states, all_encoder_attentions)
|
||||
(
|
||||
lm_logits,
|
||||
encoder_hidden_states,
|
||||
decoder_cache,
|
||||
all_decoder_hidden_states,
|
||||
all_decoder_attentions,
|
||||
all_encoder_hidden_states,
|
||||
all_encoder_attentions,
|
||||
)
|
||||
}
|
||||
|
||||
pub fn encode(&self, input_ids: &Tensor, attention_mask: Option<&Tensor>) -> Tensor {
|
||||
let (encoder_hidden_states, _, _) = self.base_model.encoder.forward_t(input_ids, attention_mask, &self.base_model.embeddings, false);
|
||||
let (encoder_hidden_states, _, _) = self.base_model.encoder.forward_t(
|
||||
input_ids,
|
||||
attention_mask,
|
||||
&self.base_model.embeddings,
|
||||
false,
|
||||
);
|
||||
encoder_hidden_states
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
pub struct BartClassificationHead {
|
||||
@ -468,16 +579,29 @@ pub struct BartClassificationHead {
|
||||
|
||||
impl BartClassificationHead {
|
||||
pub fn new(p: &nn::Path, config: &BartConfig) -> BartClassificationHead {
|
||||
let dense = nn::linear(&(p / "dense"), config.d_model, config.d_model, Default::default());
|
||||
let dense = nn::linear(
|
||||
&(p / "dense"),
|
||||
config.d_model,
|
||||
config.d_model,
|
||||
Default::default(),
|
||||
);
|
||||
let dropout = Dropout::new(config.classif_dropout);
|
||||
let out_proj = nn::linear(&(p / "out_proj"), config.d_model, config.num_labels.unwrap(), Default::default());
|
||||
let out_proj = nn::linear(
|
||||
&(p / "out_proj"),
|
||||
config.d_model,
|
||||
config.num_labels.unwrap(),
|
||||
Default::default(),
|
||||
);
|
||||
|
||||
BartClassificationHead { dense, dropout, out_proj }
|
||||
BartClassificationHead {
|
||||
dense,
|
||||
dropout,
|
||||
out_proj,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn forward_t(&self, x: &Tensor, train: bool) -> Tensor {
|
||||
x
|
||||
.apply_t(&self.dropout, train)
|
||||
x.apply_t(&self.dropout, train)
|
||||
.apply(&self.dense)
|
||||
.tanh()
|
||||
.apply_t(&self.dropout, train)
|
||||
@ -497,7 +621,6 @@ pub struct BartForSequenceClassification {
|
||||
eos_token_id: i64,
|
||||
}
|
||||
|
||||
|
||||
impl BartForSequenceClassification {
|
||||
/// Build a new `BartForSequenceClassification`
|
||||
///
|
||||
@ -509,27 +632,31 @@ impl BartForSequenceClassification {
|
||||
/// # Example
|
||||
///
|
||||
/// ```no_run
|
||||
/// use tch::{nn, Device};
|
||||
/// use rust_bert::bart::{BartConfig, BartForSequenceClassification};
|
||||
/// use rust_bert::Config;
|
||||
/// use std::path::Path;
|
||||
/// use rust_bert::bart::{BartConfig, BartForSequenceClassification};
|
||||
/// use tch::{nn, Device};
|
||||
///
|
||||
/// let config_path = Path::new("path/to/config.json");
|
||||
/// let device = Device::Cpu;
|
||||
/// let p = nn::VarStore::new(device);
|
||||
/// let config = BartConfig::from_file(config_path);
|
||||
/// let generation_mode = true;
|
||||
/// let bart: BartForSequenceClassification = BartForSequenceClassification::new(&(&p.root() / "bart"), &config);
|
||||
/// let bart: BartForSequenceClassification =
|
||||
/// BartForSequenceClassification::new(&(&p.root() / "bart"), &config);
|
||||
/// ```
|
||||
///
|
||||
pub fn new(p: &nn::Path, config: &BartConfig) -> BartForSequenceClassification {
|
||||
let base_model = BartModel::new(&(p / "model"), config, false);
|
||||
let classification_head = BartClassificationHead::new(&(p / "classification_head"), config);
|
||||
let eos_token_id = match config.eos_token_id {
|
||||
Some(value) => value,
|
||||
None => 3
|
||||
None => 3,
|
||||
};
|
||||
BartForSequenceClassification { base_model, classification_head, eos_token_id }
|
||||
BartForSequenceClassification {
|
||||
base_model,
|
||||
classification_head,
|
||||
eos_token_id,
|
||||
}
|
||||
}
|
||||
|
||||
/// Forward pass through the model
|
||||
@ -585,36 +712,63 @@ impl BartForSequenceClassification {
|
||||
/// None,
|
||||
/// false)
|
||||
/// });
|
||||
///
|
||||
/// ```
|
||||
///
|
||||
pub fn forward_t(&mut self,
|
||||
pub fn forward_t(
|
||||
&mut self,
|
||||
input_ids: &Tensor,
|
||||
attention_mask: Option<&Tensor>,
|
||||
encoder_outputs: Option<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>)>,
|
||||
decoder_input_ids: Option<&Tensor>,
|
||||
decoder_attention_mask: Option<&Tensor>,
|
||||
train: bool)
|
||||
-> (Tensor, Tensor,
|
||||
Option<Vec<Tensor>>, Option<Vec<Tensor>>,
|
||||
Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
|
||||
let (decoder_outputs, encoder_hidden_states, _,
|
||||
all_decoder_hidden_states, all_decoder_attentions,
|
||||
all_encoder_hidden_states, all_encoder_attentions) =
|
||||
self.borrow_mut().base_model.forward_t(Some(input_ids), attention_mask, decoder_input_ids, encoder_outputs, decoder_attention_mask, None, train);
|
||||
train: bool,
|
||||
) -> (
|
||||
Tensor,
|
||||
Tensor,
|
||||
Option<Vec<Tensor>>,
|
||||
Option<Vec<Tensor>>,
|
||||
Option<Vec<Tensor>>,
|
||||
Option<Vec<Tensor>>,
|
||||
) {
|
||||
let (
|
||||
decoder_outputs,
|
||||
encoder_hidden_states,
|
||||
_,
|
||||
all_decoder_hidden_states,
|
||||
all_decoder_attentions,
|
||||
all_encoder_hidden_states,
|
||||
all_encoder_attentions,
|
||||
) = self.borrow_mut().base_model.forward_t(
|
||||
Some(input_ids),
|
||||
attention_mask,
|
||||
decoder_input_ids,
|
||||
encoder_outputs,
|
||||
decoder_attention_mask,
|
||||
None,
|
||||
train,
|
||||
);
|
||||
|
||||
let eos_mask = input_ids.eq(self.eos_token_id);
|
||||
let sentence_representation = decoder_outputs
|
||||
.index_select(0, &eos_mask)
|
||||
.view((decoder_outputs.size()[0], -1, *decoder_outputs.size().last().unwrap()))
|
||||
.view((
|
||||
decoder_outputs.size()[0],
|
||||
-1,
|
||||
*decoder_outputs.size().last().unwrap(),
|
||||
))
|
||||
.select(1, -1);
|
||||
|
||||
let logits = self.classification_head.forward_t(&sentence_representation, train);
|
||||
(logits, encoder_hidden_states,
|
||||
all_decoder_hidden_states, all_decoder_attentions,
|
||||
all_encoder_hidden_states, all_encoder_attentions)
|
||||
let logits = self
|
||||
.classification_head
|
||||
.forward_t(&sentence_representation, train);
|
||||
(
|
||||
logits,
|
||||
encoder_hidden_states,
|
||||
all_decoder_hidden_states,
|
||||
all_decoder_attentions,
|
||||
all_encoder_hidden_states,
|
||||
all_encoder_attentions,
|
||||
)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
impl LMHeadModel for BartForConditionalGeneration {
|
||||
@ -675,10 +829,9 @@ impl LMHeadModel for BartForConditionalGeneration {
|
||||
/// None,
|
||||
/// false)
|
||||
/// });
|
||||
///
|
||||
/// ```
|
||||
///
|
||||
fn forward_t(&self,
|
||||
fn forward_t(
|
||||
&self,
|
||||
input_ids: &Option<Tensor>,
|
||||
cache: Cache,
|
||||
attention_mask: &Option<Tensor>,
|
||||
@ -687,27 +840,47 @@ impl LMHeadModel for BartForConditionalGeneration {
|
||||
_input_embeds: &Option<Tensor>,
|
||||
encoder_outputs: Option<&Tensor>,
|
||||
decoder_input_ids: &Option<Tensor>,
|
||||
train: bool) -> Result<(Tensor, Option<Tensor>, Cache, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
|
||||
train: bool,
|
||||
) -> Result<
|
||||
(
|
||||
Tensor,
|
||||
Option<Tensor>,
|
||||
Cache,
|
||||
Option<Vec<Tensor>>,
|
||||
Option<Vec<Tensor>>,
|
||||
),
|
||||
&'static str,
|
||||
> {
|
||||
let (decoder_output, encoder_hidden_states, new_cache, _, _, _, _) = match cache {
|
||||
Cache::BARTCache(cached_layer_states) => self.base_model.forward_t(input_ids.as_ref(),
|
||||
Cache::BARTCache(cached_layer_states) => self.base_model.forward_t(
|
||||
input_ids.as_ref(),
|
||||
attention_mask.as_ref(),
|
||||
decoder_input_ids.as_ref(),
|
||||
Some((encoder_outputs.as_ref().unwrap().copy(), None, None)),
|
||||
None,
|
||||
cached_layer_states,
|
||||
train),
|
||||
train,
|
||||
),
|
||||
|
||||
Cache::None => self.base_model.forward_t(input_ids.as_ref(),
|
||||
Cache::None => self.base_model.forward_t(
|
||||
input_ids.as_ref(),
|
||||
attention_mask.as_ref(),
|
||||
decoder_input_ids.as_ref(),
|
||||
Some((encoder_outputs.as_ref().unwrap().copy(), None, None)),
|
||||
None,
|
||||
None,
|
||||
train),
|
||||
_ => Err("Cache not compatible with BART Model")?
|
||||
train,
|
||||
),
|
||||
_ => Err("Cache not compatible with BART Model")?,
|
||||
};
|
||||
|
||||
let lm_logits = decoder_output.linear::<Tensor>(&self.base_model.embeddings.ws, None);
|
||||
Ok((lm_logits, Some(encoder_hidden_states), Cache::BARTCache(new_cache), None, None))
|
||||
Ok((
|
||||
lm_logits,
|
||||
Some(encoder_hidden_states),
|
||||
Cache::BARTCache(new_cache),
|
||||
None,
|
||||
None,
|
||||
))
|
||||
}
|
||||
}
|
@ -11,15 +11,17 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use crate::bart::attention::{SelfAttention, LayerState};
|
||||
use tch::{nn, Tensor};
|
||||
use crate::common::dropout::Dropout;
|
||||
use crate::bart::BartConfig;
|
||||
use crate::bart::attention::{LayerState, SelfAttention};
|
||||
use crate::bart::bart::Activation;
|
||||
use crate::common::activations::{_gelu, _relu, _swish, _gelu_new, _tanh};
|
||||
use tch::kind::Kind::Int64;
|
||||
use crate::bart::embeddings::{
|
||||
EmbeddingOption, LearnedPositionalEmbedding, SinusoidalPositionalEmbedding,
|
||||
};
|
||||
use crate::bart::BartConfig;
|
||||
use crate::common::activations::{_gelu, _gelu_new, _relu, _swish, _tanh};
|
||||
use crate::common::dropout::Dropout;
|
||||
use std::borrow::BorrowMut;
|
||||
use crate::bart::embeddings::{EmbeddingOption, LearnedPositionalEmbedding, SinusoidalPositionalEmbedding};
|
||||
use tch::kind::Kind::Int64;
|
||||
use tch::{nn, Tensor};
|
||||
|
||||
pub struct DecoderLayer {
|
||||
self_attention: SelfAttention,
|
||||
@ -36,51 +38,74 @@ pub struct DecoderLayer {
|
||||
|
||||
impl DecoderLayer {
|
||||
pub fn new(p: nn::Path, config: &BartConfig) -> DecoderLayer {
|
||||
let layer_norm_config = nn::LayerNormConfig { eps: 1e-5, ..Default::default() };
|
||||
let layer_norm_config = nn::LayerNormConfig {
|
||||
eps: 1e-5,
|
||||
..Default::default()
|
||||
};
|
||||
let output_attention = match config.output_attentions {
|
||||
Some(value) => value,
|
||||
None => false
|
||||
None => false,
|
||||
};
|
||||
let self_attention = SelfAttention::new(&p / "self_attn",
|
||||
let self_attention = SelfAttention::new(
|
||||
&p / "self_attn",
|
||||
config.d_model,
|
||||
config.decoder_attention_heads,
|
||||
config.attention_dropout,
|
||||
false,
|
||||
true,
|
||||
output_attention);
|
||||
let encoder_attention = SelfAttention::new(&p / "encoder_attn",
|
||||
output_attention,
|
||||
);
|
||||
let encoder_attention = SelfAttention::new(
|
||||
&p / "encoder_attn",
|
||||
config.d_model,
|
||||
config.decoder_attention_heads,
|
||||
config.attention_dropout,
|
||||
true,
|
||||
true,
|
||||
output_attention);
|
||||
let self_attention_layer_norm = nn::layer_norm(&p / "self_attn_layer_norm",
|
||||
output_attention,
|
||||
);
|
||||
let self_attention_layer_norm = nn::layer_norm(
|
||||
&p / "self_attn_layer_norm",
|
||||
vec![config.d_model],
|
||||
layer_norm_config);
|
||||
let encoder_attention_layer_norm = nn::layer_norm(&p / "encoder_attn_layer_norm",
|
||||
layer_norm_config,
|
||||
);
|
||||
let encoder_attention_layer_norm = nn::layer_norm(
|
||||
&p / "encoder_attn_layer_norm",
|
||||
vec![config.d_model],
|
||||
layer_norm_config);
|
||||
layer_norm_config,
|
||||
);
|
||||
|
||||
let dropout = Dropout::new(config.dropout);
|
||||
let activation_dropout = Dropout::new(config.activation_dropout);
|
||||
let activation_function = match &config.activation_function {
|
||||
Some(act_function) => act_function,
|
||||
None => &Activation::gelu
|
||||
None => &Activation::gelu,
|
||||
};
|
||||
let activation = Box::new(match activation_function {
|
||||
Activation::gelu => _gelu,
|
||||
Activation::relu => _relu,
|
||||
Activation::swish => _swish,
|
||||
Activation::gelu_new => _gelu_new,
|
||||
Activation::tanh => _tanh
|
||||
Activation::tanh => _tanh,
|
||||
});
|
||||
let fc1 = nn::linear(&p / "fc1", config.d_model, config.decoder_ffn_dim, Default::default());
|
||||
let fc2 = nn::linear(&p / "fc2", config.decoder_ffn_dim, config.d_model, Default::default());
|
||||
let fc1 = nn::linear(
|
||||
&p / "fc1",
|
||||
config.d_model,
|
||||
config.decoder_ffn_dim,
|
||||
Default::default(),
|
||||
);
|
||||
let fc2 = nn::linear(
|
||||
&p / "fc2",
|
||||
config.decoder_ffn_dim,
|
||||
config.d_model,
|
||||
Default::default(),
|
||||
);
|
||||
|
||||
let final_layer_norm = nn::layer_norm(&p / "final_layer_norm",
|
||||
let final_layer_norm = nn::layer_norm(
|
||||
&p / "final_layer_norm",
|
||||
vec![config.d_model],
|
||||
layer_norm_config);
|
||||
layer_norm_config,
|
||||
);
|
||||
|
||||
DecoderLayer {
|
||||
self_attention,
|
||||
@ -96,18 +121,38 @@ impl DecoderLayer {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn forward_t(&self,
|
||||
pub fn forward_t(
|
||||
&self,
|
||||
x: &Tensor,
|
||||
encoder_hidden_states: &Tensor,
|
||||
encoder_attn_mask: Option<&Tensor>,
|
||||
causal_mask: Option<&Tensor>,
|
||||
decoder_padding_mask: Option<&Tensor>,
|
||||
layer_states: (Option<LayerState>, Option<LayerState>),
|
||||
train: bool) -> (Tensor, Option<Tensor>, (Option<LayerState>, Option<LayerState>)) {
|
||||
let (output, attention_weights, new_self_layer_states) = self.self_attention.forward_t(x, Some(x), decoder_padding_mask, causal_mask, layer_states.0, train);
|
||||
train: bool,
|
||||
) -> (
|
||||
Tensor,
|
||||
Option<Tensor>,
|
||||
(Option<LayerState>, Option<LayerState>),
|
||||
) {
|
||||
let (output, attention_weights, new_self_layer_states) = self.self_attention.forward_t(
|
||||
x,
|
||||
Some(x),
|
||||
decoder_padding_mask,
|
||||
causal_mask,
|
||||
layer_states.0,
|
||||
train,
|
||||
);
|
||||
let output: Tensor = output.apply_t(&self.dropout, train) + x;
|
||||
let output = output.apply(&self.self_attention_layer_norm);
|
||||
let (output1, _, new_encoder_layer_states) = self.encoder_attention.forward_t(&output, Some(encoder_hidden_states), encoder_attn_mask, None, layer_states.1, train);
|
||||
let (output1, _, new_encoder_layer_states) = self.encoder_attention.forward_t(
|
||||
&output,
|
||||
Some(encoder_hidden_states),
|
||||
encoder_attn_mask,
|
||||
None,
|
||||
layer_states.1,
|
||||
train,
|
||||
);
|
||||
let output1: Tensor = output1.apply_t(&self.dropout, train) + output;
|
||||
let output1 = output1.apply(&self.encoder_attention_layer_norm);
|
||||
let output2 = (self.activation)(&output1.apply(&self.fc1));
|
||||
@ -116,7 +161,11 @@ impl DecoderLayer {
|
||||
.apply(&self.fc2)
|
||||
.apply_t(&self.dropout, train);
|
||||
let output2: Tensor = output2 + output1;
|
||||
(output2.apply(&self.final_layer_norm), attention_weights, (new_self_layer_states, new_encoder_layer_states))
|
||||
(
|
||||
output2.apply(&self.final_layer_norm),
|
||||
attention_weights,
|
||||
(new_self_layer_states, new_encoder_layer_states),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@ -136,61 +185,76 @@ impl BartDecoder {
|
||||
pub fn new(p: nn::Path, config: &BartConfig, generation_mode: bool) -> BartDecoder {
|
||||
let output_past = match config.output_past {
|
||||
Some(value) => value,
|
||||
None => true
|
||||
None => true,
|
||||
};
|
||||
let output_attentions = match config.output_attentions {
|
||||
Some(value) => value,
|
||||
None => false
|
||||
None => false,
|
||||
};
|
||||
let output_hidden_states = match config.output_hidden_states {
|
||||
Some(value) => value,
|
||||
None => false
|
||||
None => false,
|
||||
};
|
||||
let normalize_embedding = match config.normalize_embedding {
|
||||
Some(value) => value,
|
||||
None => true
|
||||
None => true,
|
||||
};
|
||||
let static_position_embeddings = match config.static_position_embeddings {
|
||||
Some(value) => value,
|
||||
None => false
|
||||
None => false,
|
||||
};
|
||||
let scale_embedding = match config.scale_embedding {
|
||||
Some(value) => if value { (config.d_model as f64).sqrt() } else { 1.0 },
|
||||
None => 1.0
|
||||
Some(value) => {
|
||||
if value {
|
||||
(config.d_model as f64).sqrt()
|
||||
} else {
|
||||
1.0
|
||||
}
|
||||
}
|
||||
None => 1.0,
|
||||
};
|
||||
|
||||
let dropout = Dropout::new(config.dropout);
|
||||
|
||||
let layer_norm_embedding = if normalize_embedding {
|
||||
let layer_norm_config = nn::LayerNormConfig { eps: 1e-5, ..Default::default() };
|
||||
Some(nn::layer_norm(&p / "layernorm_embedding",
|
||||
let layer_norm_config = nn::LayerNormConfig {
|
||||
eps: 1e-5,
|
||||
..Default::default()
|
||||
};
|
||||
Some(nn::layer_norm(
|
||||
&p / "layernorm_embedding",
|
||||
vec![config.d_model],
|
||||
layer_norm_config))
|
||||
layer_norm_config,
|
||||
))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let pad_token_id = match config.pad_token_id {
|
||||
Some(value) => value,
|
||||
None => 1
|
||||
None => 1,
|
||||
};
|
||||
|
||||
let embed_positions = if static_position_embeddings {
|
||||
EmbeddingOption::SinusoidalPositionalEmbedding(SinusoidalPositionalEmbedding::new(&p / "embed_positions",
|
||||
config.max_position_embeddings,
|
||||
config.d_model))
|
||||
} else {
|
||||
EmbeddingOption::LearnedPositionalEmbedding(LearnedPositionalEmbedding::new(&p / "embed_positions",
|
||||
EmbeddingOption::SinusoidalPositionalEmbedding(SinusoidalPositionalEmbedding::new(
|
||||
&p / "embed_positions",
|
||||
config.max_position_embeddings,
|
||||
config.d_model,
|
||||
pad_token_id))
|
||||
))
|
||||
} else {
|
||||
EmbeddingOption::LearnedPositionalEmbedding(LearnedPositionalEmbedding::new(
|
||||
&p / "embed_positions",
|
||||
config.max_position_embeddings,
|
||||
config.d_model,
|
||||
pad_token_id,
|
||||
))
|
||||
};
|
||||
|
||||
let mut layers: Vec<DecoderLayer> = vec!();
|
||||
let mut layers: Vec<DecoderLayer> = vec![];
|
||||
let p_layers = &p / "layers";
|
||||
for layer_index in 0..config.decoder_layers {
|
||||
layers.push(DecoderLayer::new(&p_layers / layer_index, config));
|
||||
};
|
||||
}
|
||||
|
||||
BartDecoder {
|
||||
dropout,
|
||||
@ -205,7 +269,8 @@ impl BartDecoder {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn forward_t(&self,
|
||||
pub fn forward_t(
|
||||
&self,
|
||||
input_ids: &Tensor,
|
||||
encoder_hidden_states: &Tensor,
|
||||
encoder_padding_mask: Option<&Tensor>,
|
||||
@ -213,33 +278,56 @@ impl BartDecoder {
|
||||
decoder_causal_mask: Option<&Tensor>,
|
||||
embeddings: &nn::Embedding,
|
||||
old_layer_states: Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
|
||||
train: bool)
|
||||
-> (Tensor,
|
||||
(Option<Tensor>, Option<Vec<(Option<LayerState>, Option<LayerState>)>>),
|
||||
train: bool,
|
||||
) -> (
|
||||
Tensor,
|
||||
(
|
||||
Option<Tensor>,
|
||||
Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
|
||||
),
|
||||
Option<Vec<Tensor>>,
|
||||
Option<Vec<Tensor>>) {
|
||||
Option<Vec<Tensor>>,
|
||||
) {
|
||||
let encoder_padding_mask = match encoder_padding_mask {
|
||||
Some(mask) => Some(mask.eq(0).to_kind(Int64)),
|
||||
None => None
|
||||
None => None,
|
||||
};
|
||||
|
||||
let positions = self.embed_positions.forward(input_ids, self.generation_mode);
|
||||
let positions = self
|
||||
.embed_positions
|
||||
.forward(input_ids, self.generation_mode);
|
||||
let x: Tensor = if self.generation_mode {
|
||||
let end_inputs = input_ids.size()[1];
|
||||
let end_positions = positions.size()[1];
|
||||
input_ids.narrow(1, end_inputs - 1, 1).apply(embeddings) * self.scale_embedding + positions.narrow(1, end_positions - 1, 1)
|
||||
input_ids.narrow(1, end_inputs - 1, 1).apply(embeddings) * self.scale_embedding
|
||||
+ positions.narrow(1, end_positions - 1, 1)
|
||||
} else {
|
||||
input_ids.apply(embeddings) * self.scale_embedding + positions
|
||||
};
|
||||
|
||||
let x = if let Some(layer_norm_embedding) = &self.layer_norm_embedding { x.apply(layer_norm_embedding) } else { x };
|
||||
let mut hidden_state = x
|
||||
.apply_t(&self.dropout, train)
|
||||
.transpose(0, 1);
|
||||
let mut all_hidden_states: Option<Vec<Tensor>> = if self.output_hidden_states { Some(Vec::with_capacity(self.layers.len())) } else { None };
|
||||
let mut all_attentions: Option<Vec<Tensor>> = if self.output_attentions { Some(Vec::with_capacity(self.layers.len())) } else { None };
|
||||
let mut next_decoder_cache: Option<Vec<(Option<LayerState>, Option<LayerState>)>> = if self.output_past {
|
||||
if old_layer_states.is_some() { old_layer_states } else { Some(vec!((None, None); self.layers.len())) }
|
||||
let x = if let Some(layer_norm_embedding) = &self.layer_norm_embedding {
|
||||
x.apply(layer_norm_embedding)
|
||||
} else {
|
||||
x
|
||||
};
|
||||
let mut hidden_state = x.apply_t(&self.dropout, train).transpose(0, 1);
|
||||
let mut all_hidden_states: Option<Vec<Tensor>> = if self.output_hidden_states {
|
||||
Some(Vec::with_capacity(self.layers.len()))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let mut all_attentions: Option<Vec<Tensor>> = if self.output_attentions {
|
||||
Some(Vec::with_capacity(self.layers.len()))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let mut next_decoder_cache: Option<Vec<(Option<LayerState>, Option<LayerState>)>> =
|
||||
if self.output_past {
|
||||
if old_layer_states.is_some() {
|
||||
old_layer_states
|
||||
} else {
|
||||
Some(vec![(None, None); self.layers.len()])
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
@ -252,15 +340,17 @@ impl BartDecoder {
|
||||
Some((layer_idx, layer)) => {
|
||||
let layer_state = match &next_decoder_cache {
|
||||
Some(values) => values[layer_idx].to_owned(),
|
||||
None => (None, None)
|
||||
None => (None, None),
|
||||
};
|
||||
let temp = layer.forward_t(&hidden_state,
|
||||
let temp = layer.forward_t(
|
||||
&hidden_state,
|
||||
&encoder_hidden_states,
|
||||
encoder_padding_mask.as_ref(),
|
||||
decoder_causal_mask,
|
||||
decoder_padding_mask,
|
||||
layer_state,
|
||||
train);
|
||||
train,
|
||||
);
|
||||
hidden_state = temp.0;
|
||||
attention_weights = temp.1;
|
||||
if let Some(hidden_states) = all_hidden_states.borrow_mut() {
|
||||
@ -269,15 +359,19 @@ impl BartDecoder {
|
||||
if let Some(attentions) = all_attentions.borrow_mut() {
|
||||
attentions.push(attention_weights.as_ref().unwrap().copy());
|
||||
};
|
||||
if let Some(value) = &mut next_decoder_cache { value[layer_idx] = temp.2 };
|
||||
if let Some(value) = &mut next_decoder_cache {
|
||||
value[layer_idx] = temp.2
|
||||
};
|
||||
}
|
||||
None => break
|
||||
};
|
||||
None => break,
|
||||
};
|
||||
}
|
||||
|
||||
(hidden_state.transpose(0, 1),
|
||||
(
|
||||
hidden_state.transpose(0, 1),
|
||||
(encoder_padding_mask, next_decoder_cache),
|
||||
all_hidden_states,
|
||||
all_attentions)
|
||||
all_attentions,
|
||||
)
|
||||
}
|
||||
}
|
@ -11,10 +11,9 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use tch::{nn, Tensor};
|
||||
use tch::nn::{EmbeddingConfig, embedding};
|
||||
use tch::kind::Kind::Int64;
|
||||
|
||||
use tch::nn::{embedding, EmbeddingConfig};
|
||||
use tch::{nn, Tensor};
|
||||
|
||||
/// # Abstraction that holds a embeddings configuration
|
||||
pub enum EmbeddingOption {
|
||||
@ -27,8 +26,12 @@ impl EmbeddingOption {
|
||||
/// Interface method to forward_t() of the particular models.
|
||||
pub fn forward(&self, input: &Tensor, generation_mode: bool) -> Tensor {
|
||||
match *self {
|
||||
Self::LearnedPositionalEmbedding(ref embeddings) => embeddings.forward(input, generation_mode),
|
||||
Self::SinusoidalPositionalEmbedding(ref embeddings) => embeddings.forward(input, generation_mode)
|
||||
Self::LearnedPositionalEmbedding(ref embeddings) => {
|
||||
embeddings.forward(input, generation_mode)
|
||||
}
|
||||
Self::SinusoidalPositionalEmbedding(ref embeddings) => {
|
||||
embeddings.forward(input, generation_mode)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -40,15 +43,24 @@ pub struct LearnedPositionalEmbedding {
|
||||
}
|
||||
|
||||
impl LearnedPositionalEmbedding {
|
||||
pub fn new(p: nn::Path, num_embeddings: i64, embedding_dim: i64, padding_index: i64) -> LearnedPositionalEmbedding {
|
||||
let embedding_config = EmbeddingConfig { padding_idx: padding_index, ..Default::default() };
|
||||
pub fn new(
|
||||
p: nn::Path,
|
||||
num_embeddings: i64,
|
||||
embedding_dim: i64,
|
||||
padding_index: i64,
|
||||
) -> LearnedPositionalEmbedding {
|
||||
let embedding_config = EmbeddingConfig {
|
||||
padding_idx: padding_index,
|
||||
..Default::default()
|
||||
};
|
||||
let num_embeddings = num_embeddings + padding_index + 1;
|
||||
|
||||
let embedding: nn::Embedding = embedding(p,
|
||||
num_embeddings,
|
||||
embedding_dim,
|
||||
embedding_config);
|
||||
LearnedPositionalEmbedding { embedding, padding_index }
|
||||
let embedding: nn::Embedding =
|
||||
embedding(p, num_embeddings, embedding_dim, embedding_config);
|
||||
LearnedPositionalEmbedding {
|
||||
embedding,
|
||||
padding_index,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn forward(&self, input: &Tensor, generation_mode: bool) -> Tensor {
|
||||
@ -74,11 +86,13 @@ pub struct SinusoidalPositionalEmbedding {
|
||||
}
|
||||
|
||||
impl SinusoidalPositionalEmbedding {
|
||||
pub fn new(p: nn::Path, num_embeddings: i64, embedding_dim: i64) -> SinusoidalPositionalEmbedding {
|
||||
let embedding: nn::Embedding = embedding(p,
|
||||
num_embeddings,
|
||||
embedding_dim,
|
||||
Default::default());
|
||||
pub fn new(
|
||||
p: nn::Path,
|
||||
num_embeddings: i64,
|
||||
embedding_dim: i64,
|
||||
) -> SinusoidalPositionalEmbedding {
|
||||
let embedding: nn::Embedding =
|
||||
embedding(p, num_embeddings, embedding_dim, Default::default());
|
||||
SinusoidalPositionalEmbedding { embedding }
|
||||
}
|
||||
|
||||
|
@ -12,14 +12,16 @@
|
||||
// limitations under the License.
|
||||
|
||||
use crate::bart::attention::SelfAttention;
|
||||
use tch::{nn, Tensor};
|
||||
use crate::common::dropout::Dropout;
|
||||
use crate::bart::BartConfig;
|
||||
use crate::bart::bart::Activation;
|
||||
use crate::common::activations::{_gelu, _relu, _swish, _gelu_new, _tanh};
|
||||
use crate::bart::embeddings::{EmbeddingOption, LearnedPositionalEmbedding, SinusoidalPositionalEmbedding};
|
||||
use tch::kind::Kind::Bool;
|
||||
use crate::bart::embeddings::{
|
||||
EmbeddingOption, LearnedPositionalEmbedding, SinusoidalPositionalEmbedding,
|
||||
};
|
||||
use crate::bart::BartConfig;
|
||||
use crate::common::activations::{_gelu, _gelu_new, _relu, _swish, _tanh};
|
||||
use crate::common::dropout::Dropout;
|
||||
use std::borrow::BorrowMut;
|
||||
use tch::kind::Kind::Bool;
|
||||
use tch::{nn, Tensor};
|
||||
|
||||
pub struct EncoderLayer {
|
||||
self_attention: SelfAttention,
|
||||
@ -34,46 +36,81 @@ pub struct EncoderLayer {
|
||||
|
||||
impl EncoderLayer {
|
||||
pub fn new(p: nn::Path, config: &BartConfig) -> EncoderLayer {
|
||||
let layer_norm_config = nn::LayerNormConfig { eps: 1e-5, ..Default::default() };
|
||||
let layer_norm_config = nn::LayerNormConfig {
|
||||
eps: 1e-5,
|
||||
..Default::default()
|
||||
};
|
||||
let output_attention = match config.output_attentions {
|
||||
Some(value) => value,
|
||||
None => false
|
||||
None => false,
|
||||
};
|
||||
let self_attention = SelfAttention::new(&p / "self_attn",
|
||||
let self_attention = SelfAttention::new(
|
||||
&p / "self_attn",
|
||||
config.d_model,
|
||||
config.encoder_attention_heads,
|
||||
config.attention_dropout,
|
||||
false,
|
||||
false,
|
||||
output_attention);
|
||||
let self_attention_layer_norm = nn::layer_norm(&p / "self_attn_layer_norm",
|
||||
output_attention,
|
||||
);
|
||||
let self_attention_layer_norm = nn::layer_norm(
|
||||
&p / "self_attn_layer_norm",
|
||||
vec![config.d_model],
|
||||
layer_norm_config);
|
||||
layer_norm_config,
|
||||
);
|
||||
let dropout = Dropout::new(config.dropout);
|
||||
let activation_dropout = Dropout::new(config.activation_dropout);
|
||||
let activation_function = match &config.activation_function {
|
||||
Some(act_function) => act_function,
|
||||
None => &Activation::gelu
|
||||
None => &Activation::gelu,
|
||||
};
|
||||
let activation = Box::new(match activation_function {
|
||||
Activation::gelu => _gelu,
|
||||
Activation::relu => _relu,
|
||||
Activation::swish => _swish,
|
||||
Activation::gelu_new => _gelu_new,
|
||||
Activation::tanh => _tanh
|
||||
Activation::tanh => _tanh,
|
||||
});
|
||||
let fc1 = nn::linear(&p / "fc1", config.d_model, config.encoder_ffn_dim, Default::default());
|
||||
let fc2 = nn::linear(&p / "fc2", config.encoder_ffn_dim, config.d_model, Default::default());
|
||||
let fc1 = nn::linear(
|
||||
&p / "fc1",
|
||||
config.d_model,
|
||||
config.encoder_ffn_dim,
|
||||
Default::default(),
|
||||
);
|
||||
let fc2 = nn::linear(
|
||||
&p / "fc2",
|
||||
config.encoder_ffn_dim,
|
||||
config.d_model,
|
||||
Default::default(),
|
||||
);
|
||||
|
||||
let final_layer_norm = nn::layer_norm(&p / "final_layer_norm",
|
||||
let final_layer_norm = nn::layer_norm(
|
||||
&p / "final_layer_norm",
|
||||
vec![config.d_model],
|
||||
layer_norm_config);
|
||||
layer_norm_config,
|
||||
);
|
||||
|
||||
EncoderLayer { self_attention, self_attention_layer_norm, dropout, activation_dropout, activation, fc1, fc2, final_layer_norm }
|
||||
EncoderLayer {
|
||||
self_attention,
|
||||
self_attention_layer_norm,
|
||||
dropout,
|
||||
activation_dropout,
|
||||
activation,
|
||||
fc1,
|
||||
fc2,
|
||||
final_layer_norm,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn forward_t(&self, x: &Tensor, encoder_padding_mask: Option<&Tensor>, train: bool) -> (Tensor, Option<Tensor>) {
|
||||
let (output, attention_weights, _) = self.self_attention.forward_t(x, None, encoder_padding_mask, None, None, train);
|
||||
pub fn forward_t(
|
||||
&self,
|
||||
x: &Tensor,
|
||||
encoder_padding_mask: Option<&Tensor>,
|
||||
train: bool,
|
||||
) -> (Tensor, Option<Tensor>) {
|
||||
let (output, attention_weights, _) =
|
||||
self.self_attention
|
||||
.forward_t(x, None, encoder_padding_mask, None, None, train);
|
||||
let output: Tensor = output.apply_t(&self.dropout, train) + x;
|
||||
let output = output.apply(&self.self_attention_layer_norm);
|
||||
|
||||
@ -102,57 +139,72 @@ impl BartEncoder {
|
||||
pub fn new(p: nn::Path, config: &BartConfig) -> BartEncoder {
|
||||
let output_attentions = match config.output_attentions {
|
||||
Some(value) => value,
|
||||
None => false
|
||||
None => false,
|
||||
};
|
||||
let output_hidden_states = match config.output_hidden_states {
|
||||
Some(value) => value,
|
||||
None => false
|
||||
None => false,
|
||||
};
|
||||
let normalize_embedding = match config.normalize_embedding {
|
||||
Some(value) => value,
|
||||
None => true
|
||||
None => true,
|
||||
};
|
||||
let static_position_embeddings = match config.static_position_embeddings {
|
||||
Some(value) => value,
|
||||
None => false
|
||||
None => false,
|
||||
};
|
||||
let scale_embedding = match config.scale_embedding {
|
||||
Some(value) => if value { (config.d_model as f64).sqrt() } else { 1.0 },
|
||||
None => 1.0
|
||||
Some(value) => {
|
||||
if value {
|
||||
(config.d_model as f64).sqrt()
|
||||
} else {
|
||||
1.0
|
||||
}
|
||||
}
|
||||
None => 1.0,
|
||||
};
|
||||
|
||||
let dropout = Dropout::new(config.dropout);
|
||||
|
||||
let layer_norm_embedding = if normalize_embedding {
|
||||
let layer_norm_config = nn::LayerNormConfig { eps: 1e-5, ..Default::default() };
|
||||
Some(nn::layer_norm(&p / "layernorm_embedding",
|
||||
let layer_norm_config = nn::LayerNormConfig {
|
||||
eps: 1e-5,
|
||||
..Default::default()
|
||||
};
|
||||
Some(nn::layer_norm(
|
||||
&p / "layernorm_embedding",
|
||||
vec![config.d_model],
|
||||
layer_norm_config))
|
||||
layer_norm_config,
|
||||
))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let pad_token_id = match config.pad_token_id {
|
||||
Some(value) => value,
|
||||
None => 1
|
||||
None => 1,
|
||||
};
|
||||
|
||||
let embed_positions = if static_position_embeddings {
|
||||
EmbeddingOption::SinusoidalPositionalEmbedding(SinusoidalPositionalEmbedding::new(&p / "embed_positions",
|
||||
config.max_position_embeddings,
|
||||
config.d_model))
|
||||
} else {
|
||||
EmbeddingOption::LearnedPositionalEmbedding(LearnedPositionalEmbedding::new(&p / "embed_positions",
|
||||
EmbeddingOption::SinusoidalPositionalEmbedding(SinusoidalPositionalEmbedding::new(
|
||||
&p / "embed_positions",
|
||||
config.max_position_embeddings,
|
||||
config.d_model,
|
||||
pad_token_id))
|
||||
))
|
||||
} else {
|
||||
EmbeddingOption::LearnedPositionalEmbedding(LearnedPositionalEmbedding::new(
|
||||
&p / "embed_positions",
|
||||
config.max_position_embeddings,
|
||||
config.d_model,
|
||||
pad_token_id,
|
||||
))
|
||||
};
|
||||
|
||||
let mut layers: Vec<EncoderLayer> = vec!();
|
||||
let mut layers: Vec<EncoderLayer> = vec![];
|
||||
let p_layers = &p / "layers";
|
||||
for layer_index in 0..config.encoder_layers {
|
||||
layers.push(EncoderLayer::new(&p_layers / layer_index, config));
|
||||
};
|
||||
}
|
||||
|
||||
BartEncoder {
|
||||
dropout,
|
||||
@ -165,26 +217,37 @@ impl BartEncoder {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn forward_t(&self,
|
||||
pub fn forward_t(
|
||||
&self,
|
||||
input_ids: &Tensor,
|
||||
attention_mask: Option<&Tensor>,
|
||||
embeddings: &nn::Embedding,
|
||||
train: bool)
|
||||
-> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
|
||||
train: bool,
|
||||
) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
|
||||
let attention_mask = match attention_mask {
|
||||
Some(mask) => Some(mask.eq(0).to_kind(Bool)),
|
||||
None => None
|
||||
None => None,
|
||||
};
|
||||
|
||||
let x = input_ids.apply(embeddings) * self.scale_embedding;
|
||||
let x: Tensor = x + &self.embed_positions.forward(input_ids, false);
|
||||
let x = if let Some(layer_norm_embedding) = &self.layer_norm_embedding { x.apply(layer_norm_embedding) } else { x };
|
||||
let x = x
|
||||
.apply_t(&self.dropout, train)
|
||||
.transpose(0, 1);
|
||||
let x = if let Some(layer_norm_embedding) = &self.layer_norm_embedding {
|
||||
x.apply(layer_norm_embedding)
|
||||
} else {
|
||||
x
|
||||
};
|
||||
let x = x.apply_t(&self.dropout, train).transpose(0, 1);
|
||||
|
||||
let mut all_hidden_states: Option<Vec<Tensor>> = if self.output_hidden_states { Some(vec!()) } else { None };
|
||||
let mut all_attentions: Option<Vec<Tensor>> = if self.output_attentions { Some(vec!()) } else { None };
|
||||
let mut all_hidden_states: Option<Vec<Tensor>> = if self.output_hidden_states {
|
||||
Some(vec![])
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let mut all_attentions: Option<Vec<Tensor>> = if self.output_attentions {
|
||||
Some(vec![])
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let mut hidden_state = x.copy();
|
||||
let mut attention_weights: Option<Tensor>;
|
||||
@ -204,13 +267,17 @@ impl BartEncoder {
|
||||
attentions.push(attention_weights.as_ref().unwrap().copy());
|
||||
};
|
||||
}
|
||||
None => break
|
||||
};
|
||||
None => break,
|
||||
};
|
||||
}
|
||||
if let Some(hidden_states) = all_hidden_states.borrow_mut() {
|
||||
hidden_states.push(hidden_state.as_ref().copy().transpose(0, 1));
|
||||
};
|
||||
|
||||
(hidden_state.transpose(0, 1), all_hidden_states, all_attentions)
|
||||
(
|
||||
hidden_state.transpose(0, 1),
|
||||
all_hidden_states,
|
||||
all_attentions,
|
||||
)
|
||||
}
|
||||
}
|
@ -20,14 +20,22 @@
|
||||
//! use rust_tokenizers::RobertaTokenizer;
|
||||
//! use tch::{nn, Device};
|
||||
//! # use std::path::PathBuf;
|
||||
//! use rust_bert::Config;
|
||||
//! use rust_bert::bart::{BartConfig, BartModel};
|
||||
//! use rust_bert::resources::{Resource, download_resource, LocalResource};
|
||||
//! use rust_bert::resources::{download_resource, LocalResource, Resource};
|
||||
//! use rust_bert::Config;
|
||||
//!
|
||||
//! let config_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/config.json")});
|
||||
//! let vocab_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/vocab.txt")});
|
||||
//! let merges_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/vocab.txt")});
|
||||
//! let weights_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/model.ot")});
|
||||
//! let config_resource = Resource::Local(LocalResource {
|
||||
//! local_path: PathBuf::from("path/to/config.json"),
|
||||
//! });
|
||||
//! let vocab_resource = Resource::Local(LocalResource {
|
||||
//! local_path: PathBuf::from("path/to/vocab.txt"),
|
||||
//! });
|
||||
//! let merges_resource = Resource::Local(LocalResource {
|
||||
//! local_path: PathBuf::from("path/to/vocab.txt"),
|
||||
//! });
|
||||
//! let weights_resource = Resource::Local(LocalResource {
|
||||
//! local_path: PathBuf::from("path/to/model.ot"),
|
||||
//! });
|
||||
//! let config_path = download_resource(&config_resource)?;
|
||||
//! let vocab_path = download_resource(&vocab_resource)?;
|
||||
//! let merges_path = download_resource(&merges_resource)?;
|
||||
@ -35,7 +43,11 @@
|
||||
//!
|
||||
//! let device = Device::cuda_if_available();
|
||||
//! let mut vs = nn::VarStore::new(device);
|
||||
//! let tokenizer: RobertaTokenizer = RobertaTokenizer::from_file(vocab_path.to_str().unwrap(), merges_path.to_str().unwrap(), true);
|
||||
//! let tokenizer: RobertaTokenizer = RobertaTokenizer::from_file(
|
||||
//! vocab_path.to_str().unwrap(),
|
||||
//! merges_path.to_str().unwrap(),
|
||||
//! true,
|
||||
//! );
|
||||
//! let config = BartConfig::from_file(config_path);
|
||||
//! let bart_model = BartModel::new(&vs.root(), &config, false);
|
||||
//! vs.load(weights_path)?;
|
||||
@ -44,12 +56,15 @@
|
||||
//! # }
|
||||
//! ```
|
||||
|
||||
mod bart;
|
||||
mod attention;
|
||||
mod encoder;
|
||||
mod bart;
|
||||
mod decoder;
|
||||
mod embeddings;
|
||||
mod encoder;
|
||||
|
||||
pub use bart::{BartModelResources, BartConfigResources, BartVocabResources, BartMergesResources,
|
||||
BartConfig, Activation, BartModel, BartForSequenceClassification, BartForConditionalGeneration};
|
||||
pub use attention::LayerState;
|
||||
pub use bart::{
|
||||
Activation, BartConfig, BartConfigResources, BartForConditionalGeneration,
|
||||
BartForSequenceClassification, BartMergesResources, BartModel, BartModelResources,
|
||||
BartVocabResources,
|
||||
};
|
||||
|
@ -11,11 +11,11 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use crate::common::dropout::Dropout;
|
||||
use tch::{nn, Tensor};
|
||||
use tch::kind::Kind::Float;
|
||||
use crate::bert::bert::{Activation, BertConfig};
|
||||
use crate::common::activations::{_gelu, _relu, _mish};
|
||||
use crate::common::activations::{_gelu, _mish, _relu};
|
||||
use crate::common::dropout::Dropout;
|
||||
use tch::kind::Kind::Float;
|
||||
use tch::{nn, Tensor};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct BertSelfAttention {
|
||||
@ -30,17 +30,36 @@ pub struct BertSelfAttention {
|
||||
|
||||
impl BertSelfAttention {
|
||||
pub fn new(p: nn::Path, config: &BertConfig) -> BertSelfAttention {
|
||||
assert_eq!(config.hidden_size % config.num_attention_heads, 0, "Hidden size not a multiple of the number of attention heads");
|
||||
assert_eq!(
|
||||
config.hidden_size % config.num_attention_heads,
|
||||
0,
|
||||
"Hidden size not a multiple of the number of attention heads"
|
||||
);
|
||||
|
||||
let query = nn::linear(&p / "query", config.hidden_size, config.hidden_size, Default::default());
|
||||
let key = nn::linear(&p / "key", config.hidden_size, config.hidden_size, Default::default());
|
||||
let value = nn::linear(&p / "value", config.hidden_size, config.hidden_size, Default::default());
|
||||
let query = nn::linear(
|
||||
&p / "query",
|
||||
config.hidden_size,
|
||||
config.hidden_size,
|
||||
Default::default(),
|
||||
);
|
||||
let key = nn::linear(
|
||||
&p / "key",
|
||||
config.hidden_size,
|
||||
config.hidden_size,
|
||||
Default::default(),
|
||||
);
|
||||
let value = nn::linear(
|
||||
&p / "value",
|
||||
config.hidden_size,
|
||||
config.hidden_size,
|
||||
Default::default(),
|
||||
);
|
||||
|
||||
let dropout = Dropout::new(config.attention_probs_dropout_prob);
|
||||
let attention_head_size = config.hidden_size / config.num_attention_heads;
|
||||
let output_attentions = match config.output_attentions {
|
||||
Some(value) => value,
|
||||
None => false
|
||||
None => false,
|
||||
};
|
||||
|
||||
BertSelfAttention {
|
||||
@ -55,35 +74,44 @@ impl BertSelfAttention {
|
||||
}
|
||||
|
||||
fn split_heads(&self, x: Tensor, bs: i64, dim_per_head: i64) -> Tensor {
|
||||
x.view((bs, -1, self.num_attention_heads, dim_per_head)).transpose(1, 2)
|
||||
x.view((bs, -1, self.num_attention_heads, dim_per_head))
|
||||
.transpose(1, 2)
|
||||
}
|
||||
|
||||
fn flatten(&self, x: Tensor, bs: i64, dim_per_head: i64) -> Tensor {
|
||||
x.transpose(1, 2).contiguous().view((bs, -1, &self.num_attention_heads * dim_per_head))
|
||||
x.transpose(1, 2)
|
||||
.contiguous()
|
||||
.view((bs, -1, &self.num_attention_heads * dim_per_head))
|
||||
}
|
||||
|
||||
pub fn forward_t(&self,
|
||||
pub fn forward_t(
|
||||
&self,
|
||||
hidden_states: &Tensor,
|
||||
mask: &Option<Tensor>,
|
||||
encoder_hidden_states: &Option<Tensor>,
|
||||
encoder_mask: &Option<Tensor>,
|
||||
train: bool) -> (Tensor, Option<Tensor>) {
|
||||
train: bool,
|
||||
) -> (Tensor, Option<Tensor>) {
|
||||
let (key_layer, value_layer, mask) = match encoder_hidden_states {
|
||||
Some(encoder_hidden_state_values) => {
|
||||
(encoder_hidden_state_values.apply(&self.key),
|
||||
Some(encoder_hidden_state_values) => (
|
||||
encoder_hidden_state_values.apply(&self.key),
|
||||
encoder_hidden_state_values.apply(&self.value),
|
||||
encoder_mask)
|
||||
}
|
||||
None => {
|
||||
(hidden_states.apply(&self.key),
|
||||
encoder_mask,
|
||||
),
|
||||
None => (
|
||||
hidden_states.apply(&self.key),
|
||||
hidden_states.apply(&self.value),
|
||||
mask)
|
||||
}
|
||||
mask,
|
||||
),
|
||||
};
|
||||
|
||||
let bs = hidden_states.size()[0];
|
||||
|
||||
let query_layer = self.split_heads(hidden_states.apply(&self.query), bs, self.attention_head_size);
|
||||
let query_layer = self.split_heads(
|
||||
hidden_states.apply(&self.query),
|
||||
bs,
|
||||
self.attention_head_size,
|
||||
);
|
||||
let key_layer = self.split_heads(key_layer, bs, self.attention_head_size);
|
||||
let value_layer = self.split_heads(value_layer, bs, self.attention_head_size);
|
||||
let query_layer: Tensor = query_layer / (self.attention_head_size as f64).sqrt();
|
||||
@ -114,16 +142,32 @@ pub struct BertSelfOutput {
|
||||
|
||||
impl BertSelfOutput {
|
||||
pub fn new(p: &nn::Path, config: &BertConfig) -> BertSelfOutput {
|
||||
let linear = nn::linear(p / "dense", config.hidden_size, config.hidden_size, Default::default());
|
||||
let layer_norm_config = nn::LayerNormConfig { eps: 1e-12, ..Default::default() };
|
||||
let layer_norm = nn::layer_norm(p / "LayerNorm", vec![config.hidden_size], layer_norm_config);
|
||||
let linear = nn::linear(
|
||||
p / "dense",
|
||||
config.hidden_size,
|
||||
config.hidden_size,
|
||||
Default::default(),
|
||||
);
|
||||
let layer_norm_config = nn::LayerNormConfig {
|
||||
eps: 1e-12,
|
||||
..Default::default()
|
||||
};
|
||||
let layer_norm =
|
||||
nn::layer_norm(p / "LayerNorm", vec![config.hidden_size], layer_norm_config);
|
||||
let dropout = Dropout::new(config.hidden_dropout_prob);
|
||||
|
||||
BertSelfOutput { linear, layer_norm, dropout }
|
||||
BertSelfOutput {
|
||||
linear,
|
||||
layer_norm,
|
||||
dropout,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn forward_t(&self, hidden_states: &Tensor, input_tensor: &Tensor, train: bool) -> Tensor {
|
||||
let hidden_states: Tensor = input_tensor + hidden_states.apply(&self.linear).apply_t(&self.dropout, train);
|
||||
let hidden_states: Tensor = input_tensor
|
||||
+ hidden_states
|
||||
.apply(&self.linear)
|
||||
.apply_t(&self.dropout, train);
|
||||
hidden_states.apply(&self.layer_norm)
|
||||
}
|
||||
}
|
||||
@ -141,14 +185,21 @@ impl BertAttention {
|
||||
BertAttention { _self, output }
|
||||
}
|
||||
|
||||
pub fn forward_t(&self,
|
||||
pub fn forward_t(
|
||||
&self,
|
||||
hidden_states: &Tensor,
|
||||
mask: &Option<Tensor>,
|
||||
encoder_hidden_states: &Option<Tensor>,
|
||||
encoder_mask: &Option<Tensor>,
|
||||
train: bool) -> (Tensor, Option<Tensor>) {
|
||||
let (self_output, attention_weights) = self._self.
|
||||
forward_t(hidden_states, mask, encoder_hidden_states, encoder_mask, train);
|
||||
train: bool,
|
||||
) -> (Tensor, Option<Tensor>) {
|
||||
let (self_output, attention_weights) = self._self.forward_t(
|
||||
hidden_states,
|
||||
mask,
|
||||
encoder_hidden_states,
|
||||
encoder_mask,
|
||||
train,
|
||||
);
|
||||
|
||||
let self_output = self.output.forward_t(&self_output, hidden_states, train);
|
||||
(self_output, attention_weights)
|
||||
@ -162,11 +213,16 @@ pub struct BertIntermediate {
|
||||
|
||||
impl BertIntermediate {
|
||||
pub fn new(p: &nn::Path, config: &BertConfig) -> BertIntermediate {
|
||||
let lin = nn::linear(p / "dense", config.hidden_size, config.intermediate_size, Default::default());
|
||||
let lin = nn::linear(
|
||||
p / "dense",
|
||||
config.hidden_size,
|
||||
config.intermediate_size,
|
||||
Default::default(),
|
||||
);
|
||||
let activation = Box::new(match &config.hidden_act {
|
||||
Activation::gelu => _gelu,
|
||||
Activation::relu => _relu,
|
||||
Activation::mish => _mish
|
||||
Activation::mish => _mish,
|
||||
});
|
||||
BertIntermediate { lin, activation }
|
||||
}
|
||||
@ -184,17 +240,30 @@ pub struct BertOutput {
|
||||
|
||||
impl BertOutput {
|
||||
pub fn new(p: &nn::Path, config: &BertConfig) -> BertOutput {
|
||||
let lin = nn::linear(p / "dense", config.intermediate_size, config.hidden_size, Default::default());
|
||||
let layer_norm_config = nn::LayerNormConfig { eps: 1e-12, ..Default::default() };
|
||||
let layer_norm = nn::layer_norm(p / "LayerNorm", vec![config.hidden_size], layer_norm_config);
|
||||
let lin = nn::linear(
|
||||
p / "dense",
|
||||
config.intermediate_size,
|
||||
config.hidden_size,
|
||||
Default::default(),
|
||||
);
|
||||
let layer_norm_config = nn::LayerNormConfig {
|
||||
eps: 1e-12,
|
||||
..Default::default()
|
||||
};
|
||||
let layer_norm =
|
||||
nn::layer_norm(p / "LayerNorm", vec![config.hidden_size], layer_norm_config);
|
||||
let dropout = Dropout::new(config.hidden_dropout_prob);
|
||||
|
||||
BertOutput { lin, layer_norm, dropout }
|
||||
BertOutput {
|
||||
lin,
|
||||
layer_norm,
|
||||
dropout,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn forward_t(&self, hidden_states: &Tensor, input_tensor: &Tensor, train: bool) -> Tensor {
|
||||
let hidden_states: Tensor = input_tensor + hidden_states.apply(&self.lin).apply_t(&self.dropout, train);
|
||||
let hidden_states: Tensor =
|
||||
input_tensor + hidden_states.apply(&self.lin).apply_t(&self.dropout, train);
|
||||
hidden_states.apply(&self.layer_norm)
|
||||
}
|
||||
}
|
||||
|
||||
|
437
src/bert/bert.rs
437
src/bert/bert.rs
@ -11,17 +11,17 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use crate::bert::embeddings::{BertEmbeddings, BertEmbedding};
|
||||
use crate::bert::embeddings::{BertEmbedding, BertEmbeddings};
|
||||
use crate::bert::encoder::{BertEncoder, BertPooler};
|
||||
use tch::{nn, Tensor, Kind};
|
||||
use tch::kind::Kind::Float;
|
||||
use crate::common::activations::{_gelu, _relu, _mish};
|
||||
use crate::common::linear::{LinearNoBias, linear_no_bias};
|
||||
use tch::nn::Init;
|
||||
use crate::common::activations::{_gelu, _mish, _relu};
|
||||
use crate::common::dropout::Dropout;
|
||||
use std::collections::HashMap;
|
||||
use crate::common::linear::{linear_no_bias, LinearNoBias};
|
||||
use crate::Config;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use tch::kind::Kind::Float;
|
||||
use tch::nn::Init;
|
||||
use tch::{nn, Kind, Tensor};
|
||||
|
||||
/// # BERT Pretrained model weight files
|
||||
pub struct BertModelResources;
|
||||
@ -34,23 +34,41 @@ pub struct BertVocabResources;
|
||||
|
||||
impl BertModelResources {
|
||||
/// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/bert. Modified with conversion to C-array format.
|
||||
pub const BERT: (&'static str, &'static str) = ("bert/model.ot", "https://cdn.huggingface.co/bert-base-uncased-rust_model.ot");
|
||||
pub const BERT: (&'static str, &'static str) = (
|
||||
"bert/model.ot",
|
||||
"https://cdn.huggingface.co/bert-base-uncased-rust_model.ot",
|
||||
);
|
||||
/// Shared under MIT license by the MDZ Digital Library team at the Bavarian State Library at https://github.com/dbmdz/berts. Modified with conversion to C-array format.
|
||||
pub const BERT_NER: (&'static str, &'static str) = ("bert-ner/model.ot", "https://cdn.huggingface.co/dbmdz/bert-large-cased-finetuned-conll03-english/rust_model.ot");
|
||||
pub const BERT_NER: (&'static str, &'static str) = (
|
||||
"bert-ner/model.ot",
|
||||
"https://cdn.huggingface.co/dbmdz/bert-large-cased-finetuned-conll03-english/rust_model.ot",
|
||||
);
|
||||
}
|
||||
|
||||
impl BertConfigResources {
|
||||
/// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/bert. Modified with conversion to C-array format.
|
||||
pub const BERT: (&'static str, &'static str) = ("bert/config.json", "https://cdn.huggingface.co/bert-base-uncased-config.json");
|
||||
pub const BERT: (&'static str, &'static str) = (
|
||||
"bert/config.json",
|
||||
"https://cdn.huggingface.co/bert-base-uncased-config.json",
|
||||
);
|
||||
/// Shared under MIT license by the MDZ Digital Library team at the Bavarian State Library at https://github.com/dbmdz/berts. Modified with conversion to C-array format.
|
||||
pub const BERT_NER: (&'static str, &'static str) = ("bert-ner/config.json", "https://cdn.huggingface.co/dbmdz/bert-large-cased-finetuned-conll03-english/config.json");
|
||||
pub const BERT_NER: (&'static str, &'static str) = (
|
||||
"bert-ner/config.json",
|
||||
"https://cdn.huggingface.co/dbmdz/bert-large-cased-finetuned-conll03-english/config.json",
|
||||
);
|
||||
}
|
||||
|
||||
impl BertVocabResources {
|
||||
/// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/bert. Modified with conversion to C-array format.
|
||||
pub const BERT: (&'static str, &'static str) = ("bert/vocab.txt", "https://cdn.huggingface.co/bert-base-uncased-vocab.txt");
|
||||
pub const BERT: (&'static str, &'static str) = (
|
||||
"bert/vocab.txt",
|
||||
"https://cdn.huggingface.co/bert-base-uncased-vocab.txt",
|
||||
);
|
||||
/// Shared under MIT license by the MDZ Digital Library team at the Bavarian State Library at https://github.com/dbmdz/berts. Modified with conversion to C-array format.
|
||||
pub const BERT_NER: (&'static str, &'static str) = ("bert-ner/vocab.txt", "https://cdn.huggingface.co/dbmdz/bert-large-cased-finetuned-conll03-english/vocab.txt");
|
||||
pub const BERT_NER: (&'static str, &'static str) = (
|
||||
"bert-ner/vocab.txt",
|
||||
"https://cdn.huggingface.co/dbmdz/bert-large-cased-finetuned-conll03-english/vocab.txt",
|
||||
);
|
||||
}
|
||||
|
||||
#[allow(non_camel_case_types)]
|
||||
@ -117,10 +135,10 @@ impl<T: BertEmbedding> BertModel<T> {
|
||||
/// # Example
|
||||
///
|
||||
/// ```no_run
|
||||
/// use rust_bert::bert::{BertModel, BertConfig, BertEmbeddings};
|
||||
/// use tch::{nn, Device};
|
||||
/// use rust_bert::bert::{BertConfig, BertEmbeddings, BertModel};
|
||||
/// use rust_bert::Config;
|
||||
/// use std::path::Path;
|
||||
/// use tch::{nn, Device};
|
||||
///
|
||||
/// let config_path = Path::new("path/to/config.json");
|
||||
/// let device = Device::Cpu;
|
||||
@ -128,17 +146,21 @@ impl<T: BertEmbedding> BertModel<T> {
|
||||
/// let config = BertConfig::from_file(config_path);
|
||||
/// let bert: BertModel<BertEmbeddings> = BertModel::new(&(&p.root() / "bert"), &config);
|
||||
/// ```
|
||||
///
|
||||
pub fn new(p: &nn::Path, config: &BertConfig) -> BertModel<T> {
|
||||
let is_decoder = match config.is_decoder {
|
||||
Some(value) => value,
|
||||
None => false
|
||||
None => false,
|
||||
};
|
||||
let embeddings = T::new(&(p / "embeddings"), config);
|
||||
let encoder = BertEncoder::new(&(p / "encoder"), config);
|
||||
let pooler = BertPooler::new(&(p / "pooler"), config);
|
||||
|
||||
BertModel { embeddings, encoder, pooler, is_decoder }
|
||||
BertModel {
|
||||
embeddings,
|
||||
encoder,
|
||||
pooler,
|
||||
is_decoder,
|
||||
}
|
||||
}
|
||||
|
||||
/// Forward pass through the model
|
||||
@ -178,23 +200,26 @@ impl<T: BertEmbedding> BertModel<T> {
|
||||
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
|
||||
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
|
||||
/// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
|
||||
/// let position_ids = Tensor::arange(sequence_length, (Int64, device)).expand(&[batch_size, sequence_length], true);
|
||||
/// let position_ids = Tensor::arange(sequence_length, (Int64, device))
|
||||
/// .expand(&[batch_size, sequence_length], true);
|
||||
///
|
||||
/// let (output, pooled_output, all_hidden_states, all_attentions) = no_grad(|| {
|
||||
/// bert_model
|
||||
/// .forward_t(Some(input_tensor),
|
||||
/// .forward_t(
|
||||
/// Some(input_tensor),
|
||||
/// Some(mask),
|
||||
/// Some(token_type_ids),
|
||||
/// Some(position_ids),
|
||||
/// None,
|
||||
/// &None,
|
||||
/// &None,
|
||||
/// false).unwrap()
|
||||
/// false,
|
||||
/// )
|
||||
/// .unwrap()
|
||||
/// });
|
||||
///
|
||||
/// ```
|
||||
///
|
||||
pub fn forward_t(&self,
|
||||
pub fn forward_t(
|
||||
&self,
|
||||
input_ids: Option<Tensor>,
|
||||
mask: Option<Tensor>,
|
||||
token_type_ids: Option<Tensor>,
|
||||
@ -202,74 +227,109 @@ impl<T: BertEmbedding> BertModel<T> {
|
||||
input_embeds: Option<Tensor>,
|
||||
encoder_hidden_states: &Option<Tensor>,
|
||||
encoder_mask: &Option<Tensor>,
|
||||
train: bool)
|
||||
-> Result<(Tensor, Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
|
||||
train: bool,
|
||||
) -> Result<(Tensor, Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
|
||||
let (input_shape, device) = match &input_ids {
|
||||
Some(input_value) => match &input_embeds {
|
||||
Some(_) => { return Err("Only one of input ids or input embeddings may be set"); }
|
||||
None => (input_value.size(), input_value.device())
|
||||
Some(_) => {
|
||||
return Err("Only one of input ids or input embeddings may be set");
|
||||
}
|
||||
None => (input_value.size(), input_value.device()),
|
||||
},
|
||||
None => match &input_embeds {
|
||||
Some(embeds) => (vec!(embeds.size()[0], embeds.size()[1]), embeds.device()),
|
||||
None => { return Err("At least one of input ids or input embeddings must be set"); }
|
||||
Some(embeds) => (vec![embeds.size()[0], embeds.size()[1]], embeds.device()),
|
||||
None => {
|
||||
return Err("At least one of input ids or input embeddings must be set");
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
let mask = match mask {
|
||||
Some(value) => value,
|
||||
None => Tensor::ones(&input_shape, (Kind::Int64, device))
|
||||
None => Tensor::ones(&input_shape, (Kind::Int64, device)),
|
||||
};
|
||||
|
||||
let extended_attention_mask = match mask.dim() {
|
||||
3 => mask.unsqueeze(1),
|
||||
2 => if self.is_decoder {
|
||||
2 => {
|
||||
if self.is_decoder {
|
||||
let seq_ids = Tensor::arange(input_shape[1], (Float, device));
|
||||
let causal_mask = seq_ids.unsqueeze(0).unsqueeze(0).repeat(&vec!(input_shape[0], input_shape[1], 1));
|
||||
let causal_mask = seq_ids.unsqueeze(0).unsqueeze(0).repeat(&vec![
|
||||
input_shape[0],
|
||||
input_shape[1],
|
||||
1,
|
||||
]);
|
||||
let causal_mask = causal_mask.le1(&seq_ids.unsqueeze(0).unsqueeze(-1));
|
||||
causal_mask * mask.unsqueeze(1).unsqueeze(1)
|
||||
} else {
|
||||
mask.unsqueeze(1).unsqueeze(1)
|
||||
},
|
||||
_ => { return Err("Invalid attention mask dimension, must be 2 or 3"); }
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
return Err("Invalid attention mask dimension, must be 2 or 3");
|
||||
}
|
||||
};
|
||||
|
||||
let extended_attention_mask: Tensor = (extended_attention_mask.ones_like() - extended_attention_mask) * -10000.0;
|
||||
let extended_attention_mask: Tensor =
|
||||
(extended_attention_mask.ones_like() - extended_attention_mask) * -10000.0;
|
||||
|
||||
let encoder_extended_attention_mask: Option<Tensor> = if self.is_decoder & encoder_hidden_states.is_some() {
|
||||
let encoder_extended_attention_mask: Option<Tensor> =
|
||||
if self.is_decoder & encoder_hidden_states.is_some() {
|
||||
let encoder_hidden_states = encoder_hidden_states.as_ref().unwrap();
|
||||
let encoder_hidden_states_shape = encoder_hidden_states.size();
|
||||
let encoder_mask = match encoder_mask {
|
||||
Some(value) => value.copy(),
|
||||
None => Tensor::ones(&[encoder_hidden_states_shape[0], encoder_hidden_states_shape[1]], (Kind::Int64, device))
|
||||
None => Tensor::ones(
|
||||
&[
|
||||
encoder_hidden_states_shape[0],
|
||||
encoder_hidden_states_shape[1],
|
||||
],
|
||||
(Kind::Int64, device),
|
||||
),
|
||||
};
|
||||
match encoder_mask.dim() {
|
||||
2 => Some(encoder_mask.unsqueeze(1).unsqueeze(1)),
|
||||
3 => Some(encoder_mask.unsqueeze(1)),
|
||||
_ => { return Err("Invalid encoder attention mask dimension, must be 2 or 3"); }
|
||||
_ => {
|
||||
return Err("Invalid encoder attention mask dimension, must be 2 or 3");
|
||||
}
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let embedding_output = match self.embeddings.forward_t(input_ids, token_type_ids, position_ids, input_embeds, train) {
|
||||
let embedding_output = match self.embeddings.forward_t(
|
||||
input_ids,
|
||||
token_type_ids,
|
||||
position_ids,
|
||||
input_embeds,
|
||||
train,
|
||||
) {
|
||||
Ok(value) => value,
|
||||
Err(e) => { return Err(e); }
|
||||
Err(e) => {
|
||||
return Err(e);
|
||||
}
|
||||
};
|
||||
|
||||
let (hidden_state, all_hidden_states, all_attentions) =
|
||||
self.encoder.forward_t(&embedding_output,
|
||||
let (hidden_state, all_hidden_states, all_attentions) = self.encoder.forward_t(
|
||||
&embedding_output,
|
||||
&Some(extended_attention_mask),
|
||||
encoder_hidden_states,
|
||||
&encoder_extended_attention_mask,
|
||||
train);
|
||||
train,
|
||||
);
|
||||
|
||||
let pooled_output = self.pooler.forward(&hidden_state);
|
||||
|
||||
Ok((hidden_state, pooled_output, all_hidden_states, all_attentions))
|
||||
Ok((
|
||||
hidden_state,
|
||||
pooled_output,
|
||||
all_hidden_states,
|
||||
all_attentions,
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
pub struct BertPredictionHeadTransform {
|
||||
dense: nn::Linear,
|
||||
activation: Box<dyn Fn(&Tensor) -> Tensor>,
|
||||
@ -278,16 +338,29 @@ pub struct BertPredictionHeadTransform {
|
||||
|
||||
impl BertPredictionHeadTransform {
|
||||
pub fn new(p: &nn::Path, config: &BertConfig) -> BertPredictionHeadTransform {
|
||||
let dense = nn::linear(p / "dense", config.hidden_size, config.hidden_size, Default::default());
|
||||
let dense = nn::linear(
|
||||
p / "dense",
|
||||
config.hidden_size,
|
||||
config.hidden_size,
|
||||
Default::default(),
|
||||
);
|
||||
let activation = Box::new(match &config.hidden_act {
|
||||
Activation::gelu => _gelu,
|
||||
Activation::relu => _relu,
|
||||
Activation::mish => _mish
|
||||
Activation::mish => _mish,
|
||||
});
|
||||
let layer_norm_config = nn::LayerNormConfig { eps: 1e-12, ..Default::default() };
|
||||
let layer_norm = nn::layer_norm(p / "LayerNorm", vec![config.hidden_size], layer_norm_config);
|
||||
let layer_norm_config = nn::LayerNormConfig {
|
||||
eps: 1e-12,
|
||||
..Default::default()
|
||||
};
|
||||
let layer_norm =
|
||||
nn::layer_norm(p / "LayerNorm", vec![config.hidden_size], layer_norm_config);
|
||||
|
||||
BertPredictionHeadTransform { dense, activation, layer_norm }
|
||||
BertPredictionHeadTransform {
|
||||
dense,
|
||||
activation,
|
||||
layer_norm,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn forward(&self, hidden_states: &Tensor) -> Tensor {
|
||||
@ -305,10 +378,19 @@ impl BertLMPredictionHead {
|
||||
pub fn new(p: &nn::Path, config: &BertConfig) -> BertLMPredictionHead {
|
||||
let p = &(p / "predictions");
|
||||
let transform = BertPredictionHeadTransform::new(&(p / "transform"), config);
|
||||
let decoder = linear_no_bias(&(p / "decoder"), config.hidden_size, config.vocab_size, Default::default());
|
||||
let decoder = linear_no_bias(
|
||||
&(p / "decoder"),
|
||||
config.hidden_size,
|
||||
config.vocab_size,
|
||||
Default::default(),
|
||||
);
|
||||
let bias = p.var("bias", &[config.vocab_size], Init::KaimingUniform);
|
||||
|
||||
BertLMPredictionHead { transform, decoder, bias }
|
||||
BertLMPredictionHead {
|
||||
transform,
|
||||
decoder,
|
||||
bias,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn forward(&self, hidden_states: &Tensor) -> Tensor {
|
||||
@ -338,9 +420,9 @@ impl BertForMaskedLM {
|
||||
///
|
||||
/// ```no_run
|
||||
/// use rust_bert::bert::{BertConfig, BertForMaskedLM};
|
||||
/// use tch::{nn, Device};
|
||||
/// use rust_bert::Config;
|
||||
/// use std::path::Path;
|
||||
/// use tch::{nn, Device};
|
||||
///
|
||||
/// let config_path = Path::new("path/to/config.json");
|
||||
/// let device = Device::Cpu;
|
||||
@ -348,7 +430,6 @@ impl BertForMaskedLM {
|
||||
/// let config = BertConfig::from_file(config_path);
|
||||
/// let bert = BertForMaskedLM::new(&(&p.root() / "bert"), &config);
|
||||
/// ```
|
||||
///
|
||||
pub fn new(p: &nn::Path, config: &BertConfig) -> BertForMaskedLM {
|
||||
let bert = BertModel::new(&(p / "bert"), config);
|
||||
let cls = BertLMPredictionHead::new(&(p / "cls"), config);
|
||||
@ -392,23 +473,24 @@ impl BertForMaskedLM {
|
||||
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
|
||||
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
|
||||
/// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
|
||||
/// let position_ids = Tensor::arange(sequence_length, (Int64, device)).expand(&[batch_size, sequence_length], true);
|
||||
/// let position_ids = Tensor::arange(sequence_length, (Int64, device))
|
||||
/// .expand(&[batch_size, sequence_length], true);
|
||||
///
|
||||
/// let (output, all_hidden_states, all_attentions) = no_grad(|| {
|
||||
/// bert_model
|
||||
/// .forward_t(Some(input_tensor),
|
||||
/// bert_model.forward_t(
|
||||
/// Some(input_tensor),
|
||||
/// Some(mask),
|
||||
/// Some(token_type_ids),
|
||||
/// Some(position_ids),
|
||||
/// None,
|
||||
/// &None,
|
||||
/// &None,
|
||||
/// false)
|
||||
/// false,
|
||||
/// )
|
||||
/// });
|
||||
///
|
||||
/// ```
|
||||
///
|
||||
pub fn forward_t(&self,
|
||||
pub fn forward_t(
|
||||
&self,
|
||||
input_ids: Option<Tensor>,
|
||||
mask: Option<Tensor>,
|
||||
token_type_ids: Option<Tensor>,
|
||||
@ -416,9 +498,21 @@ impl BertForMaskedLM {
|
||||
input_embeds: Option<Tensor>,
|
||||
encoder_hidden_states: &Option<Tensor>,
|
||||
encoder_mask: &Option<Tensor>,
|
||||
train: bool) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
|
||||
let (hidden_state, _, all_hidden_states, all_attentions) = self.bert.forward_t(input_ids, mask, token_type_ids, position_ids,
|
||||
input_embeds, encoder_hidden_states, encoder_mask, train).unwrap();
|
||||
train: bool,
|
||||
) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
|
||||
let (hidden_state, _, all_hidden_states, all_attentions) = self
|
||||
.bert
|
||||
.forward_t(
|
||||
input_ids,
|
||||
mask,
|
||||
token_type_ids,
|
||||
position_ids,
|
||||
input_embeds,
|
||||
encoder_hidden_states,
|
||||
encoder_mask,
|
||||
train,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let prediction_scores = self.cls.forward(&hidden_state);
|
||||
(prediction_scores, all_hidden_states, all_attentions)
|
||||
@ -448,9 +542,9 @@ impl BertForSequenceClassification {
|
||||
///
|
||||
/// ```no_run
|
||||
/// use rust_bert::bert::{BertConfig, BertForSequenceClassification};
|
||||
/// use tch::{nn, Device};
|
||||
/// use rust_bert::Config;
|
||||
/// use std::path::Path;
|
||||
/// use tch::{nn, Device};
|
||||
///
|
||||
/// let config_path = Path::new("path/to/config.json");
|
||||
/// let device = Device::Cpu;
|
||||
@ -461,10 +555,23 @@ impl BertForSequenceClassification {
|
||||
pub fn new(p: &nn::Path, config: &BertConfig) -> BertForSequenceClassification {
|
||||
let bert = BertModel::new(&(p / "bert"), config);
|
||||
let dropout = Dropout::new(config.hidden_dropout_prob);
|
||||
let num_labels = config.id2label.as_ref().expect("num_labels not provided in configuration").len() as i64;
|
||||
let classifier = nn::linear(p / "classifier", config.hidden_size, num_labels, Default::default());
|
||||
let num_labels = config
|
||||
.id2label
|
||||
.as_ref()
|
||||
.expect("num_labels not provided in configuration")
|
||||
.len() as i64;
|
||||
let classifier = nn::linear(
|
||||
p / "classifier",
|
||||
config.hidden_size,
|
||||
num_labels,
|
||||
Default::default(),
|
||||
);
|
||||
|
||||
BertForSequenceClassification { bert, dropout, classifier }
|
||||
BertForSequenceClassification {
|
||||
bert,
|
||||
dropout,
|
||||
classifier,
|
||||
}
|
||||
}
|
||||
|
||||
/// Forward pass through the model
|
||||
@ -501,31 +608,46 @@ impl BertForSequenceClassification {
|
||||
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
|
||||
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
|
||||
/// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
|
||||
/// let position_ids = Tensor::arange(sequence_length, (Int64, device)).expand(&[batch_size, sequence_length], true);
|
||||
/// let position_ids = Tensor::arange(sequence_length, (Int64, device))
|
||||
/// .expand(&[batch_size, sequence_length], true);
|
||||
///
|
||||
/// let (labels, all_hidden_states, all_attentions) = no_grad(|| {
|
||||
/// bert_model
|
||||
/// .forward_t(Some(input_tensor),
|
||||
/// bert_model.forward_t(
|
||||
/// Some(input_tensor),
|
||||
/// Some(mask),
|
||||
/// Some(token_type_ids),
|
||||
/// Some(position_ids),
|
||||
/// None,
|
||||
/// false)
|
||||
/// false,
|
||||
/// )
|
||||
/// });
|
||||
///
|
||||
/// ```
|
||||
///
|
||||
pub fn forward_t(&self,
|
||||
pub fn forward_t(
|
||||
&self,
|
||||
input_ids: Option<Tensor>,
|
||||
mask: Option<Tensor>,
|
||||
token_type_ids: Option<Tensor>,
|
||||
position_ids: Option<Tensor>,
|
||||
input_embeds: Option<Tensor>,
|
||||
train: bool) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
|
||||
let (_, pooled_output, all_hidden_states, all_attentions) = self.bert.forward_t(input_ids, mask, token_type_ids, position_ids,
|
||||
input_embeds, &None, &None, train).unwrap();
|
||||
train: bool,
|
||||
) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
|
||||
let (_, pooled_output, all_hidden_states, all_attentions) = self
|
||||
.bert
|
||||
.forward_t(
|
||||
input_ids,
|
||||
mask,
|
||||
token_type_ids,
|
||||
position_ids,
|
||||
input_embeds,
|
||||
&None,
|
||||
&None,
|
||||
train,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let output = pooled_output.apply_t(&self.dropout, train).apply(&self.classifier);
|
||||
let output = pooled_output
|
||||
.apply_t(&self.dropout, train)
|
||||
.apply(&self.classifier);
|
||||
(output, all_hidden_states, all_attentions)
|
||||
}
|
||||
}
|
||||
@ -555,9 +677,9 @@ impl BertForMultipleChoice {
|
||||
///
|
||||
/// ```no_run
|
||||
/// use rust_bert::bert::{BertConfig, BertForMultipleChoice};
|
||||
/// use tch::{nn, Device};
|
||||
/// use rust_bert::Config;
|
||||
/// use std::path::Path;
|
||||
/// use tch::{nn, Device};
|
||||
///
|
||||
/// let config_path = Path::new("path/to/config.json");
|
||||
/// let device = Device::Cpu;
|
||||
@ -570,7 +692,11 @@ impl BertForMultipleChoice {
|
||||
let dropout = Dropout::new(config.hidden_dropout_prob);
|
||||
let classifier = nn::linear(p / "classifier", config.hidden_size, 1, Default::default());
|
||||
|
||||
BertForMultipleChoice { bert, dropout, classifier }
|
||||
BertForMultipleChoice {
|
||||
bert,
|
||||
dropout,
|
||||
classifier,
|
||||
}
|
||||
}
|
||||
|
||||
/// Forward pass through the model
|
||||
@ -606,46 +732,61 @@ impl BertForMultipleChoice {
|
||||
/// let input_tensor = Tensor::rand(&[num_choices, sequence_length], (Int64, device));
|
||||
/// let mask = Tensor::zeros(&[num_choices, sequence_length], (Int64, device));
|
||||
/// let token_type_ids = Tensor::zeros(&[num_choices, sequence_length], (Int64, device));
|
||||
/// let position_ids = Tensor::arange(sequence_length, (Int64, device)).expand(&[num_choices, sequence_length], true);
|
||||
/// let position_ids = Tensor::arange(sequence_length, (Int64, device))
|
||||
/// .expand(&[num_choices, sequence_length], true);
|
||||
///
|
||||
/// let (choices, all_hidden_states, all_attentions) = no_grad(|| {
|
||||
/// bert_model
|
||||
/// .forward_t(input_tensor,
|
||||
/// bert_model.forward_t(
|
||||
/// input_tensor,
|
||||
/// Some(mask),
|
||||
/// Some(token_type_ids),
|
||||
/// Some(position_ids),
|
||||
/// false)
|
||||
/// false,
|
||||
/// )
|
||||
/// });
|
||||
///
|
||||
/// ```
|
||||
///
|
||||
pub fn forward_t(&self,
|
||||
pub fn forward_t(
|
||||
&self,
|
||||
input_ids: Tensor,
|
||||
mask: Option<Tensor>,
|
||||
token_type_ids: Option<Tensor>,
|
||||
position_ids: Option<Tensor>,
|
||||
train: bool) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
|
||||
train: bool,
|
||||
) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
|
||||
let num_choices = input_ids.size()[1];
|
||||
|
||||
let input_ids = input_ids.view((-1, *input_ids.size().last().unwrap()));
|
||||
let mask = match mask {
|
||||
Some(value) => Some(value.view((-1, *value.size().last().unwrap()))),
|
||||
None => None
|
||||
None => None,
|
||||
};
|
||||
let token_type_ids = match token_type_ids {
|
||||
Some(value) => Some(value.view((-1, *value.size().last().unwrap()))),
|
||||
None => None
|
||||
None => None,
|
||||
};
|
||||
let position_ids = match position_ids {
|
||||
Some(value) => Some(value.view((-1, *value.size().last().unwrap()))),
|
||||
None => None
|
||||
None => None,
|
||||
};
|
||||
|
||||
let (_, pooled_output, all_hidden_states, all_attentions) = self
|
||||
.bert
|
||||
.forward_t(
|
||||
Some(input_ids),
|
||||
mask,
|
||||
token_type_ids,
|
||||
position_ids,
|
||||
None,
|
||||
&None,
|
||||
&None,
|
||||
train,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let (_, pooled_output, all_hidden_states, all_attentions) = self.bert.forward_t(Some(input_ids), mask, token_type_ids, position_ids,
|
||||
None, &None, &None, train).unwrap();
|
||||
|
||||
let output = pooled_output.apply_t(&self.dropout, train).apply(&self.classifier).view((-1, num_choices));
|
||||
let output = pooled_output
|
||||
.apply_t(&self.dropout, train)
|
||||
.apply(&self.classifier)
|
||||
.view((-1, num_choices));
|
||||
(output, all_hidden_states, all_attentions)
|
||||
}
|
||||
}
|
||||
@ -674,9 +815,9 @@ impl BertForTokenClassification {
|
||||
///
|
||||
/// ```no_run
|
||||
/// use rust_bert::bert::{BertConfig, BertForTokenClassification};
|
||||
/// use tch::{nn, Device};
|
||||
/// use rust_bert::Config;
|
||||
/// use std::path::Path;
|
||||
/// use tch::{nn, Device};
|
||||
///
|
||||
/// let config_path = Path::new("path/to/config.json");
|
||||
/// let device = Device::Cpu;
|
||||
@ -687,10 +828,23 @@ impl BertForTokenClassification {
|
||||
pub fn new(p: &nn::Path, config: &BertConfig) -> BertForTokenClassification {
|
||||
let bert = BertModel::new(&(p / "bert"), config);
|
||||
let dropout = Dropout::new(config.hidden_dropout_prob);
|
||||
let num_labels = config.id2label.as_ref().expect("num_labels not provided in configuration").len() as i64;
|
||||
let classifier = nn::linear(p / "classifier", config.hidden_size, num_labels, Default::default());
|
||||
let num_labels = config
|
||||
.id2label
|
||||
.as_ref()
|
||||
.expect("num_labels not provided in configuration")
|
||||
.len() as i64;
|
||||
let classifier = nn::linear(
|
||||
p / "classifier",
|
||||
config.hidden_size,
|
||||
num_labels,
|
||||
Default::default(),
|
||||
);
|
||||
|
||||
BertForTokenClassification { bert, dropout, classifier }
|
||||
BertForTokenClassification {
|
||||
bert,
|
||||
dropout,
|
||||
classifier,
|
||||
}
|
||||
}
|
||||
|
||||
/// Forward pass through the model
|
||||
@ -727,31 +881,46 @@ impl BertForTokenClassification {
|
||||
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
|
||||
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
|
||||
/// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
|
||||
/// let position_ids = Tensor::arange(sequence_length, (Int64, device)).expand(&[batch_size, sequence_length], true);
|
||||
/// let position_ids = Tensor::arange(sequence_length, (Int64, device))
|
||||
/// .expand(&[batch_size, sequence_length], true);
|
||||
///
|
||||
/// let (token_labels, all_hidden_states, all_attentions) = no_grad(|| {
|
||||
/// bert_model
|
||||
/// .forward_t(Some(input_tensor),
|
||||
/// bert_model.forward_t(
|
||||
/// Some(input_tensor),
|
||||
/// Some(mask),
|
||||
/// Some(token_type_ids),
|
||||
/// Some(position_ids),
|
||||
/// None,
|
||||
/// false)
|
||||
/// false,
|
||||
/// )
|
||||
/// });
|
||||
///
|
||||
/// ```
|
||||
///
|
||||
pub fn forward_t(&self,
|
||||
pub fn forward_t(
|
||||
&self,
|
||||
input_ids: Option<Tensor>,
|
||||
mask: Option<Tensor>,
|
||||
token_type_ids: Option<Tensor>,
|
||||
position_ids: Option<Tensor>,
|
||||
input_embeds: Option<Tensor>,
|
||||
train: bool) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
|
||||
let (hidden_state, _, all_hidden_states, all_attentions) = self.bert.forward_t(input_ids, mask, token_type_ids, position_ids,
|
||||
input_embeds, &None, &None, train).unwrap();
|
||||
train: bool,
|
||||
) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
|
||||
let (hidden_state, _, all_hidden_states, all_attentions) = self
|
||||
.bert
|
||||
.forward_t(
|
||||
input_ids,
|
||||
mask,
|
||||
token_type_ids,
|
||||
position_ids,
|
||||
input_embeds,
|
||||
&None,
|
||||
&None,
|
||||
train,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let sequence_output = hidden_state.apply_t(&self.dropout, train).apply(&self.classifier);
|
||||
let sequence_output = hidden_state
|
||||
.apply_t(&self.dropout, train)
|
||||
.apply(&self.classifier);
|
||||
(sequence_output, all_hidden_states, all_attentions)
|
||||
}
|
||||
}
|
||||
@ -780,9 +949,9 @@ impl BertForQuestionAnswering {
|
||||
///
|
||||
/// ```no_run
|
||||
/// use rust_bert::bert::{BertConfig, BertForQuestionAnswering};
|
||||
/// use tch::{nn, Device};
|
||||
/// use rust_bert::Config;
|
||||
/// use std::path::Path;
|
||||
/// use tch::{nn, Device};
|
||||
///
|
||||
/// let config_path = Path::new("path/to/config.json");
|
||||
/// let device = Device::Cpu;
|
||||
@ -793,7 +962,12 @@ impl BertForQuestionAnswering {
|
||||
pub fn new(p: &nn::Path, config: &BertConfig) -> BertForQuestionAnswering {
|
||||
let bert = BertModel::new(&(p / "bert"), config);
|
||||
let num_labels = 2;
|
||||
let qa_outputs = nn::linear(p / "qa_outputs", config.hidden_size, num_labels, Default::default());
|
||||
let qa_outputs = nn::linear(
|
||||
p / "qa_outputs",
|
||||
config.hidden_size,
|
||||
num_labels,
|
||||
Default::default(),
|
||||
);
|
||||
|
||||
BertForQuestionAnswering { bert, qa_outputs }
|
||||
}
|
||||
@ -833,29 +1007,42 @@ impl BertForQuestionAnswering {
|
||||
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
|
||||
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
|
||||
/// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
|
||||
/// let position_ids = Tensor::arange(sequence_length, (Int64, device)).expand(&[batch_size, sequence_length], true);
|
||||
/// let position_ids = Tensor::arange(sequence_length, (Int64, device))
|
||||
/// .expand(&[batch_size, sequence_length], true);
|
||||
///
|
||||
/// let (start_scores, end_scores, all_hidden_states, all_attentions) = no_grad(|| {
|
||||
/// bert_model
|
||||
/// .forward_t(Some(input_tensor),
|
||||
/// bert_model.forward_t(
|
||||
/// Some(input_tensor),
|
||||
/// Some(mask),
|
||||
/// Some(token_type_ids),
|
||||
/// Some(position_ids),
|
||||
/// None,
|
||||
/// false)
|
||||
/// false,
|
||||
/// )
|
||||
/// });
|
||||
///
|
||||
/// ```
|
||||
///
|
||||
pub fn forward_t(&self,
|
||||
pub fn forward_t(
|
||||
&self,
|
||||
input_ids: Option<Tensor>,
|
||||
mask: Option<Tensor>,
|
||||
token_type_ids: Option<Tensor>,
|
||||
position_ids: Option<Tensor>,
|
||||
input_embeds: Option<Tensor>,
|
||||
train: bool) -> (Tensor, Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
|
||||
let (hidden_state, _, all_hidden_states, all_attentions) = self.bert.forward_t(input_ids, mask, token_type_ids, position_ids,
|
||||
input_embeds, &None, &None, train).unwrap();
|
||||
train: bool,
|
||||
) -> (Tensor, Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
|
||||
let (hidden_state, _, all_hidden_states, all_attentions) = self
|
||||
.bert
|
||||
.forward_t(
|
||||
input_ids,
|
||||
mask,
|
||||
token_type_ids,
|
||||
position_ids,
|
||||
input_embeds,
|
||||
&None,
|
||||
&None,
|
||||
train,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let sequence_output = hidden_state.apply(&self.qa_outputs);
|
||||
let logits = sequence_output.split(1, -1);
|
||||
|
@ -11,22 +11,24 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use tch::{nn, Tensor, Kind};
|
||||
use tch::nn::{EmbeddingConfig, embedding};
|
||||
use crate::common::dropout::Dropout;
|
||||
use crate::bert::bert::BertConfig;
|
||||
use crate::common::dropout::Dropout;
|
||||
use tch::nn::{embedding, EmbeddingConfig};
|
||||
use tch::{nn, Kind, Tensor};
|
||||
|
||||
/// # BertEmbedding trait (for use in BertModel or RoBERTaModel)
|
||||
/// Defines an interface for the embedding layers in BERT-based models
|
||||
pub trait BertEmbedding {
|
||||
fn new(p: &nn::Path, config: &BertConfig) -> Self;
|
||||
|
||||
fn forward_t(&self,
|
||||
fn forward_t(
|
||||
&self,
|
||||
input_ids: Option<Tensor>,
|
||||
token_type_ids: Option<Tensor>,
|
||||
position_ids: Option<Tensor>,
|
||||
input_embeds: Option<Tensor>,
|
||||
train: bool) -> Result<Tensor, &'static str>;
|
||||
train: bool,
|
||||
) -> Result<Tensor, &'static str>;
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
@ -51,10 +53,10 @@ impl BertEmbedding for BertEmbeddings {
|
||||
/// # Example
|
||||
///
|
||||
/// ```no_run
|
||||
/// use rust_bert::bert::{BertConfig, BertEmbeddings, BertEmbedding};
|
||||
/// use tch::{nn, Device};
|
||||
/// use rust_bert::bert::{BertConfig, BertEmbedding, BertEmbeddings};
|
||||
/// use rust_bert::Config;
|
||||
/// use std::path::Path;
|
||||
/// use tch::{nn, Device};
|
||||
///
|
||||
/// let config_path = Path::new("path/to/config.json");
|
||||
/// let device = Device::Cpu;
|
||||
@ -62,29 +64,47 @@ impl BertEmbedding for BertEmbeddings {
|
||||
/// let config = BertConfig::from_file(config_path);
|
||||
/// let bert_embeddings = BertEmbeddings::new(&(&p.root() / "bert_embeddings"), &config);
|
||||
/// ```
|
||||
///
|
||||
fn new(p: &nn::Path, config: &BertConfig) -> BertEmbeddings {
|
||||
let embedding_config = EmbeddingConfig { padding_idx: 0, ..Default::default() };
|
||||
let embedding_config = EmbeddingConfig {
|
||||
padding_idx: 0,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let word_embeddings: nn::Embedding = embedding(p / "word_embeddings",
|
||||
let word_embeddings: nn::Embedding = embedding(
|
||||
p / "word_embeddings",
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
embedding_config);
|
||||
embedding_config,
|
||||
);
|
||||
|
||||
let position_embeddings: nn::Embedding = embedding(p / "position_embeddings",
|
||||
let position_embeddings: nn::Embedding = embedding(
|
||||
p / "position_embeddings",
|
||||
config.max_position_embeddings,
|
||||
config.hidden_size,
|
||||
Default::default());
|
||||
Default::default(),
|
||||
);
|
||||
|
||||
let token_type_embeddings: nn::Embedding = embedding(p / "token_type_embeddings",
|
||||
let token_type_embeddings: nn::Embedding = embedding(
|
||||
p / "token_type_embeddings",
|
||||
config.type_vocab_size,
|
||||
config.hidden_size,
|
||||
Default::default());
|
||||
Default::default(),
|
||||
);
|
||||
|
||||
let layer_norm_config = nn::LayerNormConfig { eps: 1e-12, ..Default::default() };
|
||||
let layer_norm: nn::LayerNorm = nn::layer_norm(p / "LayerNorm", vec![config.hidden_size], layer_norm_config);
|
||||
let layer_norm_config = nn::LayerNormConfig {
|
||||
eps: 1e-12,
|
||||
..Default::default()
|
||||
};
|
||||
let layer_norm: nn::LayerNorm =
|
||||
nn::layer_norm(p / "LayerNorm", vec![config.hidden_size], layer_norm_config);
|
||||
let dropout: Dropout = Dropout::new(config.hidden_dropout_prob);
|
||||
BertEmbeddings { word_embeddings, position_embeddings, token_type_embeddings, layer_norm, dropout }
|
||||
BertEmbeddings {
|
||||
word_embeddings,
|
||||
position_embeddings,
|
||||
token_type_embeddings,
|
||||
layer_norm,
|
||||
dropout,
|
||||
}
|
||||
}
|
||||
|
||||
/// Forward pass through the embedding layer
|
||||
@ -118,36 +138,48 @@ impl BertEmbedding for BertEmbeddings {
|
||||
/// let (batch_size, sequence_length) = (64, 128);
|
||||
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
|
||||
/// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
|
||||
/// let position_ids = Tensor::arange(sequence_length, (Int64, device)).expand(&[batch_size, sequence_length], true);
|
||||
/// let position_ids = Tensor::arange(sequence_length, (Int64, device))
|
||||
/// .expand(&[batch_size, sequence_length], true);
|
||||
///
|
||||
/// let embedded_output = no_grad(|| {
|
||||
/// bert_embeddings
|
||||
/// .forward_t(Some(input_tensor),
|
||||
/// .forward_t(
|
||||
/// Some(input_tensor),
|
||||
/// Some(token_type_ids),
|
||||
/// Some(position_ids),
|
||||
/// None,
|
||||
/// false).unwrap()
|
||||
/// false,
|
||||
/// )
|
||||
/// .unwrap()
|
||||
/// });
|
||||
/// ```
|
||||
///
|
||||
fn forward_t(&self,
|
||||
fn forward_t(
|
||||
&self,
|
||||
input_ids: Option<Tensor>,
|
||||
token_type_ids: Option<Tensor>,
|
||||
position_ids: Option<Tensor>,
|
||||
input_embeds: Option<Tensor>,
|
||||
train: bool) -> Result<Tensor, &'static str> {
|
||||
train: bool,
|
||||
) -> Result<Tensor, &'static str> {
|
||||
let (input_embeddings, input_shape) = match input_ids {
|
||||
Some(input_value) => match input_embeds {
|
||||
Some(_) => { return Err("Only one of input ids or input embeddings may be set"); }
|
||||
None => (input_value.apply_t(&self.word_embeddings, train), input_value.size())
|
||||
Some(_) => {
|
||||
return Err("Only one of input ids or input embeddings may be set");
|
||||
}
|
||||
None => (
|
||||
input_value.apply_t(&self.word_embeddings, train),
|
||||
input_value.size(),
|
||||
),
|
||||
},
|
||||
None => match input_embeds {
|
||||
Some(embeds) => {
|
||||
let size = vec!(embeds.size()[0], embeds.size()[1]);
|
||||
let size = vec![embeds.size()[0], embeds.size()[1]];
|
||||
(embeds, size)
|
||||
},
|
||||
None => { return Err("Only one of input ids or input embeddings may be set"); }
|
||||
}
|
||||
None => {
|
||||
return Err("Only one of input ids or input embeddings may be set");
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
let seq_length = input_embeddings.as_ref().size()[1].to_owned();
|
||||
@ -155,19 +187,22 @@ impl BertEmbedding for BertEmbeddings {
|
||||
let position_ids = match position_ids {
|
||||
Some(value) => value,
|
||||
None => Tensor::arange(seq_length, (Kind::Int64, input_embeddings.device()))
|
||||
.unsqueeze(0).
|
||||
expand(&input_shape, true)
|
||||
.unsqueeze(0)
|
||||
.expand(&input_shape, true),
|
||||
};
|
||||
|
||||
let token_type_ids = match token_type_ids {
|
||||
Some(value) => value,
|
||||
None => Tensor::zeros(&input_shape, (Kind::Int64, input_embeddings.device()))
|
||||
None => Tensor::zeros(&input_shape, (Kind::Int64, input_embeddings.device())),
|
||||
};
|
||||
|
||||
let position_embeddings = position_ids.apply(&self.position_embeddings);
|
||||
let token_type_embeddings = token_type_ids.apply(&self.token_type_embeddings);
|
||||
|
||||
let input_embeddings: Tensor = input_embeddings + position_embeddings + token_type_embeddings;
|
||||
Ok(input_embeddings.apply(&self.layer_norm).apply_t(&self.dropout, train))
|
||||
let input_embeddings: Tensor =
|
||||
input_embeddings + position_embeddings + token_type_embeddings;
|
||||
Ok(input_embeddings
|
||||
.apply(&self.layer_norm)
|
||||
.apply_t(&self.dropout, train))
|
||||
}
|
||||
}
|
@ -11,10 +11,10 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use tch::{Tensor, nn};
|
||||
use crate::bert::attention::{BertAttention, BertIntermediate, BertOutput};
|
||||
use std::borrow::BorrowMut;
|
||||
use crate::bert::bert::BertConfig;
|
||||
use std::borrow::BorrowMut;
|
||||
use tch::{nn, Tensor};
|
||||
|
||||
pub struct BertLayer {
|
||||
attention: BertAttention,
|
||||
@ -30,35 +30,55 @@ impl BertLayer {
|
||||
let (is_decoder, cross_attention) = match config.is_decoder {
|
||||
Some(value) => {
|
||||
if value == true {
|
||||
(value, Some(BertAttention::new(&(p / "cross_attention"), &config)))
|
||||
(
|
||||
value,
|
||||
Some(BertAttention::new(&(p / "cross_attention"), &config)),
|
||||
)
|
||||
} else {
|
||||
(value, None)
|
||||
}
|
||||
}
|
||||
None => (false, None)
|
||||
None => (false, None),
|
||||
};
|
||||
|
||||
let intermediate = BertIntermediate::new(&(p / "intermediate"), &config);
|
||||
let output = BertOutput::new(&(p / "output"), &config);
|
||||
|
||||
BertLayer { attention, is_decoder, cross_attention, intermediate, output }
|
||||
BertLayer {
|
||||
attention,
|
||||
is_decoder,
|
||||
cross_attention,
|
||||
intermediate,
|
||||
output,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn forward_t(&self,
|
||||
pub fn forward_t(
|
||||
&self,
|
||||
hidden_states: &Tensor,
|
||||
mask: &Option<Tensor>,
|
||||
encoder_hidden_states: &Option<Tensor>,
|
||||
encoder_mask: &Option<Tensor>,
|
||||
train: bool) -> (Tensor, Option<Tensor>, Option<Tensor>) {
|
||||
let (attention_output, attention_weights, cross_attention_weights) = if self.is_decoder & encoder_hidden_states.is_some() {
|
||||
train: bool,
|
||||
) -> (Tensor, Option<Tensor>, Option<Tensor>) {
|
||||
let (attention_output, attention_weights, cross_attention_weights) =
|
||||
if self.is_decoder & encoder_hidden_states.is_some() {
|
||||
let (attention_output, attention_weights) =
|
||||
self.attention.forward_t(hidden_states, mask, &None, &None, train);
|
||||
self.attention
|
||||
.forward_t(hidden_states, mask, &None, &None, train);
|
||||
let (attention_output, cross_attention_weights) =
|
||||
self.cross_attention.as_ref().unwrap().forward_t(&attention_output, mask, encoder_hidden_states, encoder_mask, train);
|
||||
self.cross_attention.as_ref().unwrap().forward_t(
|
||||
&attention_output,
|
||||
mask,
|
||||
encoder_hidden_states,
|
||||
encoder_mask,
|
||||
train,
|
||||
);
|
||||
(attention_output, attention_weights, cross_attention_weights)
|
||||
} else {
|
||||
let (attention_output, attention_weights) =
|
||||
self.attention.forward_t(hidden_states, mask, &None, &None, train);
|
||||
self.attention
|
||||
.forward_t(hidden_states, mask, &None, &None, train);
|
||||
(attention_output, attention_weights, None)
|
||||
};
|
||||
|
||||
@ -78,26 +98,47 @@ pub struct BertEncoder {
|
||||
impl BertEncoder {
|
||||
pub fn new(p: &nn::Path, config: &BertConfig) -> BertEncoder {
|
||||
let p = &(p / "layer");
|
||||
let output_attentions = if let Some(value) = config.output_attentions { value } else { false };
|
||||
let output_hidden_states = if let Some(value) = config.output_hidden_states { value } else { false };
|
||||
|
||||
let mut layers: Vec<BertLayer> = vec!();
|
||||
for layer_index in 0..config.num_hidden_layers {
|
||||
layers.push(BertLayer::new(&(p / layer_index), config));
|
||||
let output_attentions = if let Some(value) = config.output_attentions {
|
||||
value
|
||||
} else {
|
||||
false
|
||||
};
|
||||
let output_hidden_states = if let Some(value) = config.output_hidden_states {
|
||||
value
|
||||
} else {
|
||||
false
|
||||
};
|
||||
|
||||
BertEncoder { output_attentions, output_hidden_states, layers }
|
||||
let mut layers: Vec<BertLayer> = vec![];
|
||||
for layer_index in 0..config.num_hidden_layers {
|
||||
layers.push(BertLayer::new(&(p / layer_index), config));
|
||||
}
|
||||
|
||||
pub fn forward_t(&self,
|
||||
BertEncoder {
|
||||
output_attentions,
|
||||
output_hidden_states,
|
||||
layers,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn forward_t(
|
||||
&self,
|
||||
hidden_states: &Tensor,
|
||||
mask: &Option<Tensor>,
|
||||
encoder_hidden_states: &Option<Tensor>,
|
||||
encoder_mask: &Option<Tensor>,
|
||||
train: bool)
|
||||
-> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
|
||||
let mut all_hidden_states: Option<Vec<Tensor>> = if self.output_hidden_states { Some(vec!()) } else { None };
|
||||
let mut all_attentions: Option<Vec<Tensor>> = if self.output_attentions { Some(vec!()) } else { None };
|
||||
train: bool,
|
||||
) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
|
||||
let mut all_hidden_states: Option<Vec<Tensor>> = if self.output_hidden_states {
|
||||
Some(vec![])
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let mut all_attentions: Option<Vec<Tensor>> = if self.output_attentions {
|
||||
Some(vec![])
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let mut hidden_state = hidden_states.copy();
|
||||
let mut attention_weights: Option<Tensor>;
|
||||
@ -109,16 +150,22 @@ impl BertEncoder {
|
||||
hidden_states.push(hidden_state.as_ref().copy());
|
||||
};
|
||||
|
||||
let temp = layer.forward_t(&hidden_state, &mask, encoder_hidden_states, encoder_mask, train);
|
||||
let temp = layer.forward_t(
|
||||
&hidden_state,
|
||||
&mask,
|
||||
encoder_hidden_states,
|
||||
encoder_mask,
|
||||
train,
|
||||
);
|
||||
hidden_state = temp.0;
|
||||
attention_weights = temp.1;
|
||||
if let Some(attentions) = all_attentions.borrow_mut() {
|
||||
attentions.push(attention_weights.as_ref().unwrap().copy());
|
||||
};
|
||||
}
|
||||
None => break
|
||||
};
|
||||
None => break,
|
||||
};
|
||||
}
|
||||
|
||||
(hidden_state, all_hidden_states, all_attentions)
|
||||
}
|
||||
@ -130,14 +177,16 @@ pub struct BertPooler {
|
||||
|
||||
impl BertPooler {
|
||||
pub fn new(p: &nn::Path, config: &BertConfig) -> BertPooler {
|
||||
let lin = nn::linear(&(p / "dense"), config.hidden_size, config.hidden_size, Default::default());
|
||||
let lin = nn::linear(
|
||||
&(p / "dense"),
|
||||
config.hidden_size,
|
||||
config.hidden_size,
|
||||
Default::default(),
|
||||
);
|
||||
BertPooler { lin }
|
||||
}
|
||||
|
||||
pub fn forward(&self, hidden_states: &Tensor) -> Tensor {
|
||||
hidden_states
|
||||
.select(1, 0)
|
||||
.apply(&self.lin)
|
||||
.tanh()
|
||||
hidden_states.select(1, 0).apply(&self.lin).tanh()
|
||||
}
|
||||
}
|
@ -24,13 +24,19 @@
|
||||
//! use rust_tokenizers::BertTokenizer;
|
||||
//! use tch::{nn, Device};
|
||||
//! # use std::path::PathBuf;
|
||||
//! use rust_bert::bert::{BertForMaskedLM, BertConfig};
|
||||
//! use rust_bert::bert::{BertConfig, BertForMaskedLM};
|
||||
//! use rust_bert::resources::{download_resource, LocalResource, Resource};
|
||||
//! use rust_bert::Config;
|
||||
//! use rust_bert::resources::{Resource, download_resource, LocalResource};
|
||||
//!
|
||||
//! let config_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/config.json")});
|
||||
//! let vocab_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/vocab.txt")});
|
||||
//! let weights_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/model.ot")});
|
||||
//! let config_resource = Resource::Local(LocalResource {
|
||||
//! local_path: PathBuf::from("path/to/config.json"),
|
||||
//! });
|
||||
//! let vocab_resource = Resource::Local(LocalResource {
|
||||
//! local_path: PathBuf::from("path/to/vocab.txt"),
|
||||
//! });
|
||||
//! let weights_resource = Resource::Local(LocalResource {
|
||||
//! local_path: PathBuf::from("path/to/model.ot"),
|
||||
//! });
|
||||
//! let config_path = download_resource(&config_resource)?;
|
||||
//! let vocab_path = download_resource(&vocab_resource)?;
|
||||
//! let weights_path = download_resource(&weights_resource)?;
|
||||
@ -45,13 +51,14 @@
|
||||
//! # }
|
||||
//! ```
|
||||
|
||||
|
||||
mod attention;
|
||||
mod bert;
|
||||
mod embeddings;
|
||||
mod attention;
|
||||
pub(crate) mod encoder;
|
||||
|
||||
pub use bert::{BertModelResources, BertConfigResources, BertVocabResources,
|
||||
BertConfig, Activation, BertModel, BertForTokenClassification, BertForMultipleChoice,
|
||||
BertForMaskedLM, BertForSequenceClassification, BertForQuestionAnswering};
|
||||
pub use bert::{
|
||||
Activation, BertConfig, BertConfigResources, BertForMaskedLM, BertForMultipleChoice,
|
||||
BertForQuestionAnswering, BertForSequenceClassification, BertForTokenClassification, BertModel,
|
||||
BertModelResources, BertVocabResources,
|
||||
};
|
||||
pub use embeddings::{BertEmbedding, BertEmbeddings};
|
@ -1,14 +1,26 @@
|
||||
use tch::Tensor;
|
||||
use std::f64::consts::PI;
|
||||
use tch::Tensor;
|
||||
|
||||
pub fn _gelu(x: &Tensor) -> Tensor { x * 0.5 * (1.0 + (x / ((2.0 as f64).sqrt())).erf()) }
|
||||
pub fn _gelu(x: &Tensor) -> Tensor {
|
||||
x * 0.5 * (1.0 + (x / ((2.0 as f64).sqrt())).erf())
|
||||
}
|
||||
|
||||
pub fn _relu(x: &Tensor) -> Tensor { x.relu() }
|
||||
pub fn _relu(x: &Tensor) -> Tensor {
|
||||
x.relu()
|
||||
}
|
||||
|
||||
pub fn _swish(x: &Tensor) -> Tensor { x * x.sigmoid() }
|
||||
pub fn _swish(x: &Tensor) -> Tensor {
|
||||
x * x.sigmoid()
|
||||
}
|
||||
|
||||
pub fn _mish(x: &Tensor) -> Tensor { x * (x.softplus().tanh()) }
|
||||
pub fn _mish(x: &Tensor) -> Tensor {
|
||||
x * (x.softplus().tanh())
|
||||
}
|
||||
|
||||
pub fn _gelu_new(x: &Tensor) -> Tensor { x * 0.5 * (((x.pow(3.0f64) * 0.044715 + x) * ((2f64 / PI).sqrt())).tanh() + 1) }
|
||||
pub fn _gelu_new(x: &Tensor) -> Tensor {
|
||||
x * 0.5 * (((x.pow(3.0f64) * 0.044715 + x) * ((2f64 / PI).sqrt())).tanh() + 1)
|
||||
}
|
||||
|
||||
pub fn _tanh(x: &Tensor) -> Tensor { x.tanh() }
|
||||
pub fn _tanh(x: &Tensor) -> Tensor {
|
||||
x.tanh()
|
||||
}
|
||||
|
@ -9,15 +9,16 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
|
||||
use std::path::Path;
|
||||
use serde::Deserialize;
|
||||
use std::fs::File;
|
||||
use std::io::BufReader;
|
||||
use serde::Deserialize;
|
||||
use std::path::Path;
|
||||
|
||||
/// # Utility to deserialize JSON config files
|
||||
pub trait Config<T>
|
||||
where for<'de> T: Deserialize<'de> {
|
||||
where
|
||||
for<'de> T: Deserialize<'de>,
|
||||
{
|
||||
/// Loads a `Config` object from a JSON file. The format is expected to be aligned with the [Transformers library](https://github.com/huggingface/transformers) configuration files for each model.
|
||||
/// The parsing will fail if non-optional keys expected by the model are missing.
|
||||
///
|
||||
@ -28,14 +29,13 @@ pub trait Config<T>
|
||||
/// # Example
|
||||
///
|
||||
/// ```no_run
|
||||
/// use rust_bert::gpt2::Gpt2Config;
|
||||
/// use rust_bert::Config;
|
||||
/// use std::path::Path;
|
||||
/// use rust_bert::gpt2::Gpt2Config;
|
||||
///
|
||||
/// let config_path = Path::new("path/to/config.json");
|
||||
/// let config = Gpt2Config::from_file(config_path);
|
||||
/// ```
|
||||
///
|
||||
fn from_file(path: &Path) -> T {
|
||||
let f = File::open(path).expect("Could not open configuration file.");
|
||||
let br = BufReader::new(f);
|
||||
|
@ -10,9 +10,9 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use tch::nn::{Init, Path, Module};
|
||||
use tch::Tensor;
|
||||
use std::borrow::Borrow;
|
||||
use tch::nn::{Init, Module, Path};
|
||||
use tch::Tensor;
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct LinearNoBiasConfig {
|
||||
@ -32,7 +32,6 @@ pub struct LinearNoBias {
|
||||
pub ws: Tensor,
|
||||
}
|
||||
|
||||
|
||||
pub fn linear_no_bias<'a, T: Borrow<Path<'a>>>(
|
||||
vs: T,
|
||||
in_dim: i64,
|
||||
|
@ -1,7 +1,7 @@
|
||||
pub mod config;
|
||||
pub mod resources;
|
||||
pub(crate) mod dropout;
|
||||
pub(crate) mod activations;
|
||||
pub mod config;
|
||||
pub(crate) mod dropout;
|
||||
pub(crate) mod linear;
|
||||
pub mod resources;
|
||||
|
||||
pub use config::Config;
|
@ -18,9 +18,9 @@
|
||||
//! pre-trained models in each model module.
|
||||
|
||||
use lazy_static::lazy_static;
|
||||
use std::path::PathBuf;
|
||||
use reqwest::Client;
|
||||
use std::{fs, env};
|
||||
use std::path::PathBuf;
|
||||
use std::{env, fs};
|
||||
use tokio::prelude::*;
|
||||
use tokio::runtime::Runtime;
|
||||
use tokio::task;
|
||||
@ -47,12 +47,13 @@ impl Resource {
|
||||
/// # Example
|
||||
///
|
||||
/// ```no_run
|
||||
/// use rust_bert::resources::{Resource, LocalResource};
|
||||
/// use rust_bert::resources::{LocalResource, Resource};
|
||||
/// use std::path::PathBuf;
|
||||
/// let config_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/config.json")});
|
||||
/// let config_resource = Resource::Local(LocalResource {
|
||||
/// local_path: PathBuf::from("path/to/config.json"),
|
||||
/// });
|
||||
/// let config_path = config_resource.get_local_path();
|
||||
/// ```
|
||||
///
|
||||
pub fn get_local_path(&self) -> &PathBuf {
|
||||
match self {
|
||||
Resource::Local(resource) => &resource.local_path,
|
||||
@ -65,7 +66,7 @@ impl Resource {
|
||||
#[derive(PartialEq, Clone)]
|
||||
pub struct LocalResource {
|
||||
/// Local path for the resource
|
||||
pub local_path: PathBuf
|
||||
pub local_path: PathBuf,
|
||||
}
|
||||
|
||||
/// # Remote resource
|
||||
@ -93,13 +94,18 @@ impl RemoteResource {
|
||||
/// # Example
|
||||
///
|
||||
/// ```no_run
|
||||
/// use rust_bert::resources::{Resource, RemoteResource};
|
||||
/// use rust_bert::resources::{RemoteResource, Resource};
|
||||
/// use std::path::PathBuf;
|
||||
/// let config_resource = Resource::Remote(RemoteResource::new("http://config_json_location", PathBuf::from("path/to/config.json")));
|
||||
/// let config_resource = Resource::Remote(RemoteResource::new(
|
||||
/// "http://config_json_location",
|
||||
/// PathBuf::from("path/to/config.json"),
|
||||
/// ));
|
||||
/// ```
|
||||
///
|
||||
pub fn new(url: &str, target: PathBuf) -> RemoteResource {
|
||||
RemoteResource { url: url.to_string(), local_path: target }
|
||||
RemoteResource {
|
||||
url: url.to_string(),
|
||||
local_path: target,
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a new RemoteResource from an URL and local name. Will define a local path pointing to
|
||||
@ -117,14 +123,12 @@ impl RemoteResource {
|
||||
/// # Example
|
||||
///
|
||||
/// ```no_run
|
||||
/// use rust_bert::resources::{Resource, RemoteResource};
|
||||
/// let model_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
/// ("distilbert-sst2/model.ot",
|
||||
/// "https://cdn.huggingface.co/distilbert-base-uncased-finetuned-sst-2-english-rust_model.ot"
|
||||
/// )
|
||||
/// ));
|
||||
/// use rust_bert::resources::{RemoteResource, Resource};
|
||||
/// let model_resource = Resource::Remote(RemoteResource::from_pretrained((
|
||||
/// "distilbert-sst2/model.ot",
|
||||
/// "https://cdn.huggingface.co/distilbert-base-uncased-finetuned-sst-2-english-rust_model.ot",
|
||||
/// )));
|
||||
/// ```
|
||||
///
|
||||
pub fn from_pretrained(name_url_tuple: (&str, &str)) -> RemoteResource {
|
||||
let name = name_url_tuple.0;
|
||||
let url = name_url_tuple.1.to_string();
|
||||
@ -171,15 +175,13 @@ fn _get_cache_directory() -> PathBuf {
|
||||
/// # Example
|
||||
///
|
||||
/// ```no_run
|
||||
/// use rust_bert::resources::{Resource, RemoteResource, download_resource};
|
||||
/// let model_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
/// ("distilbert-sst2/model.ot",
|
||||
/// "https://cdn.huggingface.co/distilbert-base-uncased-finetuned-sst-2-english-rust_model.ot"
|
||||
/// )
|
||||
/// ));
|
||||
/// use rust_bert::resources::{download_resource, RemoteResource, Resource};
|
||||
/// let model_resource = Resource::Remote(RemoteResource::from_pretrained((
|
||||
/// "distilbert-sst2/model.ot",
|
||||
/// "https://cdn.huggingface.co/distilbert-base-uncased-finetuned-sst-2-english-rust_model.ot",
|
||||
/// )));
|
||||
/// let local_path = download_resource(&model_resource);
|
||||
/// ```
|
||||
///
|
||||
pub fn download_resource(resource: &Resource) -> failure::Fallible<&PathBuf> {
|
||||
match resource {
|
||||
Resource::Remote(remote_resource) => {
|
||||
@ -202,8 +204,6 @@ pub fn download_resource(resource: &Resource) -> failure::Fallible<&PathBuf> {
|
||||
|
||||
Ok(resource.get_local_path())
|
||||
}
|
||||
Resource::Local(_) => {
|
||||
Ok(resource.get_local_path())
|
||||
}
|
||||
Resource::Local(_) => Ok(resource.get_local_path()),
|
||||
}
|
||||
}
|
@ -16,7 +16,11 @@ extern crate tch;
|
||||
|
||||
pub fn main() -> failure::Fallible<()> {
|
||||
let args: Vec<_> = std::env::args().collect();
|
||||
ensure!(args.len() == 3, "usage: {} source.npz destination.ot", args[0]);
|
||||
ensure!(
|
||||
args.len() == 3,
|
||||
"usage: {} source.npz destination.ot",
|
||||
args[0]
|
||||
);
|
||||
|
||||
let source_file = &args[1];
|
||||
let destination_file = &args[2];
|
||||
|
@ -10,11 +10,10 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use tch::{nn, Tensor};
|
||||
use crate::common::dropout::Dropout;
|
||||
use crate::distilbert::distilbert::DistilBertConfig;
|
||||
use tch::kind::Kind::Float;
|
||||
use crate::common::dropout::Dropout;
|
||||
|
||||
use tch::{nn, Tensor};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct MultiHeadSelfAttention {
|
||||
@ -39,7 +38,7 @@ impl MultiHeadSelfAttention {
|
||||
|
||||
let output_attentions = match config.output_attentions {
|
||||
Some(value) => value,
|
||||
None => false
|
||||
None => false,
|
||||
};
|
||||
|
||||
MultiHeadSelfAttention {
|
||||
@ -59,10 +58,19 @@ impl MultiHeadSelfAttention {
|
||||
}
|
||||
|
||||
fn flatten(&self, x: Tensor, bs: i64, dim_per_head: i64) -> Tensor {
|
||||
x.transpose(1, 2).contiguous().view((bs, -1, &self.n_heads * dim_per_head))
|
||||
x.transpose(1, 2)
|
||||
.contiguous()
|
||||
.view((bs, -1, &self.n_heads * dim_per_head))
|
||||
}
|
||||
|
||||
pub fn forward_t(&self, query: &Tensor, key: &Tensor, value: &Tensor, mask: &Option<Tensor>, train: bool) -> (Tensor, Option<Tensor>) {
|
||||
pub fn forward_t(
|
||||
&self,
|
||||
query: &Tensor,
|
||||
key: &Tensor,
|
||||
value: &Tensor,
|
||||
mask: &Option<Tensor>,
|
||||
train: bool,
|
||||
) -> (Tensor, Option<Tensor>) {
|
||||
let bs = query.size()[0];
|
||||
let k_length = key.size()[1];
|
||||
|
||||
@ -73,14 +81,19 @@ impl MultiHeadSelfAttention {
|
||||
|
||||
let scores = if let Some(mask) = mask {
|
||||
let unmasked_scores = q.matmul(&k.transpose(2, 3));
|
||||
let mask = mask.le1(&(mask.zeros_like() + 0.1)).view((bs, 1i64, 1i64, k_length)).expand_as(&unmasked_scores);
|
||||
let mask = mask
|
||||
.le1(&(mask.zeros_like() + 0.1))
|
||||
.view((bs, 1i64, 1i64, k_length))
|
||||
.expand_as(&unmasked_scores);
|
||||
unmasked_scores.masked_fill(&mask, std::f64::NEG_INFINITY)
|
||||
} else {
|
||||
q.matmul(&k.transpose(2, 3))
|
||||
};
|
||||
|
||||
let weights = scores.softmax(-1, Float).apply_t(&self.dropout, train);
|
||||
let context = self.flatten(weights.matmul(&v), bs, self.dim_per_head).apply(&self.out_lin);
|
||||
let context = self
|
||||
.flatten(weights.matmul(&v), bs, self.dim_per_head)
|
||||
.apply(&self.out_lin);
|
||||
|
||||
if !self.output_attentions {
|
||||
(context, None)
|
||||
|
@ -12,13 +12,13 @@
|
||||
|
||||
extern crate tch;
|
||||
|
||||
use std::collections::HashMap;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use crate::distilbert::embeddings::DistilBertEmbedding;
|
||||
use crate::distilbert::transformer::Transformer;
|
||||
use self::tch::{nn, Tensor};
|
||||
use crate::common::dropout::Dropout;
|
||||
use crate::distilbert::embeddings::DistilBertEmbedding;
|
||||
use crate::distilbert::transformer::Transformer;
|
||||
use crate::Config;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// # DistilBERT Pretrained model weight files
|
||||
pub struct DistilBertModelResources;
|
||||
@ -31,29 +31,56 @@ pub struct DistilBertVocabResources;
|
||||
|
||||
impl DistilBertModelResources {
|
||||
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
|
||||
pub const DISTIL_BERT_SST2: (&'static str, &'static str) = ("distilbert-sst2/model.ot", "https://cdn.huggingface.co/distilbert-base-uncased-finetuned-sst-2-english-rust_model.ot");
|
||||
pub const DISTIL_BERT_SST2: (&'static str, &'static str) = (
|
||||
"distilbert-sst2/model.ot",
|
||||
"https://cdn.huggingface.co/distilbert-base-uncased-finetuned-sst-2-english-rust_model.ot",
|
||||
);
|
||||
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
|
||||
pub const DISTIL_BERT: (&'static str, &'static str) = ("distilbert/model.ot", "https://cdn.huggingface.co/distilbert-base-uncased-rust_model.ot");
|
||||
pub const DISTIL_BERT: (&'static str, &'static str) = (
|
||||
"distilbert/model.ot",
|
||||
"https://cdn.huggingface.co/distilbert-base-uncased-rust_model.ot",
|
||||
);
|
||||
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
|
||||
pub const DISTIL_BERT_SQUAD: (&'static str, &'static str) = ("distilbert-qa/model.ot", "https://cdn.huggingface.co/distilbert-base-cased-distilled-squad-rust_model.ot");
|
||||
pub const DISTIL_BERT_SQUAD: (&'static str, &'static str) = (
|
||||
"distilbert-qa/model.ot",
|
||||
"https://cdn.huggingface.co/distilbert-base-cased-distilled-squad-rust_model.ot",
|
||||
);
|
||||
}
|
||||
|
||||
impl DistilBertConfigResources {
|
||||
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
|
||||
pub const DISTIL_BERT_SST2: (&'static str, &'static str) = ("distilbert-sst2/config.json", "https://cdn.huggingface.co/distilbert-base-uncased-finetuned-sst-2-english-config.json");
|
||||
pub const DISTIL_BERT_SST2: (&'static str, &'static str) = (
|
||||
"distilbert-sst2/config.json",
|
||||
"https://cdn.huggingface.co/distilbert-base-uncased-finetuned-sst-2-english-config.json",
|
||||
);
|
||||
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
|
||||
pub const DISTIL_BERT: (&'static str, &'static str) = ("distilbert/config.json", "https://cdn.huggingface.co/distilbert-base-uncased-config.json");
|
||||
pub const DISTIL_BERT: (&'static str, &'static str) = (
|
||||
"distilbert/config.json",
|
||||
"https://cdn.huggingface.co/distilbert-base-uncased-config.json",
|
||||
);
|
||||
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
|
||||
pub const DISTIL_BERT_SQUAD: (&'static str, &'static str) = ("distilbert-qa/config.json", "https://cdn.huggingface.co/distilbert-base-cased-distilled-squad-config.json");
|
||||
pub const DISTIL_BERT_SQUAD: (&'static str, &'static str) = (
|
||||
"distilbert-qa/config.json",
|
||||
"https://cdn.huggingface.co/distilbert-base-cased-distilled-squad-config.json",
|
||||
);
|
||||
}
|
||||
|
||||
impl DistilBertVocabResources {
|
||||
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
|
||||
pub const DISTIL_BERT_SST2: (&'static str, &'static str) = ("distilbert-sst2/vocab.txt", "https://cdn.huggingface.co/distilbert-base-uncased-finetuned-sst-2-english-vocab.txt");
|
||||
pub const DISTIL_BERT_SST2: (&'static str, &'static str) = (
|
||||
"distilbert-sst2/vocab.txt",
|
||||
"https://cdn.huggingface.co/distilbert-base-uncased-finetuned-sst-2-english-vocab.txt",
|
||||
);
|
||||
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
|
||||
pub const DISTIL_BERT: (&'static str, &'static str) = ("distilbert/vocab.txt", "https://cdn.huggingface.co/bert-base-uncased-vocab.txt");
|
||||
pub const DISTIL_BERT: (&'static str, &'static str) = (
|
||||
"distilbert/vocab.txt",
|
||||
"https://cdn.huggingface.co/bert-base-uncased-vocab.txt",
|
||||
);
|
||||
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
|
||||
pub const DISTIL_BERT_SQUAD: (&'static str, &'static str) = ("distilbert-qa/vocab.txt", "https://cdn.huggingface.co/bert-large-cased-vocab.txt");
|
||||
pub const DISTIL_BERT_SQUAD: (&'static str, &'static str) = (
|
||||
"distilbert-qa/vocab.txt",
|
||||
"https://cdn.huggingface.co/bert-large-cased-vocab.txt",
|
||||
);
|
||||
}
|
||||
|
||||
#[allow(non_camel_case_types)]
|
||||
@ -118,10 +145,10 @@ impl DistilBertModel {
|
||||
/// # Example
|
||||
///
|
||||
/// ```no_run
|
||||
/// use tch::{nn, Device};
|
||||
/// use rust_bert::distilbert::{DistilBertConfig, DistilBertModel};
|
||||
/// use rust_bert::Config;
|
||||
/// use std::path::Path;
|
||||
/// use rust_bert::distilbert::{DistilBertConfig, DistilBertModel};
|
||||
/// use tch::{nn, Device};
|
||||
///
|
||||
/// let config_path = Path::new("path/to/config.json");
|
||||
/// let device = Device::Cpu;
|
||||
@ -129,12 +156,14 @@ impl DistilBertModel {
|
||||
/// let config = DistilBertConfig::from_file(config_path);
|
||||
/// let distil_bert: DistilBertModel = DistilBertModel::new(&(&p.root() / "distilbert"), &config);
|
||||
/// ```
|
||||
///
|
||||
pub fn new(p: &nn::Path, config: &DistilBertConfig) -> DistilBertModel {
|
||||
let p = &(p / "distilbert");
|
||||
let embeddings = DistilBertEmbedding::new(&(p / "embeddings"), config);
|
||||
let transformer = Transformer::new(&(p / "transformer"), config);
|
||||
DistilBertModel { embeddings, transformer }
|
||||
DistilBertModel {
|
||||
embeddings,
|
||||
transformer,
|
||||
}
|
||||
}
|
||||
|
||||
/// Forward pass through the model
|
||||
@ -172,28 +201,32 @@ impl DistilBertModel {
|
||||
///
|
||||
/// let (output, all_hidden_states, all_attentions) = no_grad(|| {
|
||||
/// distilbert_model
|
||||
/// .forward_t(Some(input_tensor),
|
||||
/// Some(mask),
|
||||
/// None,
|
||||
/// false).unwrap()
|
||||
/// .forward_t(Some(input_tensor), Some(mask), None, false)
|
||||
/// .unwrap()
|
||||
/// });
|
||||
///
|
||||
/// ```
|
||||
///
|
||||
pub fn forward_t(&self, input: Option<Tensor>, mask: Option<Tensor>, input_embeds: Option<Tensor>, train: bool)
|
||||
-> Result<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
|
||||
pub fn forward_t(
|
||||
&self,
|
||||
input: Option<Tensor>,
|
||||
mask: Option<Tensor>,
|
||||
input_embeds: Option<Tensor>,
|
||||
train: bool,
|
||||
) -> Result<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
|
||||
let input_embeddings = match input {
|
||||
Some(input_value) => match input_embeds {
|
||||
Some(_) => { return Err("Only one of input ids or input embeddings may be set"); }
|
||||
None => input_value.apply_t(&self.embeddings, train)
|
||||
Some(_) => {
|
||||
return Err("Only one of input ids or input embeddings may be set");
|
||||
}
|
||||
None => input_value.apply_t(&self.embeddings, train),
|
||||
},
|
||||
None => match input_embeds {
|
||||
Some(embeds) => embeds,
|
||||
None => { return Err("At least one of input ids or input embeddings must be set"); }
|
||||
None => {
|
||||
return Err("At least one of input ids or input embeddings must be set");
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
|
||||
let transformer_output = (&self.transformer).forward_t(&input_embeddings, mask, train);
|
||||
Ok(transformer_output)
|
||||
}
|
||||
@ -223,28 +256,47 @@ impl DistilBertModelClassifier {
|
||||
/// # Example
|
||||
///
|
||||
/// ```no_run
|
||||
/// use tch::{nn, Device};
|
||||
/// use rust_bert::distilbert::{DistilBertConfig, DistilBertModelClassifier};
|
||||
/// use rust_bert::Config;
|
||||
/// use std::path::Path;
|
||||
/// use rust_bert::distilbert::{DistilBertConfig, DistilBertModelClassifier};
|
||||
/// use tch::{nn, Device};
|
||||
///
|
||||
/// let config_path = Path::new("path/to/config.json");
|
||||
/// let device = Device::Cpu;
|
||||
/// let p = nn::VarStore::new(device);
|
||||
/// let config = DistilBertConfig::from_file(config_path);
|
||||
/// let distil_bert: DistilBertModelClassifier = DistilBertModelClassifier::new(&(&p.root() / "distilbert"), &config);
|
||||
/// let distil_bert: DistilBertModelClassifier =
|
||||
/// DistilBertModelClassifier::new(&(&p.root() / "distilbert"), &config);
|
||||
/// ```
|
||||
///
|
||||
pub fn new(p: &nn::Path, config: &DistilBertConfig) -> DistilBertModelClassifier {
|
||||
let distil_bert_model = DistilBertModel::new(&p, config);
|
||||
|
||||
let num_labels = config.id2label.as_ref().expect("id2label must be provided for classifiers").len() as i64;
|
||||
let num_labels = config
|
||||
.id2label
|
||||
.as_ref()
|
||||
.expect("id2label must be provided for classifiers")
|
||||
.len() as i64;
|
||||
|
||||
let pre_classifier = nn::linear(&(p / "pre_classifier"), config.dim, config.dim, Default::default());
|
||||
let classifier = nn::linear(&(p / "classifier"), config.dim, num_labels, Default::default());
|
||||
let pre_classifier = nn::linear(
|
||||
&(p / "pre_classifier"),
|
||||
config.dim,
|
||||
config.dim,
|
||||
Default::default(),
|
||||
);
|
||||
let classifier = nn::linear(
|
||||
&(p / "classifier"),
|
||||
config.dim,
|
||||
num_labels,
|
||||
Default::default(),
|
||||
);
|
||||
let dropout = Dropout::new(config.seq_classif_dropout);
|
||||
|
||||
DistilBertModelClassifier { distil_bert_model, pre_classifier, classifier, dropout }
|
||||
DistilBertModelClassifier {
|
||||
distil_bert_model,
|
||||
pre_classifier,
|
||||
classifier,
|
||||
dropout,
|
||||
}
|
||||
}
|
||||
|
||||
/// Forward pass through the model
|
||||
@ -287,14 +339,21 @@ impl DistilBertModelClassifier {
|
||||
/// None,
|
||||
/// false).unwrap()
|
||||
/// });
|
||||
///
|
||||
/// ```
|
||||
///
|
||||
pub fn forward_t(&self, input: Option<Tensor>, mask: Option<Tensor>, input_embeds: Option<Tensor>, train: bool)
|
||||
-> Result<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
|
||||
let (output, all_hidden_states, all_attentions) = match self.distil_bert_model.forward_t(input, mask, input_embeds, train) {
|
||||
pub fn forward_t(
|
||||
&self,
|
||||
input: Option<Tensor>,
|
||||
mask: Option<Tensor>,
|
||||
input_embeds: Option<Tensor>,
|
||||
train: bool,
|
||||
) -> Result<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
|
||||
let (output, all_hidden_states, all_attentions) =
|
||||
match self
|
||||
.distil_bert_model
|
||||
.forward_t(input, mask, input_embeds, train)
|
||||
{
|
||||
Ok(value) => value,
|
||||
Err(err) => return Err(err)
|
||||
Err(err) => return Err(err),
|
||||
};
|
||||
|
||||
let output = output
|
||||
@ -322,7 +381,6 @@ pub struct DistilBertModelMaskedLM {
|
||||
vocab_projector: nn::Linear,
|
||||
}
|
||||
|
||||
|
||||
impl DistilBertModelMaskedLM {
|
||||
/// Build a new `DistilBertModelMaskedLM` for sequence classification
|
||||
///
|
||||
@ -334,10 +392,10 @@ impl DistilBertModelMaskedLM {
|
||||
/// # Example
|
||||
///
|
||||
/// ```no_run
|
||||
/// use tch::{nn, Device};
|
||||
/// use rust_bert::distilbert::{DistilBertConfig, DistilBertModelMaskedLM};
|
||||
/// use rust_bert::Config;
|
||||
/// use std::path::Path;
|
||||
/// use rust_bert::distilbert::{DistilBertConfig, DistilBertModelMaskedLM};
|
||||
/// use tch::{nn, Device};
|
||||
///
|
||||
/// let config_path = Path::new("path/to/config.json");
|
||||
/// let device = Device::Cpu;
|
||||
@ -345,15 +403,33 @@ impl DistilBertModelMaskedLM {
|
||||
/// let config = DistilBertConfig::from_file(config_path);
|
||||
/// let distil_bert = DistilBertModelMaskedLM::new(&(&p.root() / "distilbert"), &config);
|
||||
/// ```
|
||||
///
|
||||
pub fn new(p: &nn::Path, config: &DistilBertConfig) -> DistilBertModelMaskedLM {
|
||||
let distil_bert_model = DistilBertModel::new(&p, config);
|
||||
let vocab_transform = nn::linear(&(p / "vocab_transform"), config.dim, config.dim, Default::default());
|
||||
let layer_norm_config = nn::LayerNormConfig { eps: 1e-12, ..Default::default() };
|
||||
let vocab_layer_norm = nn::layer_norm(p / "vocab_layer_norm", vec![config.dim], layer_norm_config);
|
||||
let vocab_projector = nn::linear(&(p / "vocab_projector"), config.dim, config.vocab_size, Default::default());
|
||||
let vocab_transform = nn::linear(
|
||||
&(p / "vocab_transform"),
|
||||
config.dim,
|
||||
config.dim,
|
||||
Default::default(),
|
||||
);
|
||||
let layer_norm_config = nn::LayerNormConfig {
|
||||
eps: 1e-12,
|
||||
..Default::default()
|
||||
};
|
||||
let vocab_layer_norm =
|
||||
nn::layer_norm(p / "vocab_layer_norm", vec![config.dim], layer_norm_config);
|
||||
let vocab_projector = nn::linear(
|
||||
&(p / "vocab_projector"),
|
||||
config.dim,
|
||||
config.vocab_size,
|
||||
Default::default(),
|
||||
);
|
||||
|
||||
DistilBertModelMaskedLM { distil_bert_model, vocab_transform, vocab_layer_norm, vocab_projector }
|
||||
DistilBertModelMaskedLM {
|
||||
distil_bert_model,
|
||||
vocab_transform,
|
||||
vocab_layer_norm,
|
||||
vocab_projector,
|
||||
}
|
||||
}
|
||||
|
||||
/// Forward pass through the model
|
||||
@ -391,19 +467,24 @@ impl DistilBertModelMaskedLM {
|
||||
///
|
||||
/// let (output, all_hidden_states, all_attentions) = no_grad(|| {
|
||||
/// distilbert_model
|
||||
/// .forward_t(Some(input_tensor),
|
||||
/// Some(mask),
|
||||
/// None,
|
||||
/// false).unwrap()
|
||||
/// .forward_t(Some(input_tensor), Some(mask), None, false)
|
||||
/// .unwrap()
|
||||
/// });
|
||||
///
|
||||
/// ```
|
||||
///
|
||||
pub fn forward_t(&self, input: Option<Tensor>, mask: Option<Tensor>, input_embeds: Option<Tensor>, train: bool)
|
||||
-> Result<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
|
||||
let (output, all_hidden_states, all_attentions) = match self.distil_bert_model.forward_t(input, mask, input_embeds, train) {
|
||||
pub fn forward_t(
|
||||
&self,
|
||||
input: Option<Tensor>,
|
||||
mask: Option<Tensor>,
|
||||
input_embeds: Option<Tensor>,
|
||||
train: bool,
|
||||
) -> Result<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
|
||||
let (output, all_hidden_states, all_attentions) =
|
||||
match self
|
||||
.distil_bert_model
|
||||
.forward_t(input, mask, input_embeds, train)
|
||||
{
|
||||
Ok(value) => value,
|
||||
Err(err) => return Err(err)
|
||||
Err(err) => return Err(err),
|
||||
};
|
||||
|
||||
let output = output
|
||||
@ -440,10 +521,10 @@ impl DistilBertForQuestionAnswering {
|
||||
/// # Example
|
||||
///
|
||||
/// ```no_run
|
||||
/// use tch::{nn, Device};
|
||||
/// use rust_bert::distilbert::{DistilBertConfig, DistilBertForQuestionAnswering};
|
||||
/// use rust_bert::Config;
|
||||
/// use std::path::Path;
|
||||
/// use rust_bert::distilbert::{DistilBertConfig, DistilBertForQuestionAnswering};
|
||||
/// use tch::{nn, Device};
|
||||
///
|
||||
/// let config_path = Path::new("path/to/config.json");
|
||||
/// let device = Device::Cpu;
|
||||
@ -451,13 +532,16 @@ impl DistilBertForQuestionAnswering {
|
||||
/// let config = DistilBertConfig::from_file(config_path);
|
||||
/// let distil_bert = DistilBertForQuestionAnswering::new(&(&p.root() / "distilbert"), &config);
|
||||
/// ```
|
||||
///
|
||||
pub fn new(p: &nn::Path, config: &DistilBertConfig) -> DistilBertForQuestionAnswering {
|
||||
let distil_bert_model = DistilBertModel::new(&p, config);
|
||||
let qa_outputs = nn::linear(&(p / "qa_outputs"), config.dim, 2, Default::default());
|
||||
let dropout = Dropout::new(config.qa_dropout);
|
||||
|
||||
DistilBertForQuestionAnswering { distil_bert_model, qa_outputs, dropout }
|
||||
DistilBertForQuestionAnswering {
|
||||
distil_bert_model,
|
||||
qa_outputs,
|
||||
dropout,
|
||||
}
|
||||
}
|
||||
|
||||
/// Forward pass through the model
|
||||
@ -496,35 +580,33 @@ impl DistilBertForQuestionAnswering {
|
||||
///
|
||||
/// let (start_scores, end_score, all_hidden_states, all_attentions) = no_grad(|| {
|
||||
/// distilbert_model
|
||||
/// .forward_t(Some(input_tensor),
|
||||
/// Some(mask),
|
||||
/// None,
|
||||
/// false).unwrap()
|
||||
/// .forward_t(Some(input_tensor), Some(mask), None, false)
|
||||
/// .unwrap()
|
||||
/// });
|
||||
///
|
||||
/// ```
|
||||
///
|
||||
pub fn forward_t(&self,
|
||||
pub fn forward_t(
|
||||
&self,
|
||||
input: Option<Tensor>,
|
||||
mask: Option<Tensor>,
|
||||
input_embeds: Option<Tensor>,
|
||||
train: bool)
|
||||
-> Result<(Tensor, Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
|
||||
let (output, all_hidden_states, all_attentions) = match self.distil_bert_model.forward_t(input, mask, input_embeds, train) {
|
||||
train: bool,
|
||||
) -> Result<(Tensor, Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
|
||||
let (output, all_hidden_states, all_attentions) =
|
||||
match self
|
||||
.distil_bert_model
|
||||
.forward_t(input, mask, input_embeds, train)
|
||||
{
|
||||
Ok(value) => value,
|
||||
Err(err) => return Err(err)
|
||||
Err(err) => return Err(err),
|
||||
};
|
||||
|
||||
let output = output
|
||||
.apply_t(&self.dropout, train)
|
||||
.apply(&self.qa_outputs);
|
||||
let output = output.apply_t(&self.dropout, train).apply(&self.qa_outputs);
|
||||
|
||||
let logits = output.split(1, -1);
|
||||
let (start_logits, end_logits) = (&logits[0], &logits[1]);
|
||||
let start_logits = start_logits.squeeze1(-1);
|
||||
let end_logits = end_logits.squeeze1(-1);
|
||||
|
||||
|
||||
Ok((start_logits, end_logits, all_hidden_states, all_attentions))
|
||||
}
|
||||
}
|
||||
@ -552,10 +634,10 @@ impl DistilBertForTokenClassification {
|
||||
/// # Example
|
||||
///
|
||||
/// ```no_run
|
||||
/// use tch::{nn, Device};
|
||||
/// use rust_bert::distilbert::{DistilBertConfig, DistilBertForTokenClassification};
|
||||
/// use rust_bert::Config;
|
||||
/// use std::path::Path;
|
||||
/// use rust_bert::distilbert::{DistilBertConfig, DistilBertForTokenClassification};
|
||||
/// use tch::{nn, Device};
|
||||
///
|
||||
/// let config_path = Path::new("path/to/config.json");
|
||||
/// let device = Device::Cpu;
|
||||
@ -563,16 +645,28 @@ impl DistilBertForTokenClassification {
|
||||
/// let config = DistilBertConfig::from_file(config_path);
|
||||
/// let distil_bert = DistilBertForTokenClassification::new(&(&p.root() / "distilbert"), &config);
|
||||
/// ```
|
||||
///
|
||||
pub fn new(p: &nn::Path, config: &DistilBertConfig) -> DistilBertForTokenClassification {
|
||||
let distil_bert_model = DistilBertModel::new(&p, config);
|
||||
|
||||
let num_labels = config.id2label.as_ref().expect("id2label must be provided for classifiers").len() as i64;
|
||||
let num_labels = config
|
||||
.id2label
|
||||
.as_ref()
|
||||
.expect("id2label must be provided for classifiers")
|
||||
.len() as i64;
|
||||
|
||||
let classifier = nn::linear(&(p / "classifier"), config.dim, num_labels, Default::default());
|
||||
let classifier = nn::linear(
|
||||
&(p / "classifier"),
|
||||
config.dim,
|
||||
num_labels,
|
||||
Default::default(),
|
||||
);
|
||||
let dropout = Dropout::new(config.seq_classif_dropout);
|
||||
|
||||
DistilBertForTokenClassification { distil_bert_model, classifier, dropout }
|
||||
DistilBertForTokenClassification {
|
||||
distil_bert_model,
|
||||
classifier,
|
||||
dropout,
|
||||
}
|
||||
}
|
||||
|
||||
/// Forward pass through the model
|
||||
@ -610,24 +704,27 @@ impl DistilBertForTokenClassification {
|
||||
///
|
||||
/// let (output, all_hidden_states, all_attentions) = no_grad(|| {
|
||||
/// distilbert_model
|
||||
/// .forward_t(Some(input_tensor),
|
||||
/// Some(mask),
|
||||
/// None,
|
||||
/// false).unwrap()
|
||||
/// .forward_t(Some(input_tensor), Some(mask), None, false)
|
||||
/// .unwrap()
|
||||
/// });
|
||||
///
|
||||
/// ```
|
||||
///
|
||||
pub fn forward_t(&self, input: Option<Tensor>, mask: Option<Tensor>, input_embeds: Option<Tensor>, train: bool)
|
||||
-> Result<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
|
||||
let (output, all_hidden_states, all_attentions) = match self.distil_bert_model.forward_t(input, mask, input_embeds, train) {
|
||||
pub fn forward_t(
|
||||
&self,
|
||||
input: Option<Tensor>,
|
||||
mask: Option<Tensor>,
|
||||
input_embeds: Option<Tensor>,
|
||||
train: bool,
|
||||
) -> Result<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
|
||||
let (output, all_hidden_states, all_attentions) =
|
||||
match self
|
||||
.distil_bert_model
|
||||
.forward_t(input, mask, input_embeds, train)
|
||||
{
|
||||
Ok(value) => value,
|
||||
Err(err) => return Err(err)
|
||||
Err(err) => return Err(err),
|
||||
};
|
||||
|
||||
let output = output
|
||||
.apply_t(&self.dropout, train)
|
||||
.apply(&self.classifier);
|
||||
let output = output.apply_t(&self.dropout, train).apply(&self.classifier);
|
||||
|
||||
Ok((output, all_hidden_states, all_attentions))
|
||||
}
|
||||
|
@ -10,22 +10,26 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use tch::{nn, Tensor, Kind, Device};
|
||||
use tch::nn::{ModuleT, embedding, EmbeddingConfig};
|
||||
use crate::distilbert::distilbert::DistilBertConfig;
|
||||
use crate::common::dropout::Dropout;
|
||||
use crate::distilbert::distilbert::DistilBertConfig;
|
||||
use tch::kind::Kind::Float;
|
||||
|
||||
use tch::nn::{embedding, EmbeddingConfig, ModuleT};
|
||||
use tch::{nn, Device, Kind, Tensor};
|
||||
|
||||
fn create_sinusoidal_embeddings(config: &DistilBertConfig, device: Device) -> nn::Embedding {
|
||||
let mut sinusoidal_embedding: Vec<Tensor> = Vec::with_capacity(config.max_position_embeddings as usize);
|
||||
let mut sinusoidal_embedding: Vec<Tensor> =
|
||||
Vec::with_capacity(config.max_position_embeddings as usize);
|
||||
for pos in 0..config.max_position_embeddings {
|
||||
let mut temp_vec: Vec<f64> = Vec::with_capacity(config.dim as usize);
|
||||
for j in 0..config.dim {
|
||||
if j % 2 == 0 {
|
||||
temp_vec.push((pos as f64 / 10000f64.powf((2 * (j / 2)) as f64 / config.dim as f64)).sin());
|
||||
temp_vec.push(
|
||||
(pos as f64 / 10000f64.powf((2 * (j / 2)) as f64 / config.dim as f64)).sin(),
|
||||
);
|
||||
} else {
|
||||
temp_vec.push((pos as f64 / 10000f64.powf((2 * (j / 2)) as f64 / config.dim as f64)).cos());
|
||||
temp_vec.push(
|
||||
(pos as f64 / 10000f64.powf((2 * (j / 2)) as f64 / config.dim as f64)).cos(),
|
||||
);
|
||||
}
|
||||
}
|
||||
let temp_vec = Tensor::of_slice(&temp_vec);
|
||||
@ -35,17 +39,21 @@ fn create_sinusoidal_embeddings(config: &DistilBertConfig, device: Device) -> nn
|
||||
.to_kind(Float)
|
||||
.to_device(device);
|
||||
|
||||
let embedding_config = EmbeddingConfig { padding_idx: 0, ..Default::default() };
|
||||
let mut embeddings = embedding(&nn::VarStore::new(device).root(),
|
||||
let embedding_config = EmbeddingConfig {
|
||||
padding_idx: 0,
|
||||
..Default::default()
|
||||
};
|
||||
let mut embeddings = embedding(
|
||||
&nn::VarStore::new(device).root(),
|
||||
config.max_position_embeddings,
|
||||
config.dim,
|
||||
embedding_config);
|
||||
embedding_config,
|
||||
);
|
||||
|
||||
embeddings.ws = sinusoidal_embedding;
|
||||
embeddings
|
||||
}
|
||||
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct DistilBertEmbedding {
|
||||
word_embeddings: nn::Embedding,
|
||||
@ -56,24 +64,40 @@ pub struct DistilBertEmbedding {
|
||||
|
||||
impl DistilBertEmbedding {
|
||||
pub fn new(p: &nn::Path, config: &DistilBertConfig) -> DistilBertEmbedding {
|
||||
let embedding_config = EmbeddingConfig { padding_idx: 0, ..Default::default() };
|
||||
let embedding_config = EmbeddingConfig {
|
||||
padding_idx: 0,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let word_embeddings: nn::Embedding = embedding(p / "word_embeddings",
|
||||
let word_embeddings: nn::Embedding = embedding(
|
||||
p / "word_embeddings",
|
||||
config.vocab_size,
|
||||
config.dim,
|
||||
embedding_config);
|
||||
embedding_config,
|
||||
);
|
||||
let position_embeddings: nn::Embedding = match config.sinusoidal_pos_embds {
|
||||
false => embedding(p / "position_embeddings",
|
||||
false => embedding(
|
||||
p / "position_embeddings",
|
||||
config.max_position_embeddings,
|
||||
config.dim,
|
||||
embedding_config),
|
||||
embedding_config,
|
||||
),
|
||||
|
||||
true => create_sinusoidal_embeddings(&config, p.device())
|
||||
true => create_sinusoidal_embeddings(&config, p.device()),
|
||||
};
|
||||
let layer_norm_config = nn::LayerNormConfig { eps: 1e-12, ..Default::default() };
|
||||
let layer_norm: nn::LayerNorm = nn::layer_norm(p / "LayerNorm", vec![config.dim], layer_norm_config);
|
||||
let layer_norm_config = nn::LayerNormConfig {
|
||||
eps: 1e-12,
|
||||
..Default::default()
|
||||
};
|
||||
let layer_norm: nn::LayerNorm =
|
||||
nn::layer_norm(p / "LayerNorm", vec![config.dim], layer_norm_config);
|
||||
let dropout: Dropout = Dropout::new(config.dropout);
|
||||
DistilBertEmbedding { word_embeddings, position_embeddings, layer_norm, dropout }
|
||||
DistilBertEmbedding {
|
||||
word_embeddings,
|
||||
position_embeddings,
|
||||
layer_norm,
|
||||
dropout,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn _get_word_embeddings(&self) -> &nn::Embedding {
|
||||
@ -96,7 +120,9 @@ impl ModuleT for DistilBertEmbedding {
|
||||
|
||||
// position_embed.get(0).get(0).print();
|
||||
let embeddings = word_embed + position_embed;
|
||||
let embeddings = embeddings.apply(&self.layer_norm).apply_t(&self.dropout, train);
|
||||
let embeddings = embeddings
|
||||
.apply(&self.layer_norm)
|
||||
.apply_t(&self.dropout, train);
|
||||
|
||||
embeddings
|
||||
}
|
||||
|
@ -23,13 +23,22 @@
|
||||
//! use rust_tokenizers::BertTokenizer;
|
||||
//! use tch::{nn, Device};
|
||||
//! # use std::path::PathBuf;
|
||||
//! use rust_bert::distilbert::{
|
||||
//! DistilBertConfig, DistilBertConfigResources, DistilBertModelMaskedLM,
|
||||
//! DistilBertModelResources, DistilBertVocabResources,
|
||||
//! };
|
||||
//! use rust_bert::resources::{download_resource, LocalResource, RemoteResource, Resource};
|
||||
//! use rust_bert::Config;
|
||||
//! use rust_bert::distilbert::{DistilBertModelMaskedLM, DistilBertConfig, DistilBertConfigResources, DistilBertVocabResources, DistilBertModelResources};
|
||||
//! use rust_bert::resources::{Resource, download_resource, RemoteResource, LocalResource};
|
||||
//!
|
||||
//! let config_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/config.json")});
|
||||
//! let vocab_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/vocab.txt")});
|
||||
//! let weights_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/model.ot")});
|
||||
//! let config_resource = Resource::Local(LocalResource {
|
||||
//! local_path: PathBuf::from("path/to/config.json"),
|
||||
//! });
|
||||
//! let vocab_resource = Resource::Local(LocalResource {
|
||||
//! local_path: PathBuf::from("path/to/vocab.txt"),
|
||||
//! });
|
||||
//! let weights_resource = Resource::Local(LocalResource {
|
||||
//! local_path: PathBuf::from("path/to/model.ot"),
|
||||
//! });
|
||||
//! let config_path = download_resource(&config_resource)?;
|
||||
//! let vocab_path = download_resource(&vocab_resource)?;
|
||||
//! let weights_path = download_resource(&weights_resource)?;
|
||||
@ -44,13 +53,13 @@
|
||||
//! # }
|
||||
//! ```
|
||||
|
||||
|
||||
|
||||
mod attention;
|
||||
mod distilbert;
|
||||
mod embeddings;
|
||||
mod attention;
|
||||
mod transformer;
|
||||
|
||||
pub use distilbert::{DistilBertModelResources, DistilBertConfigResources, DistilBertVocabResources,
|
||||
DistilBertConfig, Activation, DistilBertModel, DistilBertForQuestionAnswering, DistilBertForTokenClassification,
|
||||
DistilBertModelMaskedLM, DistilBertModelClassifier};
|
||||
pub use distilbert::{
|
||||
Activation, DistilBertConfig, DistilBertConfigResources, DistilBertForQuestionAnswering,
|
||||
DistilBertForTokenClassification, DistilBertModel, DistilBertModelClassifier,
|
||||
DistilBertModelMaskedLM, DistilBertModelResources, DistilBertVocabResources,
|
||||
};
|
||||
|
@ -10,13 +10,13 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use tch::{Tensor, nn};
|
||||
use crate::distilbert::distilbert::{DistilBertConfig, Activation};
|
||||
use crate::distilbert::attention::MultiHeadSelfAttention;
|
||||
use tch::nn::LayerNorm;
|
||||
use std::borrow::BorrowMut;
|
||||
use crate::common::dropout::Dropout;
|
||||
use crate::common::activations::{_gelu, _relu};
|
||||
use crate::common::dropout::Dropout;
|
||||
use crate::distilbert::attention::MultiHeadSelfAttention;
|
||||
use crate::distilbert::distilbert::{Activation, DistilBertConfig};
|
||||
use std::borrow::BorrowMut;
|
||||
use tch::nn::LayerNorm;
|
||||
use tch::{nn, Tensor};
|
||||
|
||||
pub struct FeedForwardNetwork {
|
||||
lin1: nn::Linear,
|
||||
@ -27,18 +27,35 @@ pub struct FeedForwardNetwork {
|
||||
|
||||
impl FeedForwardNetwork {
|
||||
pub fn new(p: nn::Path, config: &DistilBertConfig) -> FeedForwardNetwork {
|
||||
let lin1 = nn::linear(&p / "lin1", config.dim, config.hidden_dim, Default::default());
|
||||
let lin2 = nn::linear(&p / "lin2", config.hidden_dim, config.dim, Default::default());
|
||||
let lin1 = nn::linear(
|
||||
&p / "lin1",
|
||||
config.dim,
|
||||
config.hidden_dim,
|
||||
Default::default(),
|
||||
);
|
||||
let lin2 = nn::linear(
|
||||
&p / "lin2",
|
||||
config.hidden_dim,
|
||||
config.dim,
|
||||
Default::default(),
|
||||
);
|
||||
let dropout = Dropout::new(config.dropout);
|
||||
let activation = Box::new(match &config.activation {
|
||||
Activation::gelu => _gelu,
|
||||
Activation::relu => _relu
|
||||
Activation::relu => _relu,
|
||||
});
|
||||
FeedForwardNetwork { lin1, lin2, dropout, activation }
|
||||
FeedForwardNetwork {
|
||||
lin1,
|
||||
lin2,
|
||||
dropout,
|
||||
activation,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn forward_t(&self, input: &Tensor, train: bool) -> Tensor {
|
||||
(self.activation)(&input.apply(&self.lin1)).apply(&self.lin2).apply_t(&self.dropout, train)
|
||||
(self.activation)(&input.apply(&self.lin1))
|
||||
.apply(&self.lin2)
|
||||
.apply_t(&self.dropout, train)
|
||||
}
|
||||
}
|
||||
|
||||
@ -52,10 +69,15 @@ pub struct TransformerBlock {
|
||||
impl TransformerBlock {
|
||||
pub fn new(p: &nn::Path, config: &DistilBertConfig) -> TransformerBlock {
|
||||
let attention = MultiHeadSelfAttention::new(p / "attention", &config);
|
||||
let layer_norm_config = nn::LayerNormConfig { eps: 1e-12, ..Default::default() };
|
||||
let sa_layer_norm = nn::layer_norm(p / "sa_layer_norm", vec![config.dim], layer_norm_config);
|
||||
let layer_norm_config = nn::LayerNormConfig {
|
||||
eps: 1e-12,
|
||||
..Default::default()
|
||||
};
|
||||
let sa_layer_norm =
|
||||
nn::layer_norm(p / "sa_layer_norm", vec![config.dim], layer_norm_config);
|
||||
let ffn = FeedForwardNetwork::new(p / "ffn", &config);
|
||||
let output_layer_norm = nn::layer_norm(p / "output_layer_norm", vec![config.dim], layer_norm_config);
|
||||
let output_layer_norm =
|
||||
nn::layer_norm(p / "output_layer_norm", vec![config.dim], layer_norm_config);
|
||||
|
||||
TransformerBlock {
|
||||
attention,
|
||||
@ -65,8 +87,15 @@ impl TransformerBlock {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn forward_t(&self, input: &Tensor, mask: &Option<Tensor>, train: bool) -> (Tensor, Option<Tensor>) {
|
||||
let (output, sa_weights) = self.attention.forward_t(&input, &input, &input, mask, train);
|
||||
pub fn forward_t(
|
||||
&self,
|
||||
input: &Tensor,
|
||||
mask: &Option<Tensor>,
|
||||
train: bool,
|
||||
) -> (Tensor, Option<Tensor>) {
|
||||
let (output, sa_weights) = self
|
||||
.attention
|
||||
.forward_t(&input, &input, &input, mask, train);
|
||||
let output = (input + &output).apply(&self.sa_layer_norm);
|
||||
let output = (&output + self.ffn.forward_t(&output, train)).apply(&self.output_layer_norm);
|
||||
(output, sa_weights)
|
||||
@ -84,25 +113,41 @@ impl Transformer {
|
||||
let p = &(p / "layer");
|
||||
let output_attentions = match config.output_attentions {
|
||||
Some(value) => value,
|
||||
None => false
|
||||
None => false,
|
||||
};
|
||||
let output_hidden_states = match config.output_hidden_states {
|
||||
Some(value) => value,
|
||||
None => false
|
||||
None => false,
|
||||
};
|
||||
|
||||
let mut layers: Vec<TransformerBlock> = vec!();
|
||||
let mut layers: Vec<TransformerBlock> = vec![];
|
||||
for layer_index in 0..config.n_layers {
|
||||
layers.push(TransformerBlock::new(&(p / layer_index), config));
|
||||
};
|
||||
|
||||
Transformer { output_attentions, output_hidden_states, layers }
|
||||
}
|
||||
|
||||
pub fn forward_t(&self, input: &Tensor, mask: Option<Tensor>, train: bool)
|
||||
-> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
|
||||
let mut all_hidden_states: Option<Vec<Tensor>> = if self.output_hidden_states { Some(vec!()) } else { None };
|
||||
let mut all_attentions: Option<Vec<Tensor>> = if self.output_attentions { Some(vec!()) } else { None };
|
||||
Transformer {
|
||||
output_attentions,
|
||||
output_hidden_states,
|
||||
layers,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn forward_t(
|
||||
&self,
|
||||
input: &Tensor,
|
||||
mask: Option<Tensor>,
|
||||
train: bool,
|
||||
) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
|
||||
let mut all_hidden_states: Option<Vec<Tensor>> = if self.output_hidden_states {
|
||||
Some(vec![])
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let mut all_attentions: Option<Vec<Tensor>> = if self.output_attentions {
|
||||
Some(vec![])
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let mut hidden_state = input.copy();
|
||||
let mut attention_weights: Option<Tensor>;
|
||||
@ -121,9 +166,9 @@ impl Transformer {
|
||||
attentions.push(attention_weights.as_ref().unwrap().copy());
|
||||
};
|
||||
}
|
||||
None => break
|
||||
};
|
||||
None => break,
|
||||
};
|
||||
}
|
||||
|
||||
(hidden_state, all_hidden_states, all_attentions)
|
||||
}
|
||||
|
@ -12,15 +12,15 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use crate::bert::encoder::BertEncoder;
|
||||
use crate::bert::{Activation, BertConfig};
|
||||
use crate::common::activations::{_gelu, _mish, _relu};
|
||||
use crate::common::dropout::Dropout;
|
||||
use crate::electra::embeddings::ElectraEmbeddings;
|
||||
use crate::Config;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use crate::bert::{Activation, BertConfig};
|
||||
use crate::Config;
|
||||
use crate::electra::embeddings::ElectraEmbeddings;
|
||||
use tch::{nn, Tensor, Kind};
|
||||
use crate::bert::encoder::BertEncoder;
|
||||
use crate::common::activations::{_gelu, _relu, _mish};
|
||||
use crate::common::dropout::Dropout;
|
||||
use tch::{nn, Kind, Tensor};
|
||||
|
||||
/// # Electra Pretrained model weight files
|
||||
pub struct ElectraModelResources;
|
||||
@ -33,23 +33,41 @@ pub struct ElectraVocabResources;
|
||||
|
||||
impl ElectraModelResources {
|
||||
/// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/electra. Modified with conversion to C-array format.
|
||||
pub const BASE_GENERATOR: (&'static str, &'static str) = ("electra-base-generator/model.ot", "https://cdn.huggingface.co/google/electra-base-generator/rust_model.ot");
|
||||
pub const BASE_GENERATOR: (&'static str, &'static str) = (
|
||||
"electra-base-generator/model.ot",
|
||||
"https://cdn.huggingface.co/google/electra-base-generator/rust_model.ot",
|
||||
);
|
||||
/// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/electra. Modified with conversion to C-array format.
|
||||
pub const BASE_DISCRIMINATOR: (&'static str, &'static str) = ("electra-base-discriminator/model.ot", "https://cdn.huggingface.co/google/electra-base-discriminator/rust_model.ot");
|
||||
pub const BASE_DISCRIMINATOR: (&'static str, &'static str) = (
|
||||
"electra-base-discriminator/model.ot",
|
||||
"https://cdn.huggingface.co/google/electra-base-discriminator/rust_model.ot",
|
||||
);
|
||||
}
|
||||
|
||||
impl ElectraConfigResources {
|
||||
/// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/electra. Modified with conversion to C-array format.
|
||||
pub const BASE_GENERATOR: (&'static str, &'static str) = ("electra-base-generator/config.json", "https://cdn.huggingface.co/google/electra-base-generator/config.json");
|
||||
pub const BASE_GENERATOR: (&'static str, &'static str) = (
|
||||
"electra-base-generator/config.json",
|
||||
"https://cdn.huggingface.co/google/electra-base-generator/config.json",
|
||||
);
|
||||
/// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/electra. Modified with conversion to C-array format.
|
||||
pub const BASE_DISCRIMINATOR: (&'static str, &'static str) = ("electra-base-discriminator/config.json", "https://cdn.huggingface.co/google/electra-base-discriminator/config.json");
|
||||
pub const BASE_DISCRIMINATOR: (&'static str, &'static str) = (
|
||||
"electra-base-discriminator/config.json",
|
||||
"https://cdn.huggingface.co/google/electra-base-discriminator/config.json",
|
||||
);
|
||||
}
|
||||
|
||||
impl ElectraVocabResources {
|
||||
/// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/electra. Modified with conversion to C-array format.
|
||||
pub const BASE_GENERATOR: (&'static str, &'static str) = ("electra-base-generator/vocab.txt", "https://cdn.huggingface.co/google/electra-base-generator/vocab.txt");
|
||||
pub const BASE_GENERATOR: (&'static str, &'static str) = (
|
||||
"electra-base-generator/vocab.txt",
|
||||
"https://cdn.huggingface.co/google/electra-base-generator/vocab.txt",
|
||||
);
|
||||
/// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/electra. Modified with conversion to C-array format.
|
||||
pub const BASE_DISCRIMINATOR: (&'static str, &'static str) = ("electra-base-discriminator/vocab.txt", "https://cdn.huggingface.co/google/electra-base-discriminator/vocab.txt");
|
||||
pub const BASE_DISCRIMINATOR: (&'static str, &'static str) = (
|
||||
"electra-base-discriminator/vocab.txt",
|
||||
"https://cdn.huggingface.co/google/electra-base-discriminator/vocab.txt",
|
||||
);
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
@ -103,10 +121,10 @@ impl ElectraModel {
|
||||
/// # Example
|
||||
///
|
||||
/// ```no_run
|
||||
/// use rust_bert::electra::{ElectraModel, ElectraConfig};
|
||||
/// use tch::{nn, Device};
|
||||
/// use rust_bert::electra::{ElectraConfig, ElectraModel};
|
||||
/// use rust_bert::Config;
|
||||
/// use std::path::Path;
|
||||
/// use tch::{nn, Device};
|
||||
///
|
||||
/// let config_path = Path::new("path/to/config.json");
|
||||
/// let device = Device::Cpu;
|
||||
@ -114,11 +132,15 @@ impl ElectraModel {
|
||||
/// let config = ElectraConfig::from_file(config_path);
|
||||
/// let electra_model: ElectraModel = ElectraModel::new(&(&p.root() / "electra"), &config);
|
||||
/// ```
|
||||
///
|
||||
pub fn new(p: &nn::Path, config: &ElectraConfig) -> ElectraModel {
|
||||
let embeddings = ElectraEmbeddings::new(&(p / "embeddings"), config);
|
||||
let embeddings_project = if config.embedding_size != config.hidden_size {
|
||||
Some(nn::linear(&(p / "embeddings_project"), config.embedding_size, config.hidden_size, Default::default()))
|
||||
Some(nn::linear(
|
||||
&(p / "embeddings_project"),
|
||||
config.embedding_size,
|
||||
config.hidden_size,
|
||||
Default::default(),
|
||||
))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
@ -141,7 +163,11 @@ impl ElectraModel {
|
||||
label2id: config.label2id.clone(),
|
||||
};
|
||||
let encoder = BertEncoder::new(&(p / "encoder"), &bert_config);
|
||||
ElectraModel { embeddings, embeddings_project, encoder }
|
||||
ElectraModel {
|
||||
embeddings,
|
||||
embeddings_project,
|
||||
encoder,
|
||||
}
|
||||
}
|
||||
|
||||
/// Forward pass through the model
|
||||
@ -178,66 +204,84 @@ impl ElectraModel {
|
||||
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
|
||||
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
|
||||
/// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
|
||||
/// let position_ids = Tensor::arange(sequence_length, (Int64, device)).expand(&[batch_size, sequence_length], true);
|
||||
/// let position_ids = Tensor::arange(sequence_length, (Int64, device))
|
||||
/// .expand(&[batch_size, sequence_length], true);
|
||||
///
|
||||
/// let (output, all_hidden_states, all_attentions) = no_grad(|| {
|
||||
/// electra_model
|
||||
/// .forward_t(Some(input_tensor),
|
||||
/// .forward_t(
|
||||
/// Some(input_tensor),
|
||||
/// Some(mask),
|
||||
/// Some(token_type_ids),
|
||||
/// Some(position_ids),
|
||||
/// None,
|
||||
/// false).unwrap()
|
||||
/// false,
|
||||
/// )
|
||||
/// .unwrap()
|
||||
/// });
|
||||
///
|
||||
/// ```
|
||||
///
|
||||
pub fn forward_t(&self,
|
||||
pub fn forward_t(
|
||||
&self,
|
||||
input_ids: Option<Tensor>,
|
||||
mask: Option<Tensor>,
|
||||
token_type_ids: Option<Tensor>,
|
||||
position_ids: Option<Tensor>,
|
||||
input_embeds: Option<Tensor>,
|
||||
train: bool)
|
||||
-> Result<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
|
||||
train: bool,
|
||||
) -> Result<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
|
||||
let (input_shape, device) = match &input_ids {
|
||||
Some(input_value) => match &input_embeds {
|
||||
Some(_) => { return Err("Only one of input ids or input embeddings may be set"); }
|
||||
None => (input_value.size(), input_value.device())
|
||||
Some(_) => {
|
||||
return Err("Only one of input ids or input embeddings may be set");
|
||||
}
|
||||
None => (input_value.size(), input_value.device()),
|
||||
},
|
||||
None => match &input_embeds {
|
||||
Some(embeds) => (vec!(embeds.size()[0], embeds.size()[1]), embeds.device()),
|
||||
None => { return Err("At least one of input ids or input embeddings must be set"); }
|
||||
Some(embeds) => (vec![embeds.size()[0], embeds.size()[1]], embeds.device()),
|
||||
None => {
|
||||
return Err("At least one of input ids or input embeddings must be set");
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
let mask = match mask {
|
||||
Some(value) => value,
|
||||
None => Tensor::ones(&input_shape, (Kind::Int64, device))
|
||||
None => Tensor::ones(&input_shape, (Kind::Int64, device)),
|
||||
};
|
||||
|
||||
let extended_attention_mask = match mask.dim() {
|
||||
3 => mask.unsqueeze(1),
|
||||
2 => mask.unsqueeze(1).unsqueeze(1),
|
||||
_ => { return Err("Invalid attention mask dimension, must be 2 or 3"); }
|
||||
_ => {
|
||||
return Err("Invalid attention mask dimension, must be 2 or 3");
|
||||
}
|
||||
};
|
||||
|
||||
let hidden_states = match self.embeddings.forward_t(input_ids, token_type_ids, position_ids, input_embeds, train) {
|
||||
let hidden_states = match self.embeddings.forward_t(
|
||||
input_ids,
|
||||
token_type_ids,
|
||||
position_ids,
|
||||
input_embeds,
|
||||
train,
|
||||
) {
|
||||
Ok(value) => value,
|
||||
Err(e) => { return Err(e); }
|
||||
Err(e) => {
|
||||
return Err(e);
|
||||
}
|
||||
};
|
||||
|
||||
let hidden_states = match &self.embeddings_project {
|
||||
Some(layer) => hidden_states.apply(layer),
|
||||
None => hidden_states
|
||||
None => hidden_states,
|
||||
};
|
||||
|
||||
let (hidden_state, all_hidden_states, all_attentions) =
|
||||
self.encoder.forward_t(&hidden_states,
|
||||
let (hidden_state, all_hidden_states, all_attentions) = self.encoder.forward_t(
|
||||
&hidden_states,
|
||||
&Some(extended_attention_mask),
|
||||
&None,
|
||||
&None,
|
||||
train);
|
||||
train,
|
||||
);
|
||||
|
||||
Ok((hidden_state, all_hidden_states, all_attentions))
|
||||
}
|
||||
@ -268,9 +312,9 @@ impl ElectraDiscriminatorHead {
|
||||
///
|
||||
/// ```no_run
|
||||
/// use rust_bert::electra::{ElectraConfig, ElectraDiscriminatorHead};
|
||||
/// use tch::{nn, Device};
|
||||
/// use rust_bert::Config;
|
||||
/// use std::path::Path;
|
||||
/// use tch::{nn, Device};
|
||||
///
|
||||
/// let config_path = Path::new("path/to/config.json");
|
||||
/// let device = Device::Cpu;
|
||||
@ -278,16 +322,29 @@ impl ElectraDiscriminatorHead {
|
||||
/// let config = ElectraConfig::from_file(config_path);
|
||||
/// let discriminator_head = ElectraDiscriminatorHead::new(&(&p.root() / "electra"), &config);
|
||||
/// ```
|
||||
///
|
||||
pub fn new(p: &nn::Path, config: &ElectraConfig) -> ElectraDiscriminatorHead {
|
||||
let dense = nn::linear(&(p / "dense"), config.hidden_size, config.hidden_size, Default::default());
|
||||
let dense_prediction = nn::linear(&(p / "dense_prediction"), config.hidden_size, 1, Default::default());
|
||||
let dense = nn::linear(
|
||||
&(p / "dense"),
|
||||
config.hidden_size,
|
||||
config.hidden_size,
|
||||
Default::default(),
|
||||
);
|
||||
let dense_prediction = nn::linear(
|
||||
&(p / "dense_prediction"),
|
||||
config.hidden_size,
|
||||
1,
|
||||
Default::default(),
|
||||
);
|
||||
let activation = Box::new(match &config.hidden_act {
|
||||
Activation::gelu => _gelu,
|
||||
Activation::relu => _relu,
|
||||
Activation::mish => _mish
|
||||
Activation::mish => _mish,
|
||||
});
|
||||
ElectraDiscriminatorHead { dense, dense_prediction, activation }
|
||||
ElectraDiscriminatorHead {
|
||||
dense,
|
||||
dense_prediction,
|
||||
activation,
|
||||
}
|
||||
}
|
||||
|
||||
/// Forward pass through the discriminator head
|
||||
@ -316,14 +373,13 @@ impl ElectraDiscriminatorHead {
|
||||
/// # let config = ElectraConfig::from_file(config_path);
|
||||
/// # let discriminator_head = ElectraDiscriminatorHead::new(&vs.root(), &config);
|
||||
/// let (batch_size, sequence_length) = (64, 128);
|
||||
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length, config.hidden_size], (Float, device));
|
||||
///
|
||||
/// let output = no_grad(|| {
|
||||
/// discriminator_head.forward(&input_tensor)
|
||||
/// });
|
||||
/// let input_tensor = Tensor::rand(
|
||||
/// &[batch_size, sequence_length, config.hidden_size],
|
||||
/// (Float, device),
|
||||
/// );
|
||||
///
|
||||
/// let output = no_grad(|| discriminator_head.forward(&input_tensor));
|
||||
/// ```
|
||||
///
|
||||
pub fn forward(&self, encoder_hidden_states: &Tensor) -> Tensor {
|
||||
let output = encoder_hidden_states.apply(&self.dense);
|
||||
let output = (self.activation)(&output);
|
||||
@ -356,9 +412,9 @@ impl ElectraGeneratorHead {
|
||||
///
|
||||
/// ```no_run
|
||||
/// use rust_bert::electra::{ElectraConfig, ElectraGeneratorHead};
|
||||
/// use tch::{nn, Device};
|
||||
/// use rust_bert::Config;
|
||||
/// use std::path::Path;
|
||||
/// use tch::{nn, Device};
|
||||
///
|
||||
/// let config_path = Path::new("path/to/config.json");
|
||||
/// let device = Device::Cpu;
|
||||
@ -366,13 +422,25 @@ impl ElectraGeneratorHead {
|
||||
/// let config = ElectraConfig::from_file(config_path);
|
||||
/// let generator_head = ElectraGeneratorHead::new(&(&p.root() / "electra"), &config);
|
||||
/// ```
|
||||
///
|
||||
pub fn new(p: &nn::Path, config: &ElectraConfig) -> ElectraGeneratorHead {
|
||||
let layer_norm = nn::layer_norm(p / "LayerNorm", vec![config.embedding_size], Default::default());
|
||||
let dense = nn::linear(&(p / "dense"), config.hidden_size, config.embedding_size, Default::default());
|
||||
let layer_norm = nn::layer_norm(
|
||||
p / "LayerNorm",
|
||||
vec![config.embedding_size],
|
||||
Default::default(),
|
||||
);
|
||||
let dense = nn::linear(
|
||||
&(p / "dense"),
|
||||
config.hidden_size,
|
||||
config.embedding_size,
|
||||
Default::default(),
|
||||
);
|
||||
let activation = Box::new(_gelu);
|
||||
|
||||
ElectraGeneratorHead { layer_norm, dense, activation }
|
||||
ElectraGeneratorHead {
|
||||
layer_norm,
|
||||
dense,
|
||||
activation,
|
||||
}
|
||||
}
|
||||
|
||||
/// Forward pass through the generator head
|
||||
@ -399,14 +467,13 @@ impl ElectraGeneratorHead {
|
||||
/// # let config = ElectraConfig::from_file(config_path);
|
||||
/// # let generator_head = ElectraGeneratorHead::new(&vs.root(), &config);
|
||||
/// let (batch_size, sequence_length) = (64, 128);
|
||||
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length, config.hidden_size], (Float, device));
|
||||
///
|
||||
/// let output = no_grad(|| {
|
||||
/// generator_head.forward(&input_tensor)
|
||||
/// });
|
||||
/// let input_tensor = Tensor::rand(
|
||||
/// &[batch_size, sequence_length, config.hidden_size],
|
||||
/// (Float, device),
|
||||
/// );
|
||||
///
|
||||
/// let output = no_grad(|| generator_head.forward(&input_tensor));
|
||||
/// ```
|
||||
///
|
||||
pub fn forward(&self, encoder_hidden_states: &Tensor) -> Tensor {
|
||||
let output = encoder_hidden_states.apply(&self.dense);
|
||||
let output = (self.activation)(&output);
|
||||
@ -438,10 +505,10 @@ impl ElectraForMaskedLM {
|
||||
/// # Example
|
||||
///
|
||||
/// ```no_run
|
||||
/// use rust_bert::electra::{ElectraForMaskedLM, ElectraConfig};
|
||||
/// use tch::{nn, Device};
|
||||
/// use rust_bert::electra::{ElectraConfig, ElectraForMaskedLM};
|
||||
/// use rust_bert::Config;
|
||||
/// use std::path::Path;
|
||||
/// use tch::{nn, Device};
|
||||
///
|
||||
/// let config_path = Path::new("path/to/config.json");
|
||||
/// let device = Device::Cpu;
|
||||
@ -449,13 +516,21 @@ impl ElectraForMaskedLM {
|
||||
/// let config = ElectraConfig::from_file(config_path);
|
||||
/// let electra_model: ElectraForMaskedLM = ElectraForMaskedLM::new(&p.root(), &config);
|
||||
/// ```
|
||||
///
|
||||
pub fn new(p: &nn::Path, config: &ElectraConfig) -> ElectraForMaskedLM {
|
||||
let electra = ElectraModel::new(&(p / "electra"), config);
|
||||
let generator_head = ElectraGeneratorHead::new(&(p / "generator_predictions"), config);
|
||||
let lm_head = nn::linear(&(p / "generator_lm_head"), config.embedding_size, config.vocab_size, Default::default());
|
||||
let lm_head = nn::linear(
|
||||
&(p / "generator_lm_head"),
|
||||
config.embedding_size,
|
||||
config.vocab_size,
|
||||
Default::default(),
|
||||
);
|
||||
|
||||
ElectraForMaskedLM { electra, generator_head, lm_head }
|
||||
ElectraForMaskedLM {
|
||||
electra,
|
||||
generator_head,
|
||||
lm_head,
|
||||
}
|
||||
}
|
||||
|
||||
/// Forward pass through the model
|
||||
@ -492,32 +567,39 @@ impl ElectraForMaskedLM {
|
||||
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
|
||||
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
|
||||
/// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
|
||||
/// let position_ids = Tensor::arange(sequence_length, (Int64, device)).expand(&[batch_size, sequence_length], true);
|
||||
/// let position_ids = Tensor::arange(sequence_length, (Int64, device))
|
||||
/// .expand(&[batch_size, sequence_length], true);
|
||||
///
|
||||
/// let (output, all_hidden_states, all_attentions) = no_grad(|| {
|
||||
/// electra_model
|
||||
/// .forward_t(Some(input_tensor),
|
||||
/// electra_model.forward_t(
|
||||
/// Some(input_tensor),
|
||||
/// Some(mask),
|
||||
/// Some(token_type_ids),
|
||||
/// Some(position_ids),
|
||||
/// None,
|
||||
/// false)
|
||||
/// false,
|
||||
/// )
|
||||
/// });
|
||||
///
|
||||
/// ```
|
||||
///
|
||||
pub fn forward_t(&self,
|
||||
pub fn forward_t(
|
||||
&self,
|
||||
input_ids: Option<Tensor>,
|
||||
mask: Option<Tensor>,
|
||||
token_type_ids: Option<Tensor>,
|
||||
position_ids: Option<Tensor>,
|
||||
input_embeds: Option<Tensor>,
|
||||
train: bool)
|
||||
-> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
|
||||
let (hidden_states,
|
||||
all_hidden_states,
|
||||
all_attentions) = self.electra
|
||||
.forward_t(input_ids, mask, token_type_ids, position_ids, input_embeds, train)
|
||||
train: bool,
|
||||
) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
|
||||
let (hidden_states, all_hidden_states, all_attentions) = self
|
||||
.electra
|
||||
.forward_t(
|
||||
input_ids,
|
||||
mask,
|
||||
token_type_ids,
|
||||
position_ids,
|
||||
input_embeds,
|
||||
train,
|
||||
)
|
||||
.unwrap();
|
||||
let hidden_states = self.generator_head.forward(&hidden_states);
|
||||
let hidden_states = hidden_states.apply(&self.lm_head);
|
||||
@ -547,10 +629,10 @@ impl ElectraDiscriminator {
|
||||
/// # Example
|
||||
///
|
||||
/// ```no_run
|
||||
/// use rust_bert::electra::{ElectraDiscriminator, ElectraConfig};
|
||||
/// use tch::{nn, Device};
|
||||
/// use rust_bert::electra::{ElectraConfig, ElectraDiscriminator};
|
||||
/// use rust_bert::Config;
|
||||
/// use std::path::Path;
|
||||
/// use tch::{nn, Device};
|
||||
///
|
||||
/// let config_path = Path::new("path/to/config.json");
|
||||
/// let device = Device::Cpu;
|
||||
@ -558,12 +640,15 @@ impl ElectraDiscriminator {
|
||||
/// let config = ElectraConfig::from_file(config_path);
|
||||
/// let electra_model: ElectraDiscriminator = ElectraDiscriminator::new(&p.root(), &config);
|
||||
/// ```
|
||||
///
|
||||
pub fn new(p: &nn::Path, config: &ElectraConfig) -> ElectraDiscriminator {
|
||||
let electra = ElectraModel::new(&(p / "electra"), config);
|
||||
let discriminator_head = ElectraDiscriminatorHead::new(&(p / "discriminator_predictions"), config);
|
||||
let discriminator_head =
|
||||
ElectraDiscriminatorHead::new(&(p / "discriminator_predictions"), config);
|
||||
|
||||
ElectraDiscriminator { electra, discriminator_head }
|
||||
ElectraDiscriminator {
|
||||
electra,
|
||||
discriminator_head,
|
||||
}
|
||||
}
|
||||
|
||||
/// Forward pass through the model
|
||||
@ -611,21 +696,26 @@ impl ElectraDiscriminator {
|
||||
/// None,
|
||||
/// false)
|
||||
/// });
|
||||
///
|
||||
/// ```
|
||||
///
|
||||
pub fn forward_t(&self,
|
||||
pub fn forward_t(
|
||||
&self,
|
||||
input_ids: Option<Tensor>,
|
||||
mask: Option<Tensor>,
|
||||
token_type_ids: Option<Tensor>,
|
||||
position_ids: Option<Tensor>,
|
||||
input_embeds: Option<Tensor>,
|
||||
train: bool)
|
||||
-> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
|
||||
let (hidden_states,
|
||||
all_hidden_states,
|
||||
all_attentions) = self.electra
|
||||
.forward_t(input_ids, mask, token_type_ids, position_ids, input_embeds, train)
|
||||
train: bool,
|
||||
) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
|
||||
let (hidden_states, all_hidden_states, all_attentions) = self
|
||||
.electra
|
||||
.forward_t(
|
||||
input_ids,
|
||||
mask,
|
||||
token_type_ids,
|
||||
position_ids,
|
||||
input_embeds,
|
||||
train,
|
||||
)
|
||||
.unwrap();
|
||||
let probabilities = self.discriminator_head.forward(&hidden_states).sigmoid();
|
||||
(probabilities, all_hidden_states, all_attentions)
|
||||
@ -656,24 +746,37 @@ impl ElectraForTokenClassification {
|
||||
/// # Example
|
||||
///
|
||||
/// ```no_run
|
||||
/// use rust_bert::electra::{ElectraForTokenClassification, ElectraConfig};
|
||||
/// use tch::{nn, Device};
|
||||
/// use rust_bert::electra::{ElectraConfig, ElectraForTokenClassification};
|
||||
/// use rust_bert::Config;
|
||||
/// use std::path::Path;
|
||||
/// use tch::{nn, Device};
|
||||
/// let config_path = Path::new("path/to/config.json");
|
||||
/// let device = Device::Cpu;
|
||||
/// let p = nn::VarStore::new(device);
|
||||
/// let config = ElectraConfig::from_file(config_path);
|
||||
/// let electra_model: ElectraForTokenClassification = ElectraForTokenClassification::new(&p.root(), &config);
|
||||
/// let electra_model: ElectraForTokenClassification =
|
||||
/// ElectraForTokenClassification::new(&p.root(), &config);
|
||||
/// ```
|
||||
///
|
||||
pub fn new(p: &nn::Path, config: &ElectraConfig) -> ElectraForTokenClassification {
|
||||
let electra = ElectraModel::new(&(p / "electra"), config);
|
||||
let dropout = Dropout::new(config.hidden_dropout_prob);
|
||||
let num_labels = config.id2label.as_ref().expect("id2label must be provided for classifiers").len() as i64;
|
||||
let classifier = nn::linear(&(p / "classifier"), config.hidden_size, num_labels, Default::default());
|
||||
let num_labels = config
|
||||
.id2label
|
||||
.as_ref()
|
||||
.expect("id2label must be provided for classifiers")
|
||||
.len() as i64;
|
||||
let classifier = nn::linear(
|
||||
&(p / "classifier"),
|
||||
config.hidden_size,
|
||||
num_labels,
|
||||
Default::default(),
|
||||
);
|
||||
|
||||
ElectraForTokenClassification { electra, dropout, classifier }
|
||||
ElectraForTokenClassification {
|
||||
electra,
|
||||
dropout,
|
||||
classifier,
|
||||
}
|
||||
}
|
||||
|
||||
/// Forward pass through the model
|
||||
@ -721,21 +824,26 @@ impl ElectraForTokenClassification {
|
||||
/// None,
|
||||
/// false)
|
||||
/// });
|
||||
///
|
||||
/// ```
|
||||
///
|
||||
pub fn forward_t(&self,
|
||||
pub fn forward_t(
|
||||
&self,
|
||||
input_ids: Option<Tensor>,
|
||||
mask: Option<Tensor>,
|
||||
token_type_ids: Option<Tensor>,
|
||||
position_ids: Option<Tensor>,
|
||||
input_embeds: Option<Tensor>,
|
||||
train: bool)
|
||||
-> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
|
||||
let (hidden_states,
|
||||
all_hidden_states,
|
||||
all_attentions) = self.electra
|
||||
.forward_t(input_ids, mask, token_type_ids, position_ids, input_embeds, train)
|
||||
train: bool,
|
||||
) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
|
||||
let (hidden_states, all_hidden_states, all_attentions) = self
|
||||
.electra
|
||||
.forward_t(
|
||||
input_ids,
|
||||
mask,
|
||||
token_type_ids,
|
||||
position_ids,
|
||||
input_embeds,
|
||||
train,
|
||||
)
|
||||
.unwrap();
|
||||
let output = hidden_states
|
||||
.apply_t(&self.dropout, train)
|
||||
|
@ -12,10 +12,10 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use tch::{nn, Tensor, Kind};
|
||||
use crate::common::dropout::Dropout;
|
||||
use crate::electra::electra::ElectraConfig;
|
||||
use tch::nn::{EmbeddingConfig, embedding};
|
||||
use tch::nn::{embedding, EmbeddingConfig};
|
||||
use tch::{nn, Kind, Tensor};
|
||||
|
||||
#[derive(Debug)]
|
||||
/// # Embeddings implementation for Electra model
|
||||
@ -34,49 +34,77 @@ impl ElectraEmbeddings {
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let word_embeddings: nn::Embedding = embedding(p / "word_embeddings",
|
||||
let word_embeddings: nn::Embedding = embedding(
|
||||
p / "word_embeddings",
|
||||
config.vocab_size,
|
||||
config.embedding_size,
|
||||
embedding_config);
|
||||
embedding_config,
|
||||
);
|
||||
|
||||
let position_embeddings: nn::Embedding = embedding(p / "position_embeddings",
|
||||
let position_embeddings: nn::Embedding = embedding(
|
||||
p / "position_embeddings",
|
||||
config.max_position_embeddings,
|
||||
config.embedding_size,
|
||||
Default::default());
|
||||
Default::default(),
|
||||
);
|
||||
|
||||
let token_type_embeddings: nn::Embedding = embedding(p / "token_type_embeddings",
|
||||
let token_type_embeddings: nn::Embedding = embedding(
|
||||
p / "token_type_embeddings",
|
||||
config.type_vocab_size,
|
||||
config.embedding_size,
|
||||
Default::default());
|
||||
Default::default(),
|
||||
);
|
||||
|
||||
let layer_norm_eps = match config.layer_norm_eps {
|
||||
Some(value) => value,
|
||||
None => 1e-12
|
||||
None => 1e-12,
|
||||
};
|
||||
let layer_norm_config = nn::LayerNormConfig { eps: layer_norm_eps, ..Default::default() };
|
||||
let layer_norm: nn::LayerNorm = nn::layer_norm(p / "LayerNorm", vec![config.embedding_size], layer_norm_config);
|
||||
let layer_norm_config = nn::LayerNormConfig {
|
||||
eps: layer_norm_eps,
|
||||
..Default::default()
|
||||
};
|
||||
let layer_norm: nn::LayerNorm = nn::layer_norm(
|
||||
p / "LayerNorm",
|
||||
vec![config.embedding_size],
|
||||
layer_norm_config,
|
||||
);
|
||||
let dropout: Dropout = Dropout::new(config.hidden_dropout_prob);
|
||||
ElectraEmbeddings { word_embeddings, position_embeddings, token_type_embeddings, layer_norm, dropout}
|
||||
ElectraEmbeddings {
|
||||
word_embeddings,
|
||||
position_embeddings,
|
||||
token_type_embeddings,
|
||||
layer_norm,
|
||||
dropout,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn forward_t(&self,
|
||||
pub fn forward_t(
|
||||
&self,
|
||||
input_ids: Option<Tensor>,
|
||||
token_type_ids: Option<Tensor>,
|
||||
position_ids: Option<Tensor>,
|
||||
input_embeds: Option<Tensor>,
|
||||
train: bool) -> Result<Tensor, &'static str> {
|
||||
train: bool,
|
||||
) -> Result<Tensor, &'static str> {
|
||||
let (input_embeddings, input_shape) = match input_ids {
|
||||
Some(input_value) => match input_embeds {
|
||||
Some(_) => { return Err("Only one of input ids or input embeddings may be set"); }
|
||||
None => (input_value.apply_t(&self.word_embeddings, train), input_value.size())
|
||||
Some(_) => {
|
||||
return Err("Only one of input ids or input embeddings may be set");
|
||||
}
|
||||
None => (
|
||||
input_value.apply_t(&self.word_embeddings, train),
|
||||
input_value.size(),
|
||||
),
|
||||
},
|
||||
None => match input_embeds {
|
||||
Some(embeds) => {
|
||||
let size = vec!(embeds.size()[0], embeds.size()[1]);
|
||||
let size = vec![embeds.size()[0], embeds.size()[1]];
|
||||
(embeds, size)
|
||||
},
|
||||
None => { return Err("Only one of input ids or input embeddings may be set"); }
|
||||
}
|
||||
None => {
|
||||
return Err("Only one of input ids or input embeddings may be set");
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
let seq_length = input_embeddings.as_ref().size()[1].to_owned();
|
||||
@ -84,19 +112,22 @@ impl ElectraEmbeddings {
|
||||
let position_ids = match position_ids {
|
||||
Some(value) => value,
|
||||
None => Tensor::arange(seq_length, (Kind::Int64, input_embeddings.device()))
|
||||
.unsqueeze(0).
|
||||
expand(&input_shape, true)
|
||||
.unsqueeze(0)
|
||||
.expand(&input_shape, true),
|
||||
};
|
||||
|
||||
let token_type_ids = match token_type_ids {
|
||||
Some(value) => value,
|
||||
None => Tensor::zeros(&input_shape, (Kind::Int64, input_embeddings.device()))
|
||||
None => Tensor::zeros(&input_shape, (Kind::Int64, input_embeddings.device())),
|
||||
};
|
||||
|
||||
let position_embeddings = position_ids.apply(&self.position_embeddings);
|
||||
let token_type_embeddings = token_type_ids.apply(&self.token_type_embeddings);
|
||||
|
||||
let input_embeddings: Tensor = input_embeddings + position_embeddings + token_type_embeddings;
|
||||
Ok(input_embeddings.apply(&self.layer_norm).apply_t(&self.dropout, train))
|
||||
let input_embeddings: Tensor =
|
||||
input_embeddings + position_embeddings + token_type_embeddings;
|
||||
Ok(input_embeddings
|
||||
.apply(&self.layer_norm)
|
||||
.apply_t(&self.dropout, train))
|
||||
}
|
||||
}
|
||||
|
@ -28,13 +28,19 @@
|
||||
//! use rust_tokenizers::BertTokenizer;
|
||||
//! use tch::{nn, Device};
|
||||
//! # use std::path::PathBuf;
|
||||
//! use rust_bert::electra::{ElectraForMaskedLM, ElectraConfig};
|
||||
//! use rust_bert::electra::{ElectraConfig, ElectraForMaskedLM};
|
||||
//! use rust_bert::resources::{download_resource, LocalResource, Resource};
|
||||
//! use rust_bert::Config;
|
||||
//! use rust_bert::resources::{Resource, download_resource, LocalResource};
|
||||
//!
|
||||
//! let config_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/config.json")});
|
||||
//! let vocab_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/vocab.txt")});
|
||||
//! let weights_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/model.ot")});
|
||||
//! let config_resource = Resource::Local(LocalResource {
|
||||
//! local_path: PathBuf::from("path/to/config.json"),
|
||||
//! });
|
||||
//! let vocab_resource = Resource::Local(LocalResource {
|
||||
//! local_path: PathBuf::from("path/to/vocab.txt"),
|
||||
//! });
|
||||
//! let weights_resource = Resource::Local(LocalResource {
|
||||
//! local_path: PathBuf::from("path/to/model.ot"),
|
||||
//! });
|
||||
//! let config_path = download_resource(&config_resource)?;
|
||||
//! let vocab_path = download_resource(&vocab_resource)?;
|
||||
//! let weights_path = download_resource(&weights_resource)?;
|
||||
@ -49,9 +55,11 @@
|
||||
//! # }
|
||||
//! ```
|
||||
|
||||
|
||||
mod embeddings;
|
||||
mod electra;
|
||||
mod embeddings;
|
||||
|
||||
pub use electra::{ElectraModelResources, ElectraVocabResources, ElectraConfigResources, ElectraConfig,
|
||||
ElectraModel, ElectraDiscriminator, ElectraForMaskedLM, ElectraDiscriminatorHead, ElectraGeneratorHead, ElectraForTokenClassification};
|
||||
pub use electra::{
|
||||
ElectraConfig, ElectraConfigResources, ElectraDiscriminator, ElectraDiscriminatorHead,
|
||||
ElectraForMaskedLM, ElectraForTokenClassification, ElectraGeneratorHead, ElectraModel,
|
||||
ElectraModelResources, ElectraVocabResources,
|
||||
};
|
||||
|
@ -12,11 +12,11 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use tch::{Tensor, nn};
|
||||
use crate::common::dropout::Dropout;
|
||||
use crate::gpt2::gpt2::Gpt2Config;
|
||||
use tch::kind::Kind::Float;
|
||||
use tch::nn::{Init, Module};
|
||||
use tch::{nn, Tensor};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct GPTConv1D {
|
||||
@ -26,7 +26,14 @@ pub struct GPTConv1D {
|
||||
|
||||
impl GPTConv1D {
|
||||
pub fn new(p: &nn::Path, nf: i64, nx: i64) -> GPTConv1D {
|
||||
let weight = p.var("weight", &[nx, nf], Init::Randn { mean: 0., stdev: 0.02 });
|
||||
let weight = p.var(
|
||||
"weight",
|
||||
&[nx, nf],
|
||||
Init::Randn {
|
||||
mean: 0.,
|
||||
stdev: 0.02,
|
||||
},
|
||||
);
|
||||
let bias = p.var("bias", &[nf], Init::Const(0.));
|
||||
GPTConv1D { weight, bias }
|
||||
}
|
||||
@ -38,7 +45,6 @@ impl Module for GPTConv1D {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
pub struct Attention {
|
||||
bias: Tensor,
|
||||
c_attn: GPTConv1D,
|
||||
@ -62,23 +68,27 @@ impl Attention {
|
||||
|
||||
let attn_pdrop = match config.attn_pdrop {
|
||||
Some(value) => value,
|
||||
None => 0.1
|
||||
None => 0.1,
|
||||
};
|
||||
|
||||
let resid_pdrop = match config.resid_pdrop {
|
||||
Some(value) => value,
|
||||
None => 0.1
|
||||
None => 0.1,
|
||||
};
|
||||
|
||||
let output_attentions = match config.output_attentions {
|
||||
Some(value) => value,
|
||||
None => false
|
||||
None => false,
|
||||
};
|
||||
|
||||
let attn_dropout = Dropout::new(attn_pdrop);
|
||||
let resid_dropout = Dropout::new(resid_pdrop);
|
||||
|
||||
assert_eq!(config.n_embd % config.n_head, 0, "Attention hidden states not a multiple of the number of heads");
|
||||
assert_eq!(
|
||||
config.n_embd % config.n_head,
|
||||
0,
|
||||
"Attention hidden states not a multiple of the number of heads"
|
||||
);
|
||||
let dim_per_head = config.n_embd / config.n_head;
|
||||
|
||||
Attention {
|
||||
@ -105,19 +115,31 @@ impl Attention {
|
||||
}
|
||||
|
||||
fn flatten(&self, x: Tensor) -> Tensor {
|
||||
x.transpose(1, 2).contiguous().view((x.size()[0], -1, &self.n_head * self.dim_per_head))
|
||||
x.transpose(1, 2)
|
||||
.contiguous()
|
||||
.view((x.size()[0], -1, &self.n_head * self.dim_per_head))
|
||||
}
|
||||
|
||||
fn attention(&self, q: &Tensor, k: &Tensor, v: &Tensor, attention_mask: &Option<Tensor>, train: bool)
|
||||
-> (Tensor, Option<Tensor>) {
|
||||
fn attention(
|
||||
&self,
|
||||
q: &Tensor,
|
||||
k: &Tensor,
|
||||
v: &Tensor,
|
||||
attention_mask: &Option<Tensor>,
|
||||
train: bool,
|
||||
) -> (Tensor, Option<Tensor>) {
|
||||
let mut w = q.matmul(&k);
|
||||
if self.scale { w = w / (*v.size().last().unwrap() as f64).sqrt(); }
|
||||
if self.scale {
|
||||
w = w / (*v.size().last().unwrap() as f64).sqrt();
|
||||
}
|
||||
|
||||
let (nd, ns) = (w.size()[2], w.size()[3]);
|
||||
let b = self.bias.narrow(2, ns - nd, nd).narrow(3, 0, ns);
|
||||
|
||||
let mut w: Tensor = w * &b + 1e4 * (&b - 1);
|
||||
if let Some(mask) = attention_mask { w = w + mask; }
|
||||
if let Some(mask) = attention_mask {
|
||||
w = w + mask;
|
||||
}
|
||||
w = w.softmax(-1, Float).apply_t(&self.attn_dropout, train);
|
||||
let output = w.matmul(&v);
|
||||
|
||||
@ -128,15 +150,19 @@ impl Attention {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn forward_t(&self, x: &Tensor, layer_past: &Option<Tensor>, attention_mask: &Option<Tensor>, train: bool)
|
||||
-> (Tensor, Tensor, Option<Tensor>) {
|
||||
pub fn forward_t(
|
||||
&self,
|
||||
x: &Tensor,
|
||||
layer_past: &Option<Tensor>,
|
||||
attention_mask: &Option<Tensor>,
|
||||
train: bool,
|
||||
) -> (Tensor, Tensor, Option<Tensor>) {
|
||||
let x = x.apply(&self.c_attn).split(self.n_state, 2);
|
||||
|
||||
let (query, key, value) =
|
||||
(
|
||||
let (query, key, value) = (
|
||||
self.split_heads(&x[0], false),
|
||||
self.split_heads(&x[1], true),
|
||||
self.split_heads(&x[2], false)
|
||||
self.split_heads(&x[2], false),
|
||||
);
|
||||
let (key, value) = match layer_past {
|
||||
Some(past) => {
|
||||
@ -144,12 +170,15 @@ impl Attention {
|
||||
let value = Tensor::cat(&[past.get(1), value], -2);
|
||||
(key, value)
|
||||
}
|
||||
None => (key, value)
|
||||
None => (key, value),
|
||||
};
|
||||
let present = Tensor::stack(&[key.transpose(-2, -1), value.copy()], 0);
|
||||
let (a, attentions) = self.attention(&query, &key, &value, &attention_mask, train);
|
||||
|
||||
let a = self.flatten(a).apply(&self.c_proj).apply_t(&self.resid_dropout, train);
|
||||
let a = self
|
||||
.flatten(a)
|
||||
.apply(&self.c_proj)
|
||||
.apply_t(&self.resid_dropout, train);
|
||||
|
||||
(a, present, attentions)
|
||||
}
|
||||
|
362
src/gpt2/gpt2.rs
362
src/gpt2/gpt2.rs
@ -12,16 +12,16 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tch::{nn, Tensor};
|
||||
use crate::common::dropout::Dropout;
|
||||
use tch::nn::embedding;
|
||||
use crate::common::linear::{linear_no_bias, LinearNoBias};
|
||||
use crate::gpt2::transformer::Block;
|
||||
use tch::kind::Kind::Int64;
|
||||
use std::borrow::BorrowMut;
|
||||
use crate::common::linear::{LinearNoBias, linear_no_bias};
|
||||
use crate::pipelines::generation::{Cache, LMHeadModel};
|
||||
use crate::Config;
|
||||
use crate::pipelines::generation::{LMHeadModel, Cache};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::borrow::BorrowMut;
|
||||
use tch::kind::Kind::Int64;
|
||||
use tch::nn::embedding;
|
||||
use tch::{nn, Tensor};
|
||||
|
||||
/// # GPT2 Pretrained model weight files
|
||||
pub struct Gpt2ModelResources;
|
||||
@ -37,54 +37,114 @@ pub struct Gpt2MergesResources;
|
||||
|
||||
impl Gpt2ModelResources {
|
||||
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
|
||||
pub const GPT2: (&'static str, &'static str) = ("gpt2/model.ot", "https://cdn.huggingface.co/gpt2-rust_model.ot");
|
||||
pub const GPT2: (&'static str, &'static str) = (
|
||||
"gpt2/model.ot",
|
||||
"https://cdn.huggingface.co/gpt2-rust_model.ot",
|
||||
);
|
||||
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
|
||||
pub const GPT2_MEDIUM: (&'static str, &'static str) = ("gpt2-medium/model.ot", "https://cdn.huggingface.co/gpt2-medium-rust_model.ot");
|
||||
pub const GPT2_MEDIUM: (&'static str, &'static str) = (
|
||||
"gpt2-medium/model.ot",
|
||||
"https://cdn.huggingface.co/gpt2-medium-rust_model.ot",
|
||||
);
|
||||
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
|
||||
pub const GPT2_LARGE: (&'static str, &'static str) = ("gpt2-large/model.ot", "https://cdn.huggingface.co/gpt2-large-rust_model.ot");
|
||||
pub const GPT2_LARGE: (&'static str, &'static str) = (
|
||||
"gpt2-large/model.ot",
|
||||
"https://cdn.huggingface.co/gpt2-large-rust_model.ot",
|
||||
);
|
||||
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
|
||||
pub const GPT2_XL: (&'static str, &'static str) = ("gpt2-xl/model.ot", "https://cdn.huggingface.co/gpt2-xl-rust_model.ot");
|
||||
pub const GPT2_XL: (&'static str, &'static str) = (
|
||||
"gpt2-xl/model.ot",
|
||||
"https://cdn.huggingface.co/gpt2-xl-rust_model.ot",
|
||||
);
|
||||
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
|
||||
pub const DISTIL_GPT2: (&'static str, &'static str) = ("distilgpt2/model.ot", "https://cdn.huggingface.co/distilgpt2-rust_model.ot");
|
||||
pub const DISTIL_GPT2: (&'static str, &'static str) = (
|
||||
"distilgpt2/model.ot",
|
||||
"https://cdn.huggingface.co/distilgpt2-rust_model.ot",
|
||||
);
|
||||
}
|
||||
|
||||
impl Gpt2ConfigResources {
|
||||
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
|
||||
pub const GPT2: (&'static str, &'static str) = ("gpt2/config.json", "https://cdn.huggingface.co/gpt2-config.json");
|
||||
pub const GPT2: (&'static str, &'static str) = (
|
||||
"gpt2/config.json",
|
||||
"https://cdn.huggingface.co/gpt2-config.json",
|
||||
);
|
||||
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
|
||||
pub const GPT2_MEDIUM: (&'static str, &'static str) = ("gpt2-medium/config.json", "https://cdn.huggingface.co/gpt2-medium-config.json");
|
||||
pub const GPT2_MEDIUM: (&'static str, &'static str) = (
|
||||
"gpt2-medium/config.json",
|
||||
"https://cdn.huggingface.co/gpt2-medium-config.json",
|
||||
);
|
||||
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
|
||||
pub const GPT2_LARGE: (&'static str, &'static str) = ("gpt2-large/config.json", "https://cdn.huggingface.co/gpt2-large-config.json");
|
||||
pub const GPT2_LARGE: (&'static str, &'static str) = (
|
||||
"gpt2-large/config.json",
|
||||
"https://cdn.huggingface.co/gpt2-large-config.json",
|
||||
);
|
||||
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
|
||||
pub const GPT2_XL: (&'static str, &'static str) = ("gpt2-xl/config.json", "https://cdn.huggingface.co/gpt2-xl-config.json");
|
||||
pub const GPT2_XL: (&'static str, &'static str) = (
|
||||
"gpt2-xl/config.json",
|
||||
"https://cdn.huggingface.co/gpt2-xl-config.json",
|
||||
);
|
||||
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
|
||||
pub const DISTIL_GPT2: (&'static str, &'static str) = ("distilgpt2/config.json", "https://cdn.huggingface.co/distilgpt2-config.json");
|
||||
pub const DISTIL_GPT2: (&'static str, &'static str) = (
|
||||
"distilgpt2/config.json",
|
||||
"https://cdn.huggingface.co/distilgpt2-config.json",
|
||||
);
|
||||
}
|
||||
|
||||
impl Gpt2VocabResources {
|
||||
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
|
||||
pub const GPT2: (&'static str, &'static str) = ("gpt2/vocab.txt", "https://cdn.huggingface.co/gpt2-vocab.json");
|
||||
pub const GPT2: (&'static str, &'static str) = (
|
||||
"gpt2/vocab.txt",
|
||||
"https://cdn.huggingface.co/gpt2-vocab.json",
|
||||
);
|
||||
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
|
||||
pub const GPT2_MEDIUM: (&'static str, &'static str) = ("gpt2-medium/vocab.txt", "https://cdn.huggingface.co/gpt2-medium-vocab.json");
|
||||
pub const GPT2_MEDIUM: (&'static str, &'static str) = (
|
||||
"gpt2-medium/vocab.txt",
|
||||
"https://cdn.huggingface.co/gpt2-medium-vocab.json",
|
||||
);
|
||||
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
|
||||
pub const GPT2_LARGE: (&'static str, &'static str) = ("gpt2-large/vocab.txt", "https://cdn.huggingface.co/gpt2-large-vocab.json");
|
||||
pub const GPT2_LARGE: (&'static str, &'static str) = (
|
||||
"gpt2-large/vocab.txt",
|
||||
"https://cdn.huggingface.co/gpt2-large-vocab.json",
|
||||
);
|
||||
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
|
||||
pub const GPT2_XL: (&'static str, &'static str) = ("gpt2-xl/vocab.txt", "https://cdn.huggingface.co/gpt2-xl-vocab.json");
|
||||
pub const GPT2_XL: (&'static str, &'static str) = (
|
||||
"gpt2-xl/vocab.txt",
|
||||
"https://cdn.huggingface.co/gpt2-xl-vocab.json",
|
||||
);
|
||||
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
|
||||
pub const DISTIL_GPT2: (&'static str, &'static str) = ("distilgpt2/vocab.txt", "https://cdn.huggingface.co/distilgpt2-vocab.json");
|
||||
pub const DISTIL_GPT2: (&'static str, &'static str) = (
|
||||
"distilgpt2/vocab.txt",
|
||||
"https://cdn.huggingface.co/distilgpt2-vocab.json",
|
||||
);
|
||||
}
|
||||
|
||||
impl Gpt2MergesResources {
|
||||
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
|
||||
pub const GPT2: (&'static str, &'static str) = ("gpt2/merges.txt", "https://cdn.huggingface.co/gpt2-merges.txt");
|
||||
pub const GPT2: (&'static str, &'static str) = (
|
||||
"gpt2/merges.txt",
|
||||
"https://cdn.huggingface.co/gpt2-merges.txt",
|
||||
);
|
||||
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
|
||||
pub const GPT2_MEDIUM: (&'static str, &'static str) = ("gpt2-medium/merges.txt", "https://cdn.huggingface.co/gpt2-medium-merges.txt");
|
||||
pub const GPT2_MEDIUM: (&'static str, &'static str) = (
|
||||
"gpt2-medium/merges.txt",
|
||||
"https://cdn.huggingface.co/gpt2-medium-merges.txt",
|
||||
);
|
||||
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
|
||||
pub const GPT2_LARGE: (&'static str, &'static str) = ("gpt2-large/merges.txt", "https://cdn.huggingface.co/gpt2-large-merges.txt");
|
||||
pub const GPT2_LARGE: (&'static str, &'static str) = (
|
||||
"gpt2-large/merges.txt",
|
||||
"https://cdn.huggingface.co/gpt2-large-merges.txt",
|
||||
);
|
||||
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
|
||||
pub const GPT2_XL: (&'static str, &'static str) = ("gpt2-xl/merges.txt", "https://cdn.huggingface.co/gpt2-xl-merges.txt");
|
||||
pub const GPT2_XL: (&'static str, &'static str) = (
|
||||
"gpt2-xl/merges.txt",
|
||||
"https://cdn.huggingface.co/gpt2-xl-merges.txt",
|
||||
);
|
||||
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
|
||||
pub const DISTIL_GPT2: (&'static str, &'static str) = ("distilgpt2/merges.txt", "https://cdn.huggingface.co/distilgpt2-merges.txt");
|
||||
pub const DISTIL_GPT2: (&'static str, &'static str) = (
|
||||
"distilgpt2/merges.txt",
|
||||
"https://cdn.huggingface.co/distilgpt2-merges.txt",
|
||||
);
|
||||
}
|
||||
|
||||
#[allow(non_camel_case_types)]
|
||||
@ -156,10 +216,10 @@ impl Gpt2Model {
|
||||
/// # Example
|
||||
///
|
||||
/// ```no_run
|
||||
/// use tch::{nn, Device};
|
||||
/// use rust_bert::gpt2::{Gpt2Config, Gpt2Model};
|
||||
/// use rust_bert::Config;
|
||||
/// use std::path::Path;
|
||||
/// use rust_bert::gpt2::{Gpt2Config, Gpt2Model};
|
||||
/// use tch::{nn, Device};
|
||||
///
|
||||
/// let config_path = Path::new("path/to/config.json");
|
||||
/// let device = Device::Cpu;
|
||||
@ -167,37 +227,58 @@ impl Gpt2Model {
|
||||
/// let config = Gpt2Config::from_file(config_path);
|
||||
/// let gpt2: Gpt2Model = Gpt2Model::new(&(&p.root() / "gpt2"), &config);
|
||||
/// ```
|
||||
///
|
||||
pub fn new(p: &nn::Path, config: &Gpt2Config) -> Gpt2Model {
|
||||
let p = &(p / "transformer");
|
||||
let wte = embedding(&(p / "wte"), config.vocab_size, config.n_embd, Default::default());
|
||||
let wpe = embedding(&(p / "wpe"), config.n_positions, config.n_embd, Default::default());
|
||||
let wte = embedding(
|
||||
&(p / "wte"),
|
||||
config.vocab_size,
|
||||
config.n_embd,
|
||||
Default::default(),
|
||||
);
|
||||
let wpe = embedding(
|
||||
&(p / "wpe"),
|
||||
config.n_positions,
|
||||
config.n_embd,
|
||||
Default::default(),
|
||||
);
|
||||
|
||||
let embd_pdrop = match config.embd_pdrop {
|
||||
Some(value) => value,
|
||||
None => 0.1
|
||||
None => 0.1,
|
||||
};
|
||||
let drop = Dropout::new(embd_pdrop);
|
||||
let layer_norm_config = nn::LayerNormConfig { eps: config.layer_norm_epsilon, ..Default::default() };
|
||||
let layer_norm_config = nn::LayerNormConfig {
|
||||
eps: config.layer_norm_epsilon,
|
||||
..Default::default()
|
||||
};
|
||||
let ln_f = nn::layer_norm(p / "ln_f", vec![config.n_embd], layer_norm_config);
|
||||
let mut h: Vec<Block> = vec!();
|
||||
let mut h: Vec<Block> = vec![];
|
||||
let h_path = &(p / "h");
|
||||
for layer_index in 0..config.n_layer {
|
||||
h.push(Block::new(&(h_path / layer_index), config, true));
|
||||
};
|
||||
}
|
||||
let output_attentions = match config.output_attentions {
|
||||
Some(value) => value,
|
||||
None => false
|
||||
None => false,
|
||||
};
|
||||
let output_past = match config.output_past {
|
||||
Some(value) => value,
|
||||
None => true
|
||||
None => true,
|
||||
};
|
||||
let output_hidden_states = match config.output_hidden_states {
|
||||
Some(value) => value,
|
||||
None => false
|
||||
None => false,
|
||||
};
|
||||
Gpt2Model { wte, wpe, drop, ln_f, h, output_past, output_hidden_states, output_attentions }
|
||||
Gpt2Model {
|
||||
wte,
|
||||
wpe,
|
||||
drop,
|
||||
ln_f,
|
||||
h,
|
||||
output_past,
|
||||
output_hidden_states,
|
||||
output_attentions,
|
||||
}
|
||||
}
|
||||
|
||||
/// Forward pass through the model
|
||||
@ -226,7 +307,7 @@ impl Gpt2Model {
|
||||
/// # use rust_bert::Config;
|
||||
/// # use std::path::Path;
|
||||
/// # use tch::kind::Kind::{Int64, Double};
|
||||
/// use rust_bert::gpt2::{Gpt2Model, Gpt2Config};
|
||||
/// use rust_bert::gpt2::{Gpt2Config, Gpt2Model};
|
||||
/// # let config_path = Path::new("path/to/config.json");
|
||||
/// # let vocab_path = Path::new("path/to/vocab.txt");
|
||||
/// # let device = Device::Cpu;
|
||||
@ -237,48 +318,86 @@ impl Gpt2Model {
|
||||
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
|
||||
/// let mut past: Vec<Tensor> = Vec::with_capacity(config.n_layer as usize);
|
||||
/// for _ in 0..config.n_layer as usize {
|
||||
/// past.push(Tensor::rand(&[2, batch_size, config.n_head, past_sequence_length, config.n_embd / config.n_head], (Double, device)))
|
||||
/// past.push(Tensor::rand(
|
||||
/// &[
|
||||
/// 2,
|
||||
/// batch_size,
|
||||
/// config.n_head,
|
||||
/// past_sequence_length,
|
||||
/// config.n_embd / config.n_head,
|
||||
/// ],
|
||||
/// (Double, device),
|
||||
/// ))
|
||||
/// }
|
||||
/// let attention_mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
|
||||
/// let token_type_ids = Tensor::ones(&[batch_size, sequence_length], (Int64, device));
|
||||
/// let position_ids = Tensor::arange(sequence_length, (Int64, device)).expand(&[batch_size, sequence_length], true);
|
||||
/// let position_ids = Tensor::arange(sequence_length, (Int64, device))
|
||||
/// .expand(&[batch_size, sequence_length], true);
|
||||
///
|
||||
/// let (output, past, hidden_states, attentions) = no_grad(|| {
|
||||
/// gpt2_model
|
||||
/// .forward_t(&Some(input_tensor),
|
||||
/// .forward_t(
|
||||
/// &Some(input_tensor),
|
||||
/// &Some(past),
|
||||
/// &Some(attention_mask),
|
||||
/// &Some(token_type_ids),
|
||||
/// &Some(position_ids),
|
||||
/// &None,
|
||||
/// false).unwrap()
|
||||
/// false,
|
||||
/// )
|
||||
/// .unwrap()
|
||||
/// });
|
||||
///
|
||||
/// ```
|
||||
///
|
||||
pub fn forward_t(&self,
|
||||
pub fn forward_t(
|
||||
&self,
|
||||
input_ids: &Option<Tensor>,
|
||||
layer_past: &Option<Vec<Tensor>>,
|
||||
attention_mask: &Option<Tensor>,
|
||||
token_type_ids: &Option<Tensor>,
|
||||
position_ids: &Option<Tensor>,
|
||||
input_embeds: &Option<Tensor>,
|
||||
train: bool) -> Result<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
|
||||
train: bool,
|
||||
) -> Result<
|
||||
(
|
||||
Tensor,
|
||||
Option<Vec<Tensor>>,
|
||||
Option<Vec<Tensor>>,
|
||||
Option<Vec<Tensor>>,
|
||||
),
|
||||
&'static str,
|
||||
> {
|
||||
let (input_embeddings, seq_length) = match input_ids {
|
||||
Some(input_value) => match input_embeds {
|
||||
Some(_) => { return Err("Only one of input ids or input embeddings may be set"); }
|
||||
None => (input_value.apply(&self.wte), *input_value.size().last().unwrap())
|
||||
Some(_) => {
|
||||
return Err("Only one of input ids or input embeddings may be set");
|
||||
}
|
||||
None => (
|
||||
input_value.apply(&self.wte),
|
||||
*input_value.size().last().unwrap(),
|
||||
),
|
||||
},
|
||||
None => match input_embeds {
|
||||
Some(embeds) => (embeds.copy(), embeds.size()[1]),
|
||||
None => { return Err("At least one of input ids or input embeddings must be set"); }
|
||||
None => {
|
||||
return Err("At least one of input ids or input embeddings must be set");
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
let (layer_past, layer_past_length) = match layer_past {
|
||||
Some(value) => {
|
||||
assert_eq!(value.len(), self.h.len(), "Past activations vector must be of length equal to the number of layers");
|
||||
(value.iter().map(|v| Some(v.copy())).collect::<Vec<Option<Tensor>>>(), value[0].size()[3])
|
||||
assert_eq!(
|
||||
value.len(),
|
||||
self.h.len(),
|
||||
"Past activations vector must be of length equal to the number of layers"
|
||||
);
|
||||
(
|
||||
value
|
||||
.iter()
|
||||
.map(|v| Some(v.copy()))
|
||||
.collect::<Vec<Option<Tensor>>>(),
|
||||
value[0].size()[3],
|
||||
)
|
||||
}
|
||||
None => {
|
||||
let mut out = Vec::with_capacity(self.h.len());
|
||||
@ -289,31 +408,45 @@ impl Gpt2Model {
|
||||
|
||||
let position_ids = match position_ids {
|
||||
Some(value) => value.copy(),
|
||||
None => Tensor::arange1(layer_past_length, seq_length + layer_past_length, (Int64, input_embeddings.device())).unsqueeze(0)
|
||||
None => Tensor::arange1(
|
||||
layer_past_length,
|
||||
seq_length + layer_past_length,
|
||||
(Int64, input_embeddings.device()),
|
||||
)
|
||||
.unsqueeze(0),
|
||||
};
|
||||
|
||||
let attention_mask: Option<Tensor> = match attention_mask {
|
||||
Some(value) => {
|
||||
Some(
|
||||
Some(value) => Some(
|
||||
(value
|
||||
.view((input_embeddings.size()[0], -1))
|
||||
.unsqueeze(1)
|
||||
.unsqueeze(2)
|
||||
- 1.0
|
||||
) * 10000.0)
|
||||
}
|
||||
None => None
|
||||
- 1.0)
|
||||
* 10000.0,
|
||||
),
|
||||
None => None,
|
||||
};
|
||||
|
||||
let position_embeds = position_ids.apply(&self.wpe);
|
||||
let token_type_embeds = match token_type_ids {
|
||||
Some(value) => value.apply(&self.wte),
|
||||
None => Tensor::zeros_like(&position_embeds)
|
||||
None => Tensor::zeros_like(&position_embeds),
|
||||
};
|
||||
let mut hidden_state: Tensor =
|
||||
(input_embeddings + position_embeds + token_type_embeds).apply_t(&self.drop, train);
|
||||
let mut all_presents: Option<Vec<Tensor>> =
|
||||
if self.output_past { Some(vec![]) } else { None };
|
||||
let mut all_hidden_states: Option<Vec<Tensor>> = if self.output_hidden_states {
|
||||
Some(vec![])
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let mut all_attentions: Option<Vec<Tensor>> = if self.output_attentions {
|
||||
Some(vec![])
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let mut hidden_state: Tensor = (input_embeddings + position_embeds + token_type_embeds).apply_t(&self.drop, train);
|
||||
let mut all_presents: Option<Vec<Tensor>> = if self.output_past { Some(vec!()) } else { None };
|
||||
let mut all_hidden_states: Option<Vec<Tensor>> = if self.output_hidden_states { Some(vec!()) } else { None };
|
||||
let mut all_attentions: Option<Vec<Tensor>> = if self.output_attentions { Some(vec!()) } else { None };
|
||||
|
||||
let mut layer_iter = self.h.iter().zip(layer_past);
|
||||
loop {
|
||||
@ -333,11 +466,16 @@ impl Gpt2Model {
|
||||
attentions.push(temp.2.as_ref().unwrap().copy());
|
||||
};
|
||||
}
|
||||
None => break
|
||||
};
|
||||
None => break,
|
||||
};
|
||||
}
|
||||
|
||||
Ok((hidden_state.apply(&self.ln_f), all_presents, all_hidden_states, all_attentions))
|
||||
Ok((
|
||||
hidden_state.apply(&self.ln_f),
|
||||
all_presents,
|
||||
all_hidden_states,
|
||||
all_attentions,
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
@ -362,10 +500,10 @@ impl GPT2LMHeadModel {
|
||||
/// # Example
|
||||
///
|
||||
/// ```no_run
|
||||
/// use tch::{nn, Device};
|
||||
/// use rust_bert::gpt2::{GPT2LMHeadModel, Gpt2Config};
|
||||
/// use rust_bert::Config;
|
||||
/// use std::path::Path;
|
||||
/// use rust_bert::gpt2::{Gpt2Config, GPT2LMHeadModel};
|
||||
/// use tch::{nn, Device};
|
||||
///
|
||||
/// let config_path = Path::new("path/to/config.json");
|
||||
/// let device = Device::Cpu;
|
||||
@ -373,11 +511,18 @@ impl GPT2LMHeadModel {
|
||||
/// let config = Gpt2Config::from_file(config_path);
|
||||
/// let gpt2: GPT2LMHeadModel = GPT2LMHeadModel::new(&(&p.root() / "gpt2"), &config);
|
||||
/// ```
|
||||
///
|
||||
pub fn new(p: &nn::Path, config: &Gpt2Config) -> GPT2LMHeadModel {
|
||||
let transformer = Gpt2Model::new(&p, config);
|
||||
let lm_head = linear_no_bias(&(p / "lm_head"), config.n_embd, config.vocab_size, Default::default());
|
||||
GPT2LMHeadModel { transformer, lm_head }
|
||||
let lm_head = linear_no_bias(
|
||||
&(p / "lm_head"),
|
||||
config.n_embd,
|
||||
config.vocab_size,
|
||||
Default::default(),
|
||||
);
|
||||
GPT2LMHeadModel {
|
||||
transformer,
|
||||
lm_head,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -412,8 +557,8 @@ impl LMHeadModel for GPT2LMHeadModel {
|
||||
/// # use rust_bert::Config;
|
||||
/// # use std::path::Path;
|
||||
/// # use tch::kind::Kind::{Int64, Double};
|
||||
/// use rust_bert::gpt2::{Gpt2Config, GPT2LMHeadModel};
|
||||
/// use rust_bert::pipelines::generation::{LMHeadModel, Cache};
|
||||
/// use rust_bert::gpt2::{GPT2LMHeadModel, Gpt2Config};
|
||||
/// use rust_bert::pipelines::generation::{Cache, LMHeadModel};
|
||||
/// # let config_path = Path::new("path/to/config.json");
|
||||
/// # let vocab_path = Path::new("path/to/vocab.txt");
|
||||
/// # let device = Device::Cpu;
|
||||
@ -424,15 +569,26 @@ impl LMHeadModel for GPT2LMHeadModel {
|
||||
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
|
||||
/// let mut past: Vec<Tensor> = Vec::with_capacity(config.n_layer as usize);
|
||||
/// for _ in 0..config.n_layer as usize {
|
||||
/// past.push(Tensor::rand(&[2, batch_size, config.n_head, past_sequence_length, config.n_embd / config.n_head], (Double, device)))
|
||||
/// past.push(Tensor::rand(
|
||||
/// &[
|
||||
/// 2,
|
||||
/// batch_size,
|
||||
/// config.n_head,
|
||||
/// past_sequence_length,
|
||||
/// config.n_embd / config.n_head,
|
||||
/// ],
|
||||
/// (Double, device),
|
||||
/// ))
|
||||
/// }
|
||||
/// let attention_mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
|
||||
/// let token_type_ids = Tensor::ones(&[batch_size, sequence_length], (Int64, device));
|
||||
/// let position_ids = Tensor::arange(sequence_length, (Int64, device)).expand(&[batch_size, sequence_length], true);
|
||||
/// let position_ids = Tensor::arange(sequence_length, (Int64, device))
|
||||
/// .expand(&[batch_size, sequence_length], true);
|
||||
///
|
||||
/// let (output, _, past, hidden_states, attentions) = no_grad(|| {
|
||||
/// gpt2_model
|
||||
/// .forward_t(&Some(input_tensor),
|
||||
/// .forward_t(
|
||||
/// &Some(input_tensor),
|
||||
/// Cache::GPT2Cache(Some(past)),
|
||||
/// &Some(attention_mask),
|
||||
/// &Some(token_type_ids),
|
||||
@ -440,12 +596,13 @@ impl LMHeadModel for GPT2LMHeadModel {
|
||||
/// &None,
|
||||
/// None,
|
||||
/// &None,
|
||||
/// false).unwrap()
|
||||
/// false,
|
||||
/// )
|
||||
/// .unwrap()
|
||||
/// });
|
||||
///
|
||||
/// ```
|
||||
///
|
||||
fn forward_t(&self,
|
||||
fn forward_t(
|
||||
&self,
|
||||
input_ids: &Option<Tensor>,
|
||||
layer_past: Cache,
|
||||
attention_mask: &Option<Tensor>,
|
||||
@ -454,29 +611,46 @@ impl LMHeadModel for GPT2LMHeadModel {
|
||||
input_embeds: &Option<Tensor>,
|
||||
_encoder_outputs: Option<&Tensor>,
|
||||
_decoder_input_ids: &Option<Tensor>,
|
||||
train: bool) -> Result<(Tensor, Option<Tensor>, Cache, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
|
||||
let (output,
|
||||
past,
|
||||
all_hidden_states,
|
||||
all_attentions) = match layer_past {
|
||||
Cache::GPT2Cache(layer_past) => Ok(self.transformer.forward_t(input_ids,
|
||||
train: bool,
|
||||
) -> Result<
|
||||
(
|
||||
Tensor,
|
||||
Option<Tensor>,
|
||||
Cache,
|
||||
Option<Vec<Tensor>>,
|
||||
Option<Vec<Tensor>>,
|
||||
),
|
||||
&'static str,
|
||||
> {
|
||||
let (output, past, all_hidden_states, all_attentions) = match layer_past {
|
||||
Cache::GPT2Cache(layer_past) => Ok(self.transformer.forward_t(
|
||||
input_ids,
|
||||
&layer_past,
|
||||
attention_mask,
|
||||
token_type_ids,
|
||||
position_ids,
|
||||
input_embeds,
|
||||
train)?),
|
||||
Cache::None => Ok(self.transformer.forward_t(input_ids,
|
||||
train,
|
||||
)?),
|
||||
Cache::None => Ok(self.transformer.forward_t(
|
||||
input_ids,
|
||||
&None,
|
||||
attention_mask,
|
||||
token_type_ids,
|
||||
position_ids,
|
||||
input_embeds,
|
||||
train)?),
|
||||
_ => Err("Cache not compatible with GPT2 model")
|
||||
train,
|
||||
)?),
|
||||
_ => Err("Cache not compatible with GPT2 model"),
|
||||
}?;
|
||||
|
||||
let lm_logits = output.apply(&self.lm_head);
|
||||
Ok((lm_logits, None, Cache::GPT2Cache(past), all_hidden_states, all_attentions))
|
||||
Ok((
|
||||
lm_logits,
|
||||
None,
|
||||
Cache::GPT2Cache(past),
|
||||
all_hidden_states,
|
||||
all_attentions,
|
||||
))
|
||||
}
|
||||
}
|
||||
|
@ -19,14 +19,22 @@
|
||||
//! use rust_tokenizers::Gpt2Tokenizer;
|
||||
//! use tch::{nn, Device};
|
||||
//! # use std::path::PathBuf;
|
||||
//! use rust_bert::gpt2::{GPT2LMHeadModel, Gpt2Config};
|
||||
//! use rust_bert::resources::{download_resource, LocalResource, Resource};
|
||||
//! use rust_bert::Config;
|
||||
//! use rust_bert::gpt2::{Gpt2Config, GPT2LMHeadModel};
|
||||
//! use rust_bert::resources::{Resource, download_resource, LocalResource};
|
||||
//!
|
||||
//! let config_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/config.json")});
|
||||
//! let vocab_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/vocab.txt")});
|
||||
//! let merges_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/vocab.txt")});
|
||||
//! let weights_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/model.ot")});
|
||||
//! let config_resource = Resource::Local(LocalResource {
|
||||
//! local_path: PathBuf::from("path/to/config.json"),
|
||||
//! });
|
||||
//! let vocab_resource = Resource::Local(LocalResource {
|
||||
//! local_path: PathBuf::from("path/to/vocab.txt"),
|
||||
//! });
|
||||
//! let merges_resource = Resource::Local(LocalResource {
|
||||
//! local_path: PathBuf::from("path/to/vocab.txt"),
|
||||
//! });
|
||||
//! let weights_resource = Resource::Local(LocalResource {
|
||||
//! local_path: PathBuf::from("path/to/model.ot"),
|
||||
//! });
|
||||
//! let config_path = download_resource(&config_resource)?;
|
||||
//! let vocab_path = download_resource(&vocab_resource)?;
|
||||
//! let merges_path = download_resource(&merges_resource)?;
|
||||
@ -34,7 +42,11 @@
|
||||
//!
|
||||
//! let device = Device::cuda_if_available();
|
||||
//! let mut vs = nn::VarStore::new(device);
|
||||
//! let tokenizer: Gpt2Tokenizer = Gpt2Tokenizer::from_file(vocab_path.to_str().unwrap(), merges_path.to_str().unwrap(), true);
|
||||
//! let tokenizer: Gpt2Tokenizer = Gpt2Tokenizer::from_file(
|
||||
//! vocab_path.to_str().unwrap(),
|
||||
//! merges_path.to_str().unwrap(),
|
||||
//! true,
|
||||
//! );
|
||||
//! let config = Gpt2Config::from_file(config_path);
|
||||
//! let gpt2_model = GPT2LMHeadModel::new(&vs.root(), &config);
|
||||
//! vs.load(weights_path)?;
|
||||
@ -43,9 +55,11 @@
|
||||
//! # }
|
||||
//! ```
|
||||
|
||||
mod gpt2;
|
||||
pub(crate) mod attention;
|
||||
mod gpt2;
|
||||
pub(crate) mod transformer;
|
||||
|
||||
pub use gpt2::{Gpt2ModelResources, Gpt2ConfigResources, Gpt2VocabResources, Gpt2MergesResources,
|
||||
Gpt2Config, Gpt2Model, GptActivation, GPT2LMHeadModel};
|
||||
pub use gpt2::{
|
||||
GPT2LMHeadModel, Gpt2Config, Gpt2ConfigResources, Gpt2MergesResources, Gpt2Model,
|
||||
Gpt2ModelResources, Gpt2VocabResources, GptActivation,
|
||||
};
|
||||
|
@ -12,11 +12,11 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use crate::gpt2::attention::{GPTConv1D, Attention};
|
||||
use tch::{Tensor, nn};
|
||||
use crate::common::dropout::Dropout;
|
||||
use crate::gpt2::gpt2::{Gpt2Config, GptActivation};
|
||||
use crate::common::activations::{_gelu_new, _relu, _swish};
|
||||
use crate::common::dropout::Dropout;
|
||||
use crate::gpt2::attention::{Attention, GPTConv1D};
|
||||
use crate::gpt2::gpt2::{Gpt2Config, GptActivation};
|
||||
use tch::{nn, Tensor};
|
||||
|
||||
pub struct MLP {
|
||||
c_fc: GPTConv1D,
|
||||
@ -35,14 +35,19 @@ impl MLP {
|
||||
GptActivation::relu => _relu,
|
||||
GptActivation::swish => _swish,
|
||||
},
|
||||
None => _gelu_new
|
||||
None => _gelu_new,
|
||||
});
|
||||
let resid_pdrop = match config.resid_pdrop {
|
||||
Some(value) => value,
|
||||
None => 0.1
|
||||
None => 0.1,
|
||||
};
|
||||
let dropout = Dropout::new(resid_pdrop);
|
||||
MLP { c_fc, c_proj, activation, dropout }
|
||||
MLP {
|
||||
c_fc,
|
||||
c_proj,
|
||||
activation,
|
||||
dropout,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn forward_t(&self, x: &Tensor, train: bool) -> Tensor {
|
||||
@ -60,18 +65,33 @@ pub struct Block {
|
||||
|
||||
impl Block {
|
||||
pub fn new(p: &nn::Path, config: &Gpt2Config, scale: bool) -> Block {
|
||||
let layer_norm_config = nn::LayerNormConfig { eps: config.layer_norm_epsilon, ..Default::default() };
|
||||
let layer_norm_config = nn::LayerNormConfig {
|
||||
eps: config.layer_norm_epsilon,
|
||||
..Default::default()
|
||||
};
|
||||
let ln_1 = nn::layer_norm(p / "ln_1", vec![config.n_embd], layer_norm_config);
|
||||
let ln_2 = nn::layer_norm(p / "ln_2", vec![config.n_embd], layer_norm_config);
|
||||
let attn = Attention::new(&(p / "attn"), config, scale);
|
||||
let mlp = MLP::new(&(p / "mlp"), config);
|
||||
|
||||
Block { ln_1, attn, ln_2, mlp }
|
||||
Block {
|
||||
ln_1,
|
||||
attn,
|
||||
ln_2,
|
||||
mlp,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn forward_t(&self, x: &Tensor, layer_past: &Option<Tensor>, attention_mask: &Option<Tensor>, train: bool)
|
||||
-> (Tensor, Tensor, Option<Tensor>) {
|
||||
let (output, present, attentions) = self.attn.forward_t(&x.apply(&self.ln_1), layer_past, attention_mask, train);
|
||||
pub fn forward_t(
|
||||
&self,
|
||||
x: &Tensor,
|
||||
layer_past: &Option<Tensor>,
|
||||
attention_mask: &Option<Tensor>,
|
||||
train: bool,
|
||||
) -> (Tensor, Tensor, Option<Tensor>) {
|
||||
let (output, present, attentions) =
|
||||
self.attn
|
||||
.forward_t(&x.apply(&self.ln_1), layer_past, attention_mask, train);
|
||||
let x = x + output;
|
||||
let m = self.mlp.forward_t(&x.apply(&self.ln_2), train);
|
||||
let x = x + m;
|
||||
|
23
src/lib.rs
23
src/lib.rs
@ -16,14 +16,14 @@
|
||||
//!
|
||||
//! More information on these can be found in the [`pipelines` module](./pipelines/index.html)
|
||||
//! ```no_run
|
||||
//! use rust_bert::pipelines::question_answering::{QuestionAnsweringModel, QaInput};
|
||||
//! use rust_bert::pipelines::question_answering::{QaInput, QuestionAnsweringModel};
|
||||
//!
|
||||
//! # fn main() -> failure::Fallible<()> {
|
||||
//! let qa_model = QuestionAnsweringModel::new(Default::default())?;
|
||||
//!
|
||||
//! let question = String::from("Where does Amy live ?");
|
||||
//! let context = String::from("Amy lives in Amsterdam");
|
||||
//! let answers = qa_model.predict(&vec!(QaInput { question, context }), 1, 32);
|
||||
//! let answers = qa_model.predict(&vec![QaInput { question, context }], 1, 32);
|
||||
//! # Ok(())
|
||||
//! # }
|
||||
//! ```
|
||||
@ -55,19 +55,18 @@
|
||||
//! - Set-up a virtual environment and install dependencies
|
||||
//! - run the conversion script python /utils/download-dependencies_{MODEL_TO_DOWNLOAD}.py. The dependencies will be downloaded to the user's home directory, under ~/rustbert/{}
|
||||
//! 3. Run the example cargo run --release
|
||||
//!
|
||||
|
||||
pub mod distilbert;
|
||||
pub mod bert;
|
||||
pub mod roberta;
|
||||
pub mod openai_gpt;
|
||||
pub mod gpt2;
|
||||
pub mod bart;
|
||||
pub mod electra;
|
||||
pub mod marian;
|
||||
pub mod albert;
|
||||
pub mod bart;
|
||||
pub mod bert;
|
||||
mod common;
|
||||
pub mod distilbert;
|
||||
pub mod electra;
|
||||
pub mod gpt2;
|
||||
pub mod marian;
|
||||
pub mod openai_gpt;
|
||||
pub mod pipelines;
|
||||
pub mod roberta;
|
||||
|
||||
pub use common::Config;
|
||||
pub use common::resources;
|
||||
pub use common::Config;
|
||||
|
@ -11,10 +11,10 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use crate::bart::{BartModel, BartConfig, LayerState};
|
||||
use tch::{Tensor, nn};
|
||||
use crate::pipelines::generation::{LMHeadModel, Cache};
|
||||
use crate::bart::{BartConfig, BartModel, LayerState};
|
||||
use crate::pipelines::generation::{Cache, LMHeadModel};
|
||||
use tch::nn::Init;
|
||||
use tch::{nn, Tensor};
|
||||
|
||||
/// # Marian Pretrained model weight files
|
||||
pub struct MarianModelResources;
|
||||
@ -33,78 +33,174 @@ pub struct MarianPrefix;
|
||||
|
||||
impl MarianModelResources {
|
||||
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. Modified with conversion to C-array format.
|
||||
pub const ENGLISH2ROMANCE: (&'static str, &'static str) = ("marian-mt-en-ROMANCE/model.ot", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-ROMANCE/rust_model.ot");
|
||||
pub const ENGLISH2ROMANCE: (&'static str, &'static str) = (
|
||||
"marian-mt-en-ROMANCE/model.ot",
|
||||
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-ROMANCE/rust_model.ot",
|
||||
);
|
||||
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. Modified with conversion to C-array format.
|
||||
pub const ROMANCE2ENGLISH: (&'static str, &'static str) = ("marian-mt-ROMANCE-en/model.ot", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-ROMANCE-en/rust_model.ot");
|
||||
pub const ROMANCE2ENGLISH: (&'static str, &'static str) = (
|
||||
"marian-mt-ROMANCE-en/model.ot",
|
||||
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-ROMANCE-en/rust_model.ot",
|
||||
);
|
||||
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. Modified with conversion to C-array format.
|
||||
pub const ENGLISH2GERMAN: (&'static str, &'static str) = ("marian-mt-en-de/model.ot", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-de/rust_model.ot");
|
||||
pub const ENGLISH2GERMAN: (&'static str, &'static str) = (
|
||||
"marian-mt-en-de/model.ot",
|
||||
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-de/rust_model.ot",
|
||||
);
|
||||
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. Modified with conversion to C-array format.
|
||||
pub const GERMAN2ENGLISH: (&'static str, &'static str) = ("marian-mt-de-en/model.ot", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-de-en/rust_model.ot");
|
||||
pub const GERMAN2ENGLISH: (&'static str, &'static str) = (
|
||||
"marian-mt-de-en/model.ot",
|
||||
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-de-en/rust_model.ot",
|
||||
);
|
||||
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. Modified with conversion to C-array format.
|
||||
pub const ENGLISH2RUSSIAN: (&'static str, &'static str) = ("marian-mt-en-ru/model.ot", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-ru/rust_model.ot");
|
||||
pub const ENGLISH2RUSSIAN: (&'static str, &'static str) = (
|
||||
"marian-mt-en-ru/model.ot",
|
||||
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-ru/rust_model.ot",
|
||||
);
|
||||
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. Modified with conversion to C-array format.
|
||||
pub const RUSSIAN2ENGLISH: (&'static str, &'static str) = ("marian-mt-ru-en/model.ot", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-ru-en/rust_model.ot");
|
||||
pub const RUSSIAN2ENGLISH: (&'static str, &'static str) = (
|
||||
"marian-mt-ru-en/model.ot",
|
||||
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-ru-en/rust_model.ot",
|
||||
);
|
||||
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. Modified with conversion to C-array format.
|
||||
pub const FRENCH2GERMAN: (&'static str, &'static str) = ("marian-mt-fr-de/model.ot", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-fr-de/rust_model.ot");
|
||||
pub const FRENCH2GERMAN: (&'static str, &'static str) = (
|
||||
"marian-mt-fr-de/model.ot",
|
||||
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-fr-de/rust_model.ot",
|
||||
);
|
||||
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. Modified with conversion to C-array format.
|
||||
pub const GERMAN2FRENCH: (&'static str, &'static str) = ("marian-mt-de-fr/model.ot", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-de-fr/rust_model.ot");
|
||||
pub const GERMAN2FRENCH: (&'static str, &'static str) = (
|
||||
"marian-mt-de-fr/model.ot",
|
||||
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-de-fr/rust_model.ot",
|
||||
);
|
||||
}
|
||||
|
||||
impl MarianConfigResources {
|
||||
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
|
||||
pub const ENGLISH2ROMANCE: (&'static str, &'static str) = ("marian-mt-en-ROMANCE/config.json", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-ROMANCE/config.json");
|
||||
pub const ENGLISH2ROMANCE: (&'static str, &'static str) = (
|
||||
"marian-mt-en-ROMANCE/config.json",
|
||||
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-ROMANCE/config.json",
|
||||
);
|
||||
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
|
||||
pub const ROMANCE2ENGLISH: (&'static str, &'static str) = ("marian-mt-ROMANCE-en/config.json", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-ROMANCE-en/config.json");
|
||||
pub const ROMANCE2ENGLISH: (&'static str, &'static str) = (
|
||||
"marian-mt-ROMANCE-en/config.json",
|
||||
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-ROMANCE-en/config.json",
|
||||
);
|
||||
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
|
||||
pub const ENGLISH2GERMAN: (&'static str, &'static str) = ("marian-mt-en-de/config.json", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-de/config.json");
|
||||
pub const ENGLISH2GERMAN: (&'static str, &'static str) = (
|
||||
"marian-mt-en-de/config.json",
|
||||
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-de/config.json",
|
||||
);
|
||||
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
|
||||
pub const GERMAN2ENGLISH: (&'static str, &'static str) = ("marian-mt-de-en/config.json", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-de-en/config.json");
|
||||
pub const GERMAN2ENGLISH: (&'static str, &'static str) = (
|
||||
"marian-mt-de-en/config.json",
|
||||
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-de-en/config.json",
|
||||
);
|
||||
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
|
||||
pub const ENGLISH2RUSSIAN: (&'static str, &'static str) = ("marian-mt-en-ru/config.json", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-ru/config.json");
|
||||
pub const ENGLISH2RUSSIAN: (&'static str, &'static str) = (
|
||||
"marian-mt-en-ru/config.json",
|
||||
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-ru/config.json",
|
||||
);
|
||||
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
|
||||
pub const RUSSIAN2ENGLISH: (&'static str, &'static str) = ("marian-mt-ru-en/config.json", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-ru-en/config.json");
|
||||
pub const RUSSIAN2ENGLISH: (&'static str, &'static str) = (
|
||||
"marian-mt-ru-en/config.json",
|
||||
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-ru-en/config.json",
|
||||
);
|
||||
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
|
||||
pub const FRENCH2GERMAN: (&'static str, &'static str) = ("marian-mt-fr-de/config.json", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-fr-de/config.json");
|
||||
pub const FRENCH2GERMAN: (&'static str, &'static str) = (
|
||||
"marian-mt-fr-de/config.json",
|
||||
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-fr-de/config.json",
|
||||
);
|
||||
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
|
||||
pub const GERMAN2FRENCH: (&'static str, &'static str) = ("marian-mt-de-fr/config.json", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-de-fr/config.json");
|
||||
pub const GERMAN2FRENCH: (&'static str, &'static str) = (
|
||||
"marian-mt-de-fr/config.json",
|
||||
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-de-fr/config.json",
|
||||
);
|
||||
}
|
||||
|
||||
impl MarianVocabResources {
|
||||
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
|
||||
pub const ENGLISH2ROMANCE: (&'static str, &'static str) = ("marian-mt-en-ROMANCE/vocab.json", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-ROMANCE/vocab.json");
|
||||
pub const ENGLISH2ROMANCE: (&'static str, &'static str) = (
|
||||
"marian-mt-en-ROMANCE/vocab.json",
|
||||
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-ROMANCE/vocab.json",
|
||||
);
|
||||
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
|
||||
pub const ROMANCE2ENGLISH: (&'static str, &'static str) = ("marian-mt-ROMANCE-en/vocab.json", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-ROMANCE-en/vocab.json");
|
||||
pub const ROMANCE2ENGLISH: (&'static str, &'static str) = (
|
||||
"marian-mt-ROMANCE-en/vocab.json",
|
||||
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-ROMANCE-en/vocab.json",
|
||||
);
|
||||
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
|
||||
pub const ENGLISH2GERMAN: (&'static str, &'static str) = ("marian-mt-en-de/vocab.json", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-de/vocab.json");
|
||||
pub const ENGLISH2GERMAN: (&'static str, &'static str) = (
|
||||
"marian-mt-en-de/vocab.json",
|
||||
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-de/vocab.json",
|
||||
);
|
||||
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
|
||||
pub const GERMAN2ENGLISH: (&'static str, &'static str) = ("marian-mt-de-en/vocab.json", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-de-en/vocab.json");
|
||||
pub const GERMAN2ENGLISH: (&'static str, &'static str) = (
|
||||
"marian-mt-de-en/vocab.json",
|
||||
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-de-en/vocab.json",
|
||||
);
|
||||
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
|
||||
pub const ENGLISH2RUSSIAN: (&'static str, &'static str) = ("marian-mt-en-ru/vocab.json", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-ru/vocab.json");
|
||||
pub const ENGLISH2RUSSIAN: (&'static str, &'static str) = (
|
||||
"marian-mt-en-ru/vocab.json",
|
||||
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-ru/vocab.json",
|
||||
);
|
||||
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
|
||||
pub const RUSSIAN2ENGLISH: (&'static str, &'static str) = ("marian-mt-ru-en/vocab.json", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-ru-en/vocab.json");
|
||||
pub const RUSSIAN2ENGLISH: (&'static str, &'static str) = (
|
||||
"marian-mt-ru-en/vocab.json",
|
||||
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-ru-en/vocab.json",
|
||||
);
|
||||
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
|
||||
pub const FRENCH2GERMAN: (&'static str, &'static str) = ("marian-mt-fr-de/vocab.json", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-fr-de/vocab.json");
|
||||
pub const FRENCH2GERMAN: (&'static str, &'static str) = (
|
||||
"marian-mt-fr-de/vocab.json",
|
||||
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-fr-de/vocab.json",
|
||||
);
|
||||
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
|
||||
pub const GERMAN2FRENCH: (&'static str, &'static str) = ("marian-mt-de-fr/vocab.json", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-de-fr/vocab.json");
|
||||
pub const GERMAN2FRENCH: (&'static str, &'static str) = (
|
||||
"marian-mt-de-fr/vocab.json",
|
||||
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-de-fr/vocab.json",
|
||||
);
|
||||
}
|
||||
|
||||
impl MarianSpmResources {
|
||||
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
|
||||
pub const ENGLISH2ROMANCE: (&'static str, &'static str) = ("marian-mt-en-ROMANCE/spiece.model", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-ROMANCE/source.spm");
|
||||
pub const ENGLISH2ROMANCE: (&'static str, &'static str) = (
|
||||
"marian-mt-en-ROMANCE/spiece.model",
|
||||
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-ROMANCE/source.spm",
|
||||
);
|
||||
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
|
||||
pub const ROMANCE2ENGLISH: (&'static str, &'static str) = ("marian-mt-ROMANCE-en/spiece.model", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-ROMANCE-en/source.spm");
|
||||
pub const ROMANCE2ENGLISH: (&'static str, &'static str) = (
|
||||
"marian-mt-ROMANCE-en/spiece.model",
|
||||
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-ROMANCE-en/source.spm",
|
||||
);
|
||||
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
|
||||
pub const ENGLISH2GERMAN: (&'static str, &'static str) = ("marian-mt-en-de/spiece.model", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-de/source.spm");
|
||||
pub const ENGLISH2GERMAN: (&'static str, &'static str) = (
|
||||
"marian-mt-en-de/spiece.model",
|
||||
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-de/source.spm",
|
||||
);
|
||||
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
|
||||
pub const GERMAN2ENGLISH: (&'static str, &'static str) = ("marian-mt-de-en/spiece.model", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-de-en/source.spm");
|
||||
pub const GERMAN2ENGLISH: (&'static str, &'static str) = (
|
||||
"marian-mt-de-en/spiece.model",
|
||||
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-de-en/source.spm",
|
||||
);
|
||||
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
|
||||
pub const ENGLISH2RUSSIAN: (&'static str, &'static str) = ("marian-mt-en-ru/spiece.model", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-ru/source.spm");
|
||||
pub const ENGLISH2RUSSIAN: (&'static str, &'static str) = (
|
||||
"marian-mt-en-ru/spiece.model",
|
||||
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-ru/source.spm",
|
||||
);
|
||||
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
|
||||
pub const RUSSIAN2ENGLISH: (&'static str, &'static str) = ("marian-mt-ru-en/spiece.model", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-ru-en/source.spm");
|
||||
pub const RUSSIAN2ENGLISH: (&'static str, &'static str) = (
|
||||
"marian-mt-ru-en/spiece.model",
|
||||
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-ru-en/source.spm",
|
||||
);
|
||||
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
|
||||
pub const FRENCH2GERMAN: (&'static str, &'static str) = ("marian-mt-fr-de/spiece.model", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-fr-de/source.spm");
|
||||
pub const FRENCH2GERMAN: (&'static str, &'static str) = (
|
||||
"marian-mt-fr-de/spiece.model",
|
||||
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-fr-de/source.spm",
|
||||
);
|
||||
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
|
||||
pub const GERMAN2FRENCH: (&'static str, &'static str) = ("marian-mt-de-fr/spiece.model", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-de-fr/source.spm");
|
||||
pub const GERMAN2FRENCH: (&'static str, &'static str) = (
|
||||
"marian-mt-de-fr/spiece.model",
|
||||
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-de-fr/source.spm",
|
||||
);
|
||||
}
|
||||
|
||||
impl MarianPrefix {
|
||||
@ -150,23 +246,34 @@ impl MarianForConditionalGeneration {
|
||||
/// # Example
|
||||
///
|
||||
/// ```no_run
|
||||
/// use tch::{nn, Device};
|
||||
/// use rust_bert::bart::{BartConfig, BartForConditionalGeneration};
|
||||
/// use rust_bert::Config;
|
||||
/// use std::path::Path;
|
||||
/// use rust_bert::bart::{BartConfig, BartForConditionalGeneration};
|
||||
/// use tch::{nn, Device};
|
||||
///
|
||||
/// let config_path = Path::new("path/to/config.json");
|
||||
/// let device = Device::Cpu;
|
||||
/// let p = nn::VarStore::new(device);
|
||||
/// let config = BartConfig::from_file(config_path);
|
||||
/// let generation_mode = true;
|
||||
/// let bart: BartForConditionalGeneration = BartForConditionalGeneration::new(&(&p.root() / "bart"), &config, generation_mode);
|
||||
/// let bart: BartForConditionalGeneration =
|
||||
/// BartForConditionalGeneration::new(&(&p.root() / "bart"), &config, generation_mode);
|
||||
/// ```
|
||||
///
|
||||
pub fn new(p: &nn::Path, config: &BartConfig, generation_mode: bool) -> MarianForConditionalGeneration {
|
||||
pub fn new(
|
||||
p: &nn::Path,
|
||||
config: &BartConfig,
|
||||
generation_mode: bool,
|
||||
) -> MarianForConditionalGeneration {
|
||||
let base_model = BartModel::new(&(p / "model"), config, generation_mode);
|
||||
let final_logits_bias = p.var("final_logits_bias", &[1, config.vocab_size], Init::Const(0.));
|
||||
MarianForConditionalGeneration { base_model, final_logits_bias }
|
||||
let final_logits_bias = p.var(
|
||||
"final_logits_bias",
|
||||
&[1, config.vocab_size],
|
||||
Init::Const(0.),
|
||||
);
|
||||
MarianForConditionalGeneration {
|
||||
base_model,
|
||||
final_logits_bias,
|
||||
}
|
||||
}
|
||||
|
||||
/// Forward pass through the model
|
||||
@ -197,7 +304,7 @@ impl MarianForConditionalGeneration {
|
||||
/// # use rust_bert::Config;
|
||||
/// # use std::path::Path;
|
||||
/// # use tch::kind::Kind::{Int64, Double};
|
||||
/// use rust_bert::bart::{BartConfig};
|
||||
/// use rust_bert::bart::BartConfig;
|
||||
/// use rust_bert::marian::MarianForConditionalGeneration;
|
||||
/// # let config_path = Path::new("path/to/config.json");
|
||||
/// # let vocab_path = Path::new("path/to/vocab.txt");
|
||||
@ -208,49 +315,86 @@ impl MarianForConditionalGeneration {
|
||||
/// let (batch_size, source_sequence_length, target_sequence_length) = (64, 128, 56);
|
||||
/// let input_tensor = Tensor::rand(&[batch_size, source_sequence_length], (Int64, device));
|
||||
/// let target_tensor = Tensor::rand(&[batch_size, target_sequence_length], (Int64, device));
|
||||
/// let encoder_attention_mask = Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
|
||||
/// let decoder_attention_mask = Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
|
||||
/// let encoder_attention_mask =
|
||||
/// Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
|
||||
/// let decoder_attention_mask =
|
||||
/// Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
|
||||
///
|
||||
/// let (decoder_output, encoder_hidden_states, cache,
|
||||
/// all_encoder_hidden_states, all_encoder_attentions,
|
||||
/// all_decoder_hidden_states, all_decoder_attentions) = no_grad(|| {
|
||||
/// marian_model
|
||||
/// .forward_t(Some(&input_tensor),
|
||||
/// let (
|
||||
/// decoder_output,
|
||||
/// encoder_hidden_states,
|
||||
/// cache,
|
||||
/// all_encoder_hidden_states,
|
||||
/// all_encoder_attentions,
|
||||
/// all_decoder_hidden_states,
|
||||
/// all_decoder_attentions,
|
||||
/// ) = no_grad(|| {
|
||||
/// marian_model.forward_t(
|
||||
/// Some(&input_tensor),
|
||||
/// Some(&encoder_attention_mask),
|
||||
/// None,
|
||||
/// Some(&target_tensor),
|
||||
/// Some(&decoder_attention_mask),
|
||||
/// None,
|
||||
/// false)
|
||||
/// false,
|
||||
/// )
|
||||
/// });
|
||||
///
|
||||
/// ```
|
||||
///
|
||||
pub fn forward_t(&self,
|
||||
pub fn forward_t(
|
||||
&self,
|
||||
input_ids: Option<&Tensor>,
|
||||
attention_mask: Option<&Tensor>,
|
||||
encoder_outputs: Option<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>)>,
|
||||
decoder_input_ids: Option<&Tensor>,
|
||||
decoder_attention_mask: Option<&Tensor>,
|
||||
old_layer_states: Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
|
||||
train: bool)
|
||||
-> (Tensor, Tensor, Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
|
||||
Option<Vec<Tensor>>, Option<Vec<Tensor>>,
|
||||
Option<Vec<Tensor>>, Option<Vec<Tensor>>)
|
||||
{
|
||||
let (decoder_outputs, encoder_hidden_states, decoder_cache,
|
||||
all_decoder_hidden_states, all_decoder_attentions,
|
||||
all_encoder_hidden_states, all_encoder_attentions) =
|
||||
self.base_model.forward_t(input_ids, attention_mask, decoder_input_ids, encoder_outputs, decoder_attention_mask, old_layer_states, train);
|
||||
train: bool,
|
||||
) -> (
|
||||
Tensor,
|
||||
Tensor,
|
||||
Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
|
||||
Option<Vec<Tensor>>,
|
||||
Option<Vec<Tensor>>,
|
||||
Option<Vec<Tensor>>,
|
||||
Option<Vec<Tensor>>,
|
||||
) {
|
||||
let (
|
||||
decoder_outputs,
|
||||
encoder_hidden_states,
|
||||
decoder_cache,
|
||||
all_decoder_hidden_states,
|
||||
all_decoder_attentions,
|
||||
all_encoder_hidden_states,
|
||||
all_encoder_attentions,
|
||||
) = self.base_model.forward_t(
|
||||
input_ids,
|
||||
attention_mask,
|
||||
decoder_input_ids,
|
||||
encoder_outputs,
|
||||
decoder_attention_mask,
|
||||
old_layer_states,
|
||||
train,
|
||||
);
|
||||
|
||||
let lm_logits = decoder_outputs.linear::<Tensor>(&self.base_model.embeddings.ws, None);
|
||||
(lm_logits, encoder_hidden_states, decoder_cache,
|
||||
all_decoder_hidden_states, all_decoder_attentions,
|
||||
all_encoder_hidden_states, all_encoder_attentions)
|
||||
(
|
||||
lm_logits,
|
||||
encoder_hidden_states,
|
||||
decoder_cache,
|
||||
all_decoder_hidden_states,
|
||||
all_decoder_attentions,
|
||||
all_encoder_hidden_states,
|
||||
all_encoder_attentions,
|
||||
)
|
||||
}
|
||||
|
||||
pub fn encode(&self, input_ids: &Tensor, attention_mask: Option<&Tensor>) -> Tensor {
|
||||
let (encoder_hidden_states, _, _) = self.base_model.encoder.forward_t(input_ids, attention_mask, &self.base_model.embeddings, false);
|
||||
let (encoder_hidden_states, _, _) = self.base_model.encoder.forward_t(
|
||||
input_ids,
|
||||
attention_mask,
|
||||
&self.base_model.embeddings,
|
||||
false,
|
||||
);
|
||||
encoder_hidden_states
|
||||
}
|
||||
}
|
||||
@ -287,7 +431,7 @@ impl LMHeadModel for MarianForConditionalGeneration {
|
||||
/// # use rust_bert::Config;
|
||||
/// # use std::path::Path;
|
||||
/// # use tch::kind::Kind::{Int64, Double};
|
||||
/// use rust_bert::bart::{BartConfig};
|
||||
/// use rust_bert::bart::BartConfig;
|
||||
/// use rust_bert::marian::MarianForConditionalGeneration;
|
||||
/// # let config_path = Path::new("path/to/config.json");
|
||||
/// # let vocab_path = Path::new("path/to/vocab.txt");
|
||||
@ -298,25 +442,33 @@ impl LMHeadModel for MarianForConditionalGeneration {
|
||||
/// let (batch_size, source_sequence_length, target_sequence_length) = (64, 128, 56);
|
||||
/// let input_tensor = Tensor::rand(&[batch_size, source_sequence_length], (Int64, device));
|
||||
/// let target_tensor = Tensor::rand(&[batch_size, target_sequence_length], (Int64, device));
|
||||
/// let encoder_attention_mask = Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
|
||||
/// let decoder_attention_mask = Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
|
||||
/// let encoder_attention_mask =
|
||||
/// Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
|
||||
/// let decoder_attention_mask =
|
||||
/// Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
|
||||
///
|
||||
/// let (decoder_output, encoder_hidden_states, cache,
|
||||
/// all_encoder_hidden_states, all_encoder_attentions,
|
||||
/// all_decoder_hidden_states, all_decoder_attentions) = no_grad(|| {
|
||||
/// marian_model
|
||||
/// .forward_t(Some(&input_tensor),
|
||||
/// let (
|
||||
/// decoder_output,
|
||||
/// encoder_hidden_states,
|
||||
/// cache,
|
||||
/// all_encoder_hidden_states,
|
||||
/// all_encoder_attentions,
|
||||
/// all_decoder_hidden_states,
|
||||
/// all_decoder_attentions,
|
||||
/// ) = no_grad(|| {
|
||||
/// marian_model.forward_t(
|
||||
/// Some(&input_tensor),
|
||||
/// Some(&encoder_attention_mask),
|
||||
/// None,
|
||||
/// Some(&target_tensor),
|
||||
/// Some(&decoder_attention_mask),
|
||||
/// None,
|
||||
/// false)
|
||||
/// false,
|
||||
/// )
|
||||
/// });
|
||||
///
|
||||
/// ```
|
||||
///
|
||||
fn forward_t(&self,
|
||||
fn forward_t(
|
||||
&self,
|
||||
input_ids: &Option<Tensor>,
|
||||
cache: Cache,
|
||||
attention_mask: &Option<Tensor>,
|
||||
@ -325,26 +477,47 @@ impl LMHeadModel for MarianForConditionalGeneration {
|
||||
_input_embeds: &Option<Tensor>,
|
||||
encoder_outputs: Option<&Tensor>,
|
||||
decoder_input_ids: &Option<Tensor>,
|
||||
train: bool) -> Result<(Tensor, Option<Tensor>, Cache, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
|
||||
train: bool,
|
||||
) -> Result<
|
||||
(
|
||||
Tensor,
|
||||
Option<Tensor>,
|
||||
Cache,
|
||||
Option<Vec<Tensor>>,
|
||||
Option<Vec<Tensor>>,
|
||||
),
|
||||
&'static str,
|
||||
> {
|
||||
let (decoder_output, encoder_hidden_states, new_cache, _, _, _, _) = match cache {
|
||||
Cache::BARTCache(cached_layer_states) => self.base_model.forward_t(input_ids.as_ref(),
|
||||
Cache::BARTCache(cached_layer_states) => self.base_model.forward_t(
|
||||
input_ids.as_ref(),
|
||||
attention_mask.as_ref(),
|
||||
decoder_input_ids.as_ref(),
|
||||
Some((encoder_outputs.as_ref().unwrap().copy(), None, None)),
|
||||
None,
|
||||
cached_layer_states,
|
||||
train),
|
||||
Cache::None => self.base_model.forward_t(input_ids.as_ref(),
|
||||
train,
|
||||
),
|
||||
Cache::None => self.base_model.forward_t(
|
||||
input_ids.as_ref(),
|
||||
attention_mask.as_ref(),
|
||||
decoder_input_ids.as_ref(),
|
||||
Some((encoder_outputs.as_ref().unwrap().copy(), None, None)),
|
||||
None,
|
||||
None,
|
||||
train),
|
||||
_ => Err("Cache not compatible with Marian Model")?
|
||||
train,
|
||||
),
|
||||
_ => Err("Cache not compatible with Marian Model")?,
|
||||
};
|
||||
|
||||
let lm_logits = decoder_output.linear::<Tensor>(&self.base_model.embeddings.ws, None) + &self.final_logits_bias;
|
||||
Ok((lm_logits, Some(encoder_hidden_states), Cache::BARTCache(new_cache), None, None))
|
||||
let lm_logits = decoder_output.linear::<Tensor>(&self.base_model.embeddings.ws, None)
|
||||
+ &self.final_logits_bias;
|
||||
Ok((
|
||||
lm_logits,
|
||||
Some(encoder_hidden_states),
|
||||
Cache::BARTCache(new_cache),
|
||||
None,
|
||||
None,
|
||||
))
|
||||
}
|
||||
}
|
@ -19,16 +19,24 @@
|
||||
//! #
|
||||
//! use tch::{nn, Device};
|
||||
//! # use std::path::PathBuf;
|
||||
//! use rust_bert::Config;
|
||||
//! use rust_bert::bart::{BartConfig, BartModel};
|
||||
//! use rust_bert::resources::{Resource, download_resource, LocalResource};
|
||||
//! use rust_tokenizers::preprocessing::tokenizer::marian_tokenizer::MarianTokenizer;
|
||||
//! use rust_bert::marian::MarianForConditionalGeneration;
|
||||
//! use rust_bert::resources::{download_resource, LocalResource, Resource};
|
||||
//! use rust_bert::Config;
|
||||
//! use rust_tokenizers::preprocessing::tokenizer::marian_tokenizer::MarianTokenizer;
|
||||
//!
|
||||
//! let config_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/config.json")});
|
||||
//! let vocab_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/vocab.json")});
|
||||
//! let sentence_piece_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/spiece.model")});
|
||||
//! let weights_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/model.ot")});
|
||||
//! let config_resource = Resource::Local(LocalResource {
|
||||
//! local_path: PathBuf::from("path/to/config.json"),
|
||||
//! });
|
||||
//! let vocab_resource = Resource::Local(LocalResource {
|
||||
//! local_path: PathBuf::from("path/to/vocab.json"),
|
||||
//! });
|
||||
//! let sentence_piece_resource = Resource::Local(LocalResource {
|
||||
//! local_path: PathBuf::from("path/to/spiece.model"),
|
||||
//! });
|
||||
//! let weights_resource = Resource::Local(LocalResource {
|
||||
//! local_path: PathBuf::from("path/to/model.ot"),
|
||||
//! });
|
||||
//! let config_path = download_resource(&config_resource)?;
|
||||
//! let vocab_path = download_resource(&vocab_resource)?;
|
||||
//! let spiece_path = download_resource(&sentence_piece_resource)?;
|
||||
@ -36,7 +44,11 @@
|
||||
//!
|
||||
//! let device = Device::cuda_if_available();
|
||||
//! let mut vs = nn::VarStore::new(device);
|
||||
//! let tokenizer = MarianTokenizer::from_files(vocab_path.to_str().unwrap(), spiece_path.to_str().unwrap(), true);
|
||||
//! let tokenizer = MarianTokenizer::from_files(
|
||||
//! vocab_path.to_str().unwrap(),
|
||||
//! spiece_path.to_str().unwrap(),
|
||||
//! true,
|
||||
//! );
|
||||
//! let config = BartConfig::from_file(config_path);
|
||||
//! let marian_model = MarianForConditionalGeneration::new(&vs.root(), &config, false);
|
||||
//! vs.load(weights_path)?;
|
||||
@ -47,4 +59,7 @@
|
||||
|
||||
mod marian;
|
||||
|
||||
pub use marian::{MarianForConditionalGeneration, MarianModelResources, MarianConfigResources, MarianVocabResources, MarianSpmResources, MarianPrefix};
|
||||
pub use marian::{
|
||||
MarianConfigResources, MarianForConditionalGeneration, MarianModelResources, MarianPrefix,
|
||||
MarianSpmResources, MarianVocabResources,
|
||||
};
|
||||
|
@ -18,15 +18,23 @@
|
||||
//! use rust_tokenizers::OpenAiGptTokenizer;
|
||||
//! use tch::{nn, Device};
|
||||
//! # use std::path::PathBuf;
|
||||
//! use rust_bert::Config;
|
||||
//! use rust_bert::gpt2::Gpt2Config;
|
||||
//! use rust_bert::openai_gpt::OpenAiGptModel;
|
||||
//! use rust_bert::resources::{Resource, download_resource, LocalResource};
|
||||
//! use rust_bert::resources::{download_resource, LocalResource, Resource};
|
||||
//! use rust_bert::Config;
|
||||
//!
|
||||
//! let config_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/config.json")});
|
||||
//! let vocab_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/vocab.txt")});
|
||||
//! let merges_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/vocab.txt")});
|
||||
//! let weights_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/model.ot")});
|
||||
//! let config_resource = Resource::Local(LocalResource {
|
||||
//! local_path: PathBuf::from("path/to/config.json"),
|
||||
//! });
|
||||
//! let vocab_resource = Resource::Local(LocalResource {
|
||||
//! local_path: PathBuf::from("path/to/vocab.txt"),
|
||||
//! });
|
||||
//! let merges_resource = Resource::Local(LocalResource {
|
||||
//! local_path: PathBuf::from("path/to/vocab.txt"),
|
||||
//! });
|
||||
//! let weights_resource = Resource::Local(LocalResource {
|
||||
//! local_path: PathBuf::from("path/to/model.ot"),
|
||||
//! });
|
||||
//! let config_path = download_resource(&config_resource)?;
|
||||
//! let vocab_path = download_resource(&vocab_resource)?;
|
||||
//! let merges_path = download_resource(&merges_resource)?;
|
||||
@ -34,7 +42,11 @@
|
||||
//!
|
||||
//! let device = Device::cuda_if_available();
|
||||
//! let mut vs = nn::VarStore::new(device);
|
||||
//! let tokenizer: OpenAiGptTokenizer = OpenAiGptTokenizer::from_file(vocab_path.to_str().unwrap(), merges_path.to_str().unwrap(), true);
|
||||
//! let tokenizer: OpenAiGptTokenizer = OpenAiGptTokenizer::from_file(
|
||||
//! vocab_path.to_str().unwrap(),
|
||||
//! merges_path.to_str().unwrap(),
|
||||
//! true,
|
||||
//! );
|
||||
//! let config = Gpt2Config::from_file(config_path);
|
||||
//! let gpt_model = OpenAiGptModel::new(&vs.root(), &config);
|
||||
//! vs.load(weights_path)?;
|
||||
@ -46,5 +58,7 @@
|
||||
mod openai_gpt;
|
||||
mod transformer;
|
||||
|
||||
pub use openai_gpt::{OpenAiGptModelResources, OpenAiGptConfigResources, OpenAiGptVocabResources, OpenAiGptMergesResources,
|
||||
OpenAiGptModel, OpenAIGPTLMHeadModel};
|
||||
pub use openai_gpt::{
|
||||
OpenAIGPTLMHeadModel, OpenAiGptConfigResources, OpenAiGptMergesResources, OpenAiGptModel,
|
||||
OpenAiGptModelResources, OpenAiGptVocabResources,
|
||||
};
|
||||
|
@ -12,15 +12,15 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use tch::{nn, Tensor};
|
||||
use crate::common::dropout::Dropout;
|
||||
use tch::nn::embedding;
|
||||
use tch::kind::Kind::Int64;
|
||||
use std::borrow::BorrowMut;
|
||||
use crate::common::linear::{LinearNoBias, linear_no_bias};
|
||||
use crate::openai_gpt::transformer::Block;
|
||||
use crate::common::linear::{linear_no_bias, LinearNoBias};
|
||||
use crate::gpt2::Gpt2Config;
|
||||
use crate::pipelines::generation::{LMHeadModel, Cache};
|
||||
use crate::openai_gpt::transformer::Block;
|
||||
use crate::pipelines::generation::{Cache, LMHeadModel};
|
||||
use std::borrow::BorrowMut;
|
||||
use tch::kind::Kind::Int64;
|
||||
use tch::nn::embedding;
|
||||
use tch::{nn, Tensor};
|
||||
|
||||
/// # GPT Pretrained model weight files
|
||||
pub struct OpenAiGptModelResources;
|
||||
@ -36,22 +36,34 @@ pub struct OpenAiGptMergesResources;
|
||||
|
||||
impl OpenAiGptModelResources {
|
||||
/// Shared under MIT license by the OpenAI team at https://github.com/openai/finetune-transformer-lm. Modified with conversion to C-array format.
|
||||
pub const GPT: (&'static str, &'static str) = ("openai-gpt/model.ot", "https://cdn.huggingface.co/openai-gpt-rust_model.ot");
|
||||
pub const GPT: (&'static str, &'static str) = (
|
||||
"openai-gpt/model.ot",
|
||||
"https://cdn.huggingface.co/openai-gpt-rust_model.ot",
|
||||
);
|
||||
}
|
||||
|
||||
impl OpenAiGptConfigResources {
|
||||
/// Shared under MIT license by the OpenAI team at https://github.com/openai/finetune-transformer-lm. Modified with conversion to C-array format.
|
||||
pub const GPT: (&'static str, &'static str) = ("openai-gpt/config.json", "https://cdn.huggingface.co/openai-gpt-config.json");
|
||||
pub const GPT: (&'static str, &'static str) = (
|
||||
"openai-gpt/config.json",
|
||||
"https://cdn.huggingface.co/openai-gpt-config.json",
|
||||
);
|
||||
}
|
||||
|
||||
impl OpenAiGptVocabResources {
|
||||
/// Shared under MIT license by the OpenAI team at https://github.com/openai/finetune-transformer-lm. Modified with conversion to C-array format.
|
||||
pub const GPT: (&'static str, &'static str) = ("openai-gpt/vocab.txt", "https://cdn.huggingface.co/openai-gpt-vocab.json");
|
||||
pub const GPT: (&'static str, &'static str) = (
|
||||
"openai-gpt/vocab.txt",
|
||||
"https://cdn.huggingface.co/openai-gpt-vocab.json",
|
||||
);
|
||||
}
|
||||
|
||||
impl OpenAiGptMergesResources {
|
||||
/// Shared under MIT license by the OpenAI team at https://github.com/openai/finetune-transformer-lm. Modified with conversion to C-array format.
|
||||
pub const GPT: (&'static str, &'static str) = ("openai-gpt/merges.txt", "https://cdn.huggingface.co/openai-gpt-merges.txt");
|
||||
pub const GPT: (&'static str, &'static str) = (
|
||||
"openai-gpt/merges.txt",
|
||||
"https://cdn.huggingface.co/openai-gpt-merges.txt",
|
||||
);
|
||||
}
|
||||
|
||||
/// # GPT Base model
|
||||
@ -82,11 +94,11 @@ impl OpenAiGptModel {
|
||||
/// # Example
|
||||
///
|
||||
/// ```no_run
|
||||
/// use tch::{nn, Device};
|
||||
/// use rust_bert::Config;
|
||||
/// use std::path::Path;
|
||||
/// use rust_bert::gpt2::Gpt2Config;
|
||||
/// use rust_bert::openai_gpt::OpenAiGptModel;
|
||||
/// use rust_bert::Config;
|
||||
/// use std::path::Path;
|
||||
/// use tch::{nn, Device};
|
||||
///
|
||||
/// let config_path = Path::new("path/to/config.json");
|
||||
/// let device = Device::Cpu;
|
||||
@ -94,30 +106,46 @@ impl OpenAiGptModel {
|
||||
/// let config = Gpt2Config::from_file(config_path);
|
||||
/// let gpt2: OpenAiGptModel = OpenAiGptModel::new(&(&p.root() / "gpt"), &config);
|
||||
/// ```
|
||||
///
|
||||
pub fn new(p: &nn::Path, config: &Gpt2Config) -> OpenAiGptModel {
|
||||
let tokens_embed = embedding(&(p / "tokens_embed"), config.vocab_size, config.n_embd, Default::default());
|
||||
let positions_embed = embedding(&(p / "positions_embed"), config.n_positions, config.n_embd, Default::default());
|
||||
let tokens_embed = embedding(
|
||||
&(p / "tokens_embed"),
|
||||
config.vocab_size,
|
||||
config.n_embd,
|
||||
Default::default(),
|
||||
);
|
||||
let positions_embed = embedding(
|
||||
&(p / "positions_embed"),
|
||||
config.n_positions,
|
||||
config.n_embd,
|
||||
Default::default(),
|
||||
);
|
||||
|
||||
let embd_pdrop = match config.embd_pdrop {
|
||||
Some(value) => value,
|
||||
None => 0.1
|
||||
None => 0.1,
|
||||
};
|
||||
let drop = Dropout::new(embd_pdrop);
|
||||
let mut h: Vec<Block> = vec!();
|
||||
let mut h: Vec<Block> = vec![];
|
||||
let h_path = &(p / "h");
|
||||
for layer_index in 0..config.n_layer {
|
||||
h.push(Block::new(&(h_path / layer_index), config, true));
|
||||
};
|
||||
}
|
||||
let output_attentions = match config.output_attentions {
|
||||
Some(value) => value,
|
||||
None => false
|
||||
None => false,
|
||||
};
|
||||
let output_hidden_states = match config.output_hidden_states {
|
||||
Some(value) => value,
|
||||
None => false
|
||||
None => false,
|
||||
};
|
||||
OpenAiGptModel { tokens_embed, positions_embed, drop, h, output_hidden_states, output_attentions }
|
||||
OpenAiGptModel {
|
||||
tokens_embed,
|
||||
positions_embed,
|
||||
drop,
|
||||
h,
|
||||
output_hidden_states,
|
||||
output_attentions,
|
||||
}
|
||||
}
|
||||
|
||||
/// Forward pass through the model
|
||||
@ -156,64 +184,83 @@ impl OpenAiGptModel {
|
||||
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
|
||||
/// let attention_mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
|
||||
/// let token_type_ids = Tensor::ones(&[batch_size, sequence_length], (Int64, device));
|
||||
/// let position_ids = Tensor::arange(sequence_length, (Int64, device)).expand(&[batch_size, sequence_length], true);
|
||||
/// let position_ids = Tensor::arange(sequence_length, (Int64, device))
|
||||
/// .expand(&[batch_size, sequence_length], true);
|
||||
///
|
||||
/// let (output, hidden_states, attentions) = no_grad(|| {
|
||||
/// gpt_model
|
||||
/// .forward_t(&Some(input_tensor),
|
||||
/// .forward_t(
|
||||
/// &Some(input_tensor),
|
||||
/// &Some(attention_mask),
|
||||
/// &Some(token_type_ids),
|
||||
/// &Some(position_ids),
|
||||
/// &None,
|
||||
/// false).unwrap()
|
||||
/// false,
|
||||
/// )
|
||||
/// .unwrap()
|
||||
/// });
|
||||
///
|
||||
/// ```
|
||||
///
|
||||
pub fn forward_t(&self,
|
||||
pub fn forward_t(
|
||||
&self,
|
||||
input_ids: &Option<Tensor>,
|
||||
attention_mask: &Option<Tensor>,
|
||||
token_type_ids: &Option<Tensor>,
|
||||
position_ids: &Option<Tensor>,
|
||||
input_embeds: &Option<Tensor>,
|
||||
train: bool) -> Result<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
|
||||
train: bool,
|
||||
) -> Result<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
|
||||
let (input_embeddings, seq_length) = match input_ids {
|
||||
Some(input_value) => match input_embeds {
|
||||
Some(_) => { return Err("Only one of input ids or input embeddings may be set"); }
|
||||
None => (input_value.apply(&self.tokens_embed), *input_value.size().last().unwrap())
|
||||
Some(_) => {
|
||||
return Err("Only one of input ids or input embeddings may be set");
|
||||
}
|
||||
None => (
|
||||
input_value.apply(&self.tokens_embed),
|
||||
*input_value.size().last().unwrap(),
|
||||
),
|
||||
},
|
||||
None => match input_embeds {
|
||||
Some(embeds) => (embeds.copy(), embeds.size()[1]),
|
||||
None => { return Err("At least one of input ids or input embeddings must be set"); }
|
||||
None => {
|
||||
return Err("At least one of input ids or input embeddings must be set");
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
let position_ids = match position_ids {
|
||||
Some(value) => value.copy(),
|
||||
None => Tensor::arange(seq_length, (Int64, input_embeddings.device())).unsqueeze(0)
|
||||
None => Tensor::arange(seq_length, (Int64, input_embeddings.device())).unsqueeze(0),
|
||||
};
|
||||
|
||||
let attention_mask: Option<Tensor> = match attention_mask {
|
||||
Some(value) => {
|
||||
Some(
|
||||
Some(value) => Some(
|
||||
(value
|
||||
.view((input_embeddings.size()[0], -1))
|
||||
.unsqueeze(1)
|
||||
.unsqueeze(2)
|
||||
- 1.0
|
||||
) * 10000.0)
|
||||
}
|
||||
None => None
|
||||
- 1.0)
|
||||
* 10000.0,
|
||||
),
|
||||
None => None,
|
||||
};
|
||||
|
||||
let position_embeds = position_ids.apply(&self.positions_embed);
|
||||
let token_type_embeds = match token_type_ids {
|
||||
Some(value) => value.apply(&self.tokens_embed),
|
||||
None => Tensor::zeros_like(&position_embeds)
|
||||
None => Tensor::zeros_like(&position_embeds),
|
||||
};
|
||||
let mut hidden_state: Tensor =
|
||||
(input_embeddings + position_embeds + token_type_embeds).apply_t(&self.drop, train);
|
||||
let mut all_hidden_states: Option<Vec<Tensor>> = if self.output_hidden_states {
|
||||
Some(vec![])
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let mut all_attentions: Option<Vec<Tensor>> = if self.output_attentions {
|
||||
Some(vec![])
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let mut hidden_state: Tensor = (input_embeddings + position_embeds + token_type_embeds).apply_t(&self.drop, train);
|
||||
let mut all_hidden_states: Option<Vec<Tensor>> = if self.output_hidden_states { Some(vec!()) } else { None };
|
||||
let mut all_attentions: Option<Vec<Tensor>> = if self.output_attentions { Some(vec!()) } else { None };
|
||||
|
||||
let mut layers = self.h.iter();
|
||||
loop {
|
||||
@ -229,9 +276,9 @@ impl OpenAiGptModel {
|
||||
attentions.push(temp.1.as_ref().unwrap().copy());
|
||||
};
|
||||
}
|
||||
None => break
|
||||
};
|
||||
None => break,
|
||||
};
|
||||
}
|
||||
|
||||
Ok((hidden_state, all_hidden_states, all_attentions))
|
||||
}
|
||||
@ -258,11 +305,11 @@ impl OpenAIGPTLMHeadModel {
|
||||
/// # Example
|
||||
///
|
||||
/// ```no_run
|
||||
/// use tch::{nn, Device};
|
||||
/// use rust_bert::Config;
|
||||
/// use std::path::Path;
|
||||
/// use rust_bert::gpt2::Gpt2Config;
|
||||
/// use rust_bert::openai_gpt::OpenAIGPTLMHeadModel;
|
||||
/// use rust_bert::Config;
|
||||
/// use std::path::Path;
|
||||
/// use tch::{nn, Device};
|
||||
///
|
||||
/// let config_path = Path::new("path/to/config.json");
|
||||
/// let device = Device::Cpu;
|
||||
@ -270,11 +317,18 @@ impl OpenAIGPTLMHeadModel {
|
||||
/// let config = Gpt2Config::from_file(config_path);
|
||||
/// let gpt2: OpenAIGPTLMHeadModel = OpenAIGPTLMHeadModel::new(&(&p.root() / "gpt"), &config);
|
||||
/// ```
|
||||
///
|
||||
pub fn new(p: &nn::Path, config: &Gpt2Config) -> OpenAIGPTLMHeadModel {
|
||||
let transformer = OpenAiGptModel::new(&p, config);
|
||||
let lm_head = linear_no_bias(&(p / "lm_head"), config.n_embd, config.vocab_size, Default::default());
|
||||
OpenAIGPTLMHeadModel { transformer, lm_head }
|
||||
let lm_head = linear_no_bias(
|
||||
&(p / "lm_head"),
|
||||
config.n_embd,
|
||||
config.vocab_size,
|
||||
Default::default(),
|
||||
);
|
||||
OpenAIGPTLMHeadModel {
|
||||
transformer,
|
||||
lm_head,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -336,10 +390,9 @@ impl LMHeadModel for OpenAIGPTLMHeadModel {
|
||||
/// &None,
|
||||
/// false).unwrap()
|
||||
/// });
|
||||
///
|
||||
/// ```
|
||||
///
|
||||
fn forward_t(&self,
|
||||
fn forward_t(
|
||||
&self,
|
||||
input_ids: &Option<Tensor>,
|
||||
_layer_past: Cache,
|
||||
attention_mask: &Option<Tensor>,
|
||||
@ -348,17 +401,33 @@ impl LMHeadModel for OpenAIGPTLMHeadModel {
|
||||
input_embeds: &Option<Tensor>,
|
||||
_encoder_outputs: Option<&Tensor>,
|
||||
_decoder_input_ids: &Option<Tensor>,
|
||||
train: bool) -> Result<(Tensor, Option<Tensor>, Cache, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
|
||||
let (output,
|
||||
all_hidden_states,
|
||||
all_attentions) = self.transformer.forward_t(input_ids,
|
||||
train: bool,
|
||||
) -> Result<
|
||||
(
|
||||
Tensor,
|
||||
Option<Tensor>,
|
||||
Cache,
|
||||
Option<Vec<Tensor>>,
|
||||
Option<Vec<Tensor>>,
|
||||
),
|
||||
&'static str,
|
||||
> {
|
||||
let (output, all_hidden_states, all_attentions) = self.transformer.forward_t(
|
||||
input_ids,
|
||||
attention_mask,
|
||||
token_type_ids,
|
||||
position_ids,
|
||||
input_embeds,
|
||||
train)?;
|
||||
train,
|
||||
)?;
|
||||
|
||||
let lm_logits = output.apply(&self.lm_head);
|
||||
Ok((lm_logits, None, Cache::None, all_hidden_states, all_attentions))
|
||||
Ok((
|
||||
lm_logits,
|
||||
None,
|
||||
Cache::None,
|
||||
all_hidden_states,
|
||||
all_attentions,
|
||||
))
|
||||
}
|
||||
}
|
@ -13,9 +13,9 @@
|
||||
// limitations under the License.
|
||||
|
||||
use crate::gpt2::attention::Attention;
|
||||
use tch::{Tensor, nn};
|
||||
use crate::gpt2::transformer::MLP;
|
||||
use crate::gpt2::Gpt2Config;
|
||||
use tch::{nn, Tensor};
|
||||
|
||||
pub struct Block {
|
||||
ln_1: nn::LayerNorm,
|
||||
@ -26,17 +26,29 @@ pub struct Block {
|
||||
|
||||
impl Block {
|
||||
pub fn new(p: &nn::Path, config: &Gpt2Config, scale: bool) -> Block {
|
||||
let layer_norm_config = nn::LayerNormConfig { eps: config.layer_norm_epsilon, ..Default::default() };
|
||||
let layer_norm_config = nn::LayerNormConfig {
|
||||
eps: config.layer_norm_epsilon,
|
||||
..Default::default()
|
||||
};
|
||||
let ln_1 = nn::layer_norm(p / "ln_1", vec![config.n_embd], layer_norm_config);
|
||||
let ln_2 = nn::layer_norm(p / "ln_2", vec![config.n_embd], layer_norm_config);
|
||||
let attn = Attention::new(&(p / "attn"), config, scale);
|
||||
let mlp = MLP::new(&(p / "mlp"), config);
|
||||
|
||||
Block { ln_1, attn, ln_2, mlp }
|
||||
Block {
|
||||
ln_1,
|
||||
attn,
|
||||
ln_2,
|
||||
mlp,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn forward_t(&self, x: &Tensor, attention_mask: &Option<Tensor>, train: bool)
|
||||
-> (Tensor, Option<Tensor>) {
|
||||
pub fn forward_t(
|
||||
&self,
|
||||
x: &Tensor,
|
||||
attention_mask: &Option<Tensor>,
|
||||
train: bool,
|
||||
) -> (Tensor, Option<Tensor>) {
|
||||
let (output, _, attentions) = self.attn.forward_t(x, &None, attention_mask, train);
|
||||
let x = (x + output).apply(&self.ln_1);
|
||||
let m = self.mlp.forward_t(&x, train);
|
||||
|
@ -19,13 +19,13 @@
|
||||
//!
|
||||
use crate::bert::BertConfig;
|
||||
use crate::distilbert::DistilBertConfig;
|
||||
use rust_tokenizers::{BertTokenizer, RobertaTokenizer, TokenizedInput, TruncationStrategy};
|
||||
use rust_tokenizers::preprocessing::tokenizer::base_tokenizer::Tokenizer;
|
||||
use std::path::Path;
|
||||
use crate::Config;
|
||||
use std::collections::HashMap;
|
||||
use serde::{Serialize, Deserialize};
|
||||
use crate::electra::ElectraConfig;
|
||||
use crate::Config;
|
||||
use rust_tokenizers::preprocessing::tokenizer::base_tokenizer::Tokenizer;
|
||||
use rust_tokenizers::{BertTokenizer, RobertaTokenizer, TokenizedInput, TruncationStrategy};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::path::Path;
|
||||
|
||||
#[derive(Clone, Copy, Serialize, Deserialize)]
|
||||
/// # Identifies the type of model
|
||||
@ -60,25 +60,42 @@ impl ConfigOption {
|
||||
match model_type {
|
||||
ModelType::Bert | ModelType::Roberta => ConfigOption::Bert(BertConfig::from_file(path)),
|
||||
ModelType::DistilBert => ConfigOption::DistilBert(DistilBertConfig::from_file(path)),
|
||||
ModelType::Electra => ConfigOption::Electra(ElectraConfig::from_file(path))
|
||||
ModelType::Electra => ConfigOption::Electra(ElectraConfig::from_file(path)),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_label_mapping(self) -> HashMap<i64, String> {
|
||||
match self {
|
||||
Self::Bert(config) => config.id2label.expect("No label dictionary (id2label) provided in configuration file"),
|
||||
Self::DistilBert(config) => config.id2label.expect("No label dictionary (id2label) provided in configuration file"),
|
||||
Self::Electra(config) => config.id2label.expect("No label dictionary (id2label) provided in configuration file"),
|
||||
Self::Bert(config) => config
|
||||
.id2label
|
||||
.expect("No label dictionary (id2label) provided in configuration file"),
|
||||
Self::DistilBert(config) => config
|
||||
.id2label
|
||||
.expect("No label dictionary (id2label) provided in configuration file"),
|
||||
Self::Electra(config) => config
|
||||
.id2label
|
||||
.expect("No label dictionary (id2label) provided in configuration file"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TokenizerOption {
|
||||
/// Interface method to load a tokenizer from file
|
||||
pub fn from_file(model_type: ModelType, vocab_path: &str, merges_path: Option<&str>, lower_case: bool) -> Self {
|
||||
pub fn from_file(
|
||||
model_type: ModelType,
|
||||
vocab_path: &str,
|
||||
merges_path: Option<&str>,
|
||||
lower_case: bool,
|
||||
) -> Self {
|
||||
match model_type {
|
||||
ModelType::Bert | ModelType::DistilBert | ModelType::Electra => TokenizerOption::Bert(BertTokenizer::from_file(vocab_path, lower_case)),
|
||||
ModelType::Roberta => TokenizerOption::Roberta(RobertaTokenizer::from_file(vocab_path, merges_path.expect("No merges specified!"), lower_case)),
|
||||
ModelType::Bert | ModelType::DistilBert | ModelType::Electra => {
|
||||
TokenizerOption::Bert(BertTokenizer::from_file(vocab_path, lower_case))
|
||||
}
|
||||
ModelType::Roberta => TokenizerOption::Roberta(RobertaTokenizer::from_file(
|
||||
vocab_path,
|
||||
merges_path.expect("No merges specified!"),
|
||||
lower_case,
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
@ -86,15 +103,25 @@ impl TokenizerOption {
|
||||
pub fn model_type(&self) -> ModelType {
|
||||
match *self {
|
||||
Self::Bert(_) => ModelType::Bert,
|
||||
Self::Roberta(_) => ModelType::Roberta
|
||||
Self::Roberta(_) => ModelType::Roberta,
|
||||
}
|
||||
}
|
||||
|
||||
/// Interface method
|
||||
pub fn encode_list(&self, text_list: Vec<&str>, max_len: usize, truncation_strategy: &TruncationStrategy, stride: usize) -> Vec<TokenizedInput> {
|
||||
pub fn encode_list(
|
||||
&self,
|
||||
text_list: Vec<&str>,
|
||||
max_len: usize,
|
||||
truncation_strategy: &TruncationStrategy,
|
||||
stride: usize,
|
||||
) -> Vec<TokenizedInput> {
|
||||
match *self {
|
||||
Self::Bert(ref tokenizer) => tokenizer.encode_list(text_list, max_len, truncation_strategy, stride),
|
||||
Self::Roberta(ref tokenizer) => tokenizer.encode_list(text_list, max_len, truncation_strategy, stride)
|
||||
Self::Bert(ref tokenizer) => {
|
||||
tokenizer.encode_list(text_list, max_len, truncation_strategy, stride)
|
||||
}
|
||||
Self::Roberta(ref tokenizer) => {
|
||||
tokenizer.encode_list(text_list, max_len, truncation_strategy, stride)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -6,14 +6,14 @@
|
||||
//! Extractive question answering from a given question and context. DistilBERT model finetuned on SQuAD (Stanford Question Answering Dataset)
|
||||
//!
|
||||
//! ```no_run
|
||||
//! use rust_bert::pipelines::question_answering::{QuestionAnsweringModel, QaInput};
|
||||
//! use rust_bert::pipelines::question_answering::{QaInput, QuestionAnsweringModel};
|
||||
//! # fn main() -> failure::Fallible<()> {
|
||||
//! let qa_model = QuestionAnsweringModel::new(Default::default())?;
|
||||
//!
|
||||
//! let question = String::from("Where does Amy live ?");
|
||||
//! let context = String::from("Amy lives in Amsterdam");
|
||||
//!
|
||||
//! let answers = qa_model.predict(&vec!(QaInput { question, context }), 1, 32);
|
||||
//! let answers = qa_model.predict(&vec![QaInput { question, context }], 1, 32);
|
||||
//! # Ok(())
|
||||
//! # }
|
||||
//! ```
|
||||
@ -22,15 +22,12 @@
|
||||
//! ```no_run
|
||||
//! # use rust_bert::pipelines::question_answering::Answer;
|
||||
//! # let output =
|
||||
//! [
|
||||
//! Answer {
|
||||
//! [Answer {
|
||||
//! score: 0.9976,
|
||||
//! start: 13,
|
||||
//! end: 21,
|
||||
//! answer: "Amsterdam"
|
||||
//!# .to_owned()
|
||||
//! }
|
||||
//! ]
|
||||
//! answer: "Amsterdam", //#### # .to_owned()
|
||||
//! }]
|
||||
//! # ;
|
||||
//! ```
|
||||
//!
|
||||
@ -48,9 +45,10 @@
|
||||
//! ```no_run
|
||||
//! # fn main() -> failure::Fallible<()> {
|
||||
//! # use rust_bert::pipelines::generation::LanguageGenerator;
|
||||
//! use rust_bert::pipelines::translation::{TranslationModel, TranslationConfig, Language};
|
||||
//! use rust_bert::pipelines::translation::{Language, TranslationConfig, TranslationModel};
|
||||
//! use tch::Device;
|
||||
//! let translation_config = TranslationConfig::new(Language::EnglishToFrench, Device::cuda_if_available());
|
||||
//! let translation_config =
|
||||
//! TranslationConfig::new(Language::EnglishToFrench, Device::cuda_if_available());
|
||||
//! let mut model = TranslationModel::new(translation_config)?;
|
||||
//!
|
||||
//! let input = ["This is a sentence to be translated"];
|
||||
@ -129,7 +127,7 @@
|
||||
//! let mut model = GPT2Generator::new(Default::default())?;
|
||||
//! let input_context_1 = "The dog";
|
||||
//! let input_context_2 = "The cat was";
|
||||
//! let output = model.generate(Some(vec!(input_context_1, input_context_2)), None);
|
||||
//! let output = model.generate(Some(vec![input_context_1, input_context_2]), None);
|
||||
//! # Ok(())
|
||||
//! # }
|
||||
//! ```
|
||||
@ -170,9 +168,18 @@
|
||||
//! # use rust_bert::pipelines::sentiment::SentimentPolarity::{Positive, Negative};
|
||||
//! # let output =
|
||||
//! [
|
||||
//! Sentiment { polarity: Positive, score: 0.998 },
|
||||
//! Sentiment { polarity: Negative, score: 0.992 },
|
||||
//! Sentiment { polarity: Positive, score: 0.999 }
|
||||
//! Sentiment {
|
||||
//! polarity: Positive,
|
||||
//! score: 0.998,
|
||||
//! },
|
||||
//! Sentiment {
|
||||
//! polarity: Negative,
|
||||
//! score: 0.992,
|
||||
//! },
|
||||
//! Sentiment {
|
||||
//! polarity: Positive,
|
||||
//! score: 0.999,
|
||||
//! },
|
||||
//! ]
|
||||
//! # ;
|
||||
//! ```
|
||||
@ -185,7 +192,7 @@
|
||||
//! let ner_model = NERModel::new(Default::default())?;
|
||||
//! let input = [
|
||||
//! "My name is Amy. I live in Paris.",
|
||||
//! "Paris is a city in France."
|
||||
//! "Paris is a city in France.",
|
||||
//! ];
|
||||
//! let output = ner_model.predict(&input);
|
||||
//! # Ok(())
|
||||
@ -197,21 +204,37 @@
|
||||
//! # use rust_bert::pipelines::ner::Entity;
|
||||
//! # let output =
|
||||
//! [
|
||||
//! Entity { word: String::from("Amy"), score: 0.9986, label: String::from("I-PER") },
|
||||
//! Entity { word: String::from("Paris"), score: 0.9985, label: String::from("I-LOC") },
|
||||
//! Entity { word: String::from("Paris"), score: 0.9988, label: String::from("I-LOC") },
|
||||
//! Entity { word: String::from("France"), score: 0.9993, label: String::from("I-LOC") },
|
||||
//! Entity {
|
||||
//! word: String::from("Amy"),
|
||||
//! score: 0.9986,
|
||||
//! label: String::from("I-PER"),
|
||||
//! },
|
||||
//! Entity {
|
||||
//! word: String::from("Paris"),
|
||||
//! score: 0.9985,
|
||||
//! label: String::from("I-LOC"),
|
||||
//! },
|
||||
//! Entity {
|
||||
//! word: String::from("Paris"),
|
||||
//! score: 0.9988,
|
||||
//! label: String::from("I-LOC"),
|
||||
//! },
|
||||
//! Entity {
|
||||
//! word: String::from("France"),
|
||||
//! score: 0.9993,
|
||||
//! label: String::from("I-LOC"),
|
||||
//! },
|
||||
//! ]
|
||||
//! # ;
|
||||
//! ```
|
||||
//!
|
||||
|
||||
pub mod sentiment;
|
||||
pub mod common;
|
||||
pub mod token_classification;
|
||||
pub mod sequence_classification;
|
||||
pub mod generation;
|
||||
pub mod ner;
|
||||
pub mod question_answering;
|
||||
pub mod generation;
|
||||
pub mod sentiment;
|
||||
pub mod sequence_classification;
|
||||
pub mod summarization;
|
||||
pub mod token_classification;
|
||||
pub mod translation;
|
||||
|
@ -25,7 +25,7 @@
|
||||
//!
|
||||
//! let input = [
|
||||
//! "My name is Amy. I live in Paris.",
|
||||
//! "Paris is a city in France."
|
||||
//! "Paris is a city in France.",
|
||||
//! ];
|
||||
//! let output = ner_model.predict(&input);
|
||||
//! # Ok(())
|
||||
@ -37,16 +37,31 @@
|
||||
//! # use rust_bert::pipelines::ner::Entity;
|
||||
//! # let output =
|
||||
//! [
|
||||
//! Entity { word: String::from("Amy"), score: 0.9986, label: String::from("I-PER") },
|
||||
//! Entity { word: String::from("Paris"), score: 0.9985, label: String::from("I-LOC") },
|
||||
//! Entity { word: String::from("Paris"), score: 0.9988, label: String::from("I-LOC") },
|
||||
//! Entity { word: String::from("France"), score: 0.9993, label: String::from("I-LOC") },
|
||||
//! Entity {
|
||||
//! word: String::from("Amy"),
|
||||
//! score: 0.9986,
|
||||
//! label: String::from("I-PER"),
|
||||
//! },
|
||||
//! Entity {
|
||||
//! word: String::from("Paris"),
|
||||
//! score: 0.9985,
|
||||
//! label: String::from("I-LOC"),
|
||||
//! },
|
||||
//! Entity {
|
||||
//! word: String::from("Paris"),
|
||||
//! score: 0.9988,
|
||||
//! label: String::from("I-LOC"),
|
||||
//! },
|
||||
//! Entity {
|
||||
//! word: String::from("France"),
|
||||
//! score: 0.9993,
|
||||
//! label: String::from("I-LOC"),
|
||||
//! },
|
||||
//! ]
|
||||
//! # ;
|
||||
//! ```
|
||||
|
||||
use crate::pipelines::token_classification::{TokenClassificationModel, TokenClassificationConfig};
|
||||
|
||||
use crate::pipelines::token_classification::{TokenClassificationConfig, TokenClassificationModel};
|
||||
|
||||
#[derive(Debug)]
|
||||
/// # Entity generated by a `NERModel`
|
||||
@ -64,7 +79,7 @@ type NERConfig = TokenClassificationConfig;
|
||||
|
||||
/// # NERModel to extract named entities
|
||||
pub struct NERModel {
|
||||
token_classification_model: TokenClassificationModel
|
||||
token_classification_model: TokenClassificationModel,
|
||||
}
|
||||
|
||||
impl NERModel {
|
||||
@ -84,10 +99,11 @@ impl NERModel {
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
///
|
||||
pub fn new(ner_config: NERConfig) -> failure::Fallible<NERModel> {
|
||||
let model = TokenClassificationModel::new(ner_config)?;
|
||||
Ok(NERModel { token_classification_model: model })
|
||||
Ok(NERModel {
|
||||
token_classification_model: model,
|
||||
})
|
||||
}
|
||||
|
||||
/// Extract entities from a text
|
||||
@ -109,24 +125,22 @@ impl NERModel {
|
||||
/// let ner_model = NERModel::new(Default::default())?;
|
||||
/// let input = [
|
||||
/// "My name is Amy. I live in Paris.",
|
||||
/// "Paris is a city in France."
|
||||
/// "Paris is a city in France.",
|
||||
/// ];
|
||||
/// let output = ner_model.predict(&input);
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
///
|
||||
pub fn predict(&self, input: &[&str]) -> Vec<Entity> {
|
||||
self.token_classification_model
|
||||
.predict(input, true, false)
|
||||
.into_iter()
|
||||
.filter(|token| token.label != "O")
|
||||
.map(|token| {
|
||||
Entity {
|
||||
.map(|token| Entity {
|
||||
word: token.text,
|
||||
score: token.score,
|
||||
label: token.label,
|
||||
}
|
||||
}).collect()
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
@ -17,7 +17,7 @@
|
||||
//! The dependencies will be downloaded to the user's home directory, under ~/.cache/.rustbert/distilbert-qa
|
||||
//!
|
||||
//! ```no_run
|
||||
//! use rust_bert::pipelines::question_answering::{QuestionAnsweringModel, QaInput};
|
||||
//! use rust_bert::pipelines::question_answering::{QaInput, QuestionAnsweringModel};
|
||||
//!
|
||||
//! # fn main() -> failure::Fallible<()> {
|
||||
//! let qa_model = QuestionAnsweringModel::new(Default::default())?;
|
||||
@ -25,7 +25,7 @@
|
||||
//! let question = String::from("Where does Amy live ?");
|
||||
//! let context = String::from("Amy lives in Amsterdam");
|
||||
//!
|
||||
//! let answers = qa_model.predict(&vec!(QaInput { question, context }), 1, 32);
|
||||
//! let answers = qa_model.predict(&vec![QaInput { question, context }], 1, 32);
|
||||
//! # Ok(())
|
||||
//! # }
|
||||
//! ```
|
||||
@ -34,31 +34,31 @@
|
||||
//! ```no_run
|
||||
//! # use rust_bert::pipelines::question_answering::Answer;
|
||||
//! # let output =
|
||||
//! [
|
||||
//! Answer {
|
||||
//! [Answer {
|
||||
//! score: 0.9976,
|
||||
//! start: 13,
|
||||
//! end: 21,
|
||||
//! answer: "Amsterdam"
|
||||
//!# .to_owned()
|
||||
//! }
|
||||
//! ]
|
||||
//! answer: "Amsterdam", //#### # .to_owned()
|
||||
//! }]
|
||||
//! # ;
|
||||
//! ```
|
||||
|
||||
use rust_tokenizers::{BertTokenizer, Tokenizer, TruncationStrategy, TokenizedInput};
|
||||
use rust_tokenizers::preprocessing::tokenizer::base_tokenizer::Mask;
|
||||
use tch::{Device, Tensor, no_grad};
|
||||
use std::path::PathBuf;
|
||||
use rust_tokenizers::tokenization_utils::truncate_sequences;
|
||||
use std::collections::HashMap;
|
||||
use std::cmp::min;
|
||||
use tch::nn::VarStore;
|
||||
use tch::kind::Kind::Float;
|
||||
use std::fs;
|
||||
use crate::common::resources::{download_resource, RemoteResource, Resource};
|
||||
use crate::distilbert::{
|
||||
DistilBertConfig, DistilBertConfigResources, DistilBertForQuestionAnswering,
|
||||
DistilBertModelResources, DistilBertVocabResources,
|
||||
};
|
||||
use crate::Config;
|
||||
use crate::distilbert::{DistilBertForQuestionAnswering, DistilBertConfig, DistilBertConfigResources, DistilBertModelResources, DistilBertVocabResources};
|
||||
use crate::common::resources::{Resource, RemoteResource, download_resource};
|
||||
use rust_tokenizers::preprocessing::tokenizer::base_tokenizer::Mask;
|
||||
use rust_tokenizers::tokenization_utils::truncate_sequences;
|
||||
use rust_tokenizers::{BertTokenizer, TokenizedInput, Tokenizer, TruncationStrategy};
|
||||
use std::cmp::min;
|
||||
use std::collections::HashMap;
|
||||
use std::fs;
|
||||
use std::path::PathBuf;
|
||||
use tch::kind::Kind::Float;
|
||||
use tch::nn::VarStore;
|
||||
use tch::{no_grad, Device, Tensor};
|
||||
|
||||
/// # Input for Question Answering
|
||||
/// Includes a context (containing the answer) and question strings
|
||||
@ -84,7 +84,6 @@ struct QaFeature {
|
||||
pub token_to_orig_map: HashMap<i64, i64>,
|
||||
pub p_mask: Vec<i8>,
|
||||
pub example_index: i64,
|
||||
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
@ -102,34 +101,38 @@ pub struct Answer {
|
||||
|
||||
impl PartialEq for Answer {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
(self.start == other.start) &&
|
||||
(self.end == other.end) &&
|
||||
(self.answer == other.answer)
|
||||
(self.start == other.start) && (self.end == other.end) && (self.answer == other.answer)
|
||||
}
|
||||
}
|
||||
|
||||
fn remove_duplicates<T: PartialEq + Clone>(vector: &mut Vec<T>) -> &mut Vec<T> {
|
||||
let mut potential_duplicates = vec!();
|
||||
vector.retain(|item| if potential_duplicates.contains(item) {
|
||||
let mut potential_duplicates = vec![];
|
||||
vector.retain(|item| {
|
||||
if potential_duplicates.contains(item) {
|
||||
false
|
||||
} else {
|
||||
potential_duplicates.push(item.clone());
|
||||
true
|
||||
}
|
||||
});
|
||||
vector
|
||||
}
|
||||
|
||||
|
||||
impl QaExample {
|
||||
pub fn new(question: &str, context: &str) -> QaExample {
|
||||
let question = question.to_owned();
|
||||
let (doc_tokens, char_to_word_offset) = QaExample::split_context(context);
|
||||
QaExample { question, context: context.to_owned(), doc_tokens, char_to_word_offset }
|
||||
QaExample {
|
||||
question,
|
||||
context: context.to_owned(),
|
||||
doc_tokens,
|
||||
char_to_word_offset,
|
||||
}
|
||||
}
|
||||
|
||||
fn split_context(context: &str) -> (Vec<String>, Vec<i64>) {
|
||||
let mut doc_tokens: Vec<String> = vec!();
|
||||
let mut char_to_word_offset: Vec<i64> = vec!();
|
||||
let mut doc_tokens: Vec<String> = vec![];
|
||||
let mut char_to_word_offset: Vec<i64> = vec![];
|
||||
let max_length = context.len();
|
||||
let mut current_word = String::with_capacity(max_length);
|
||||
let mut previous_whitespace = false;
|
||||
@ -158,11 +161,11 @@ impl QaExample {
|
||||
}
|
||||
|
||||
fn is_whitespace(character: &char) -> bool {
|
||||
(character == &' ') |
|
||||
(character == &'\t') |
|
||||
(character == &'\r') |
|
||||
(character == &'\n') |
|
||||
(*character as u32 == 0x202F)
|
||||
(character == &' ')
|
||||
| (character == &'\t')
|
||||
| (character == &'\r')
|
||||
| (character == &'\n')
|
||||
| (*character as u32 == 0x202F)
|
||||
}
|
||||
}
|
||||
|
||||
@ -182,9 +185,15 @@ pub struct QuestionAnsweringConfig {
|
||||
impl Default for QuestionAnsweringConfig {
|
||||
fn default() -> QuestionAnsweringConfig {
|
||||
QuestionAnsweringConfig {
|
||||
model_resource: Resource::Remote(RemoteResource::from_pretrained(DistilBertModelResources::DISTIL_BERT_SQUAD)),
|
||||
config_resource: Resource::Remote(RemoteResource::from_pretrained(DistilBertConfigResources::DISTIL_BERT_SQUAD)),
|
||||
vocab_resource: Resource::Remote(RemoteResource::from_pretrained(DistilBertVocabResources::DISTIL_BERT_SQUAD)),
|
||||
model_resource: Resource::Remote(RemoteResource::from_pretrained(
|
||||
DistilBertModelResources::DISTIL_BERT_SQUAD,
|
||||
)),
|
||||
config_resource: Resource::Remote(RemoteResource::from_pretrained(
|
||||
DistilBertConfigResources::DISTIL_BERT_SQUAD,
|
||||
)),
|
||||
vocab_resource: Resource::Remote(RemoteResource::from_pretrained(
|
||||
DistilBertVocabResources::DISTIL_BERT_SQUAD,
|
||||
)),
|
||||
device: Device::cuda_if_available(),
|
||||
}
|
||||
}
|
||||
@ -220,16 +229,23 @@ impl QuestionAnsweringModel {
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
///
|
||||
pub fn new(question_answering_config: QuestionAnsweringConfig) -> failure::Fallible<QuestionAnsweringModel> {
|
||||
pub fn new(
|
||||
question_answering_config: QuestionAnsweringConfig,
|
||||
) -> failure::Fallible<QuestionAnsweringModel> {
|
||||
let config_path = download_resource(&question_answering_config.config_resource)?;
|
||||
let vocab_path = download_resource(&question_answering_config.vocab_resource)?;
|
||||
let weights_path = download_resource(&question_answering_config.model_resource)?;
|
||||
let device = question_answering_config.device;
|
||||
|
||||
let tokenizer = BertTokenizer::from_file(vocab_path.to_str().unwrap(), false);
|
||||
let pad_idx = *Tokenizer::vocab(&tokenizer).special_values.get("[PAD]").expect("[PAD] token not found in vocabulary");
|
||||
let sep_idx = *Tokenizer::vocab(&tokenizer).special_values.get("[SEP]").expect("[SEP] token not found in vocabulary");
|
||||
let pad_idx = *Tokenizer::vocab(&tokenizer)
|
||||
.special_values
|
||||
.get("[PAD]")
|
||||
.expect("[PAD] token not found in vocabulary");
|
||||
let sep_idx = *Tokenizer::vocab(&tokenizer)
|
||||
.special_values
|
||||
.get("[SEP]")
|
||||
.expect("[SEP] token not found in vocabulary");
|
||||
let mut var_store = VarStore::new(device);
|
||||
let mut config = DistilBertConfig::from_file(config_path);
|
||||
// The config for the current pre-trained question answering model indicates position embeddings which does not seem accurate
|
||||
@ -249,7 +265,6 @@ impl QuestionAnsweringModel {
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
/// Perform extractive question answering given a list of `QaInputs`
|
||||
///
|
||||
/// # Arguments
|
||||
@ -265,7 +280,7 @@ impl QuestionAnsweringModel {
|
||||
///
|
||||
/// ```no_run
|
||||
/// # fn main() -> failure::Fallible<()> {
|
||||
/// use rust_bert::pipelines::question_answering::{QuestionAnsweringModel, QaInput};
|
||||
/// use rust_bert::pipelines::question_answering::{QaInput, QuestionAnsweringModel};
|
||||
///
|
||||
/// let qa_model = QuestionAnsweringModel::new(Default::default())?;
|
||||
///
|
||||
@ -274,15 +289,25 @@ impl QuestionAnsweringModel {
|
||||
/// let question_2 = String::from("Where does Eric live");
|
||||
/// let context_2 = String::from("While Amy lives in Amsterdam, Eric is in The Hague.");
|
||||
///
|
||||
/// let qa_input_1 = QaInput { question: question_1, context: context_1 };
|
||||
/// let qa_input_2 = QaInput { question: question_2, context: context_2 };
|
||||
/// let qa_input_1 = QaInput {
|
||||
/// question: question_1,
|
||||
/// context: context_1,
|
||||
/// };
|
||||
/// let qa_input_2 = QaInput {
|
||||
/// question: question_2,
|
||||
/// context: context_2,
|
||||
/// };
|
||||
/// let answers = qa_model.predict(&[qa_input_1, qa_input_2], 1, 32);
|
||||
///
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
///
|
||||
pub fn predict(&self, qa_inputs: &[QaInput], top_k: i64, batch_size: usize) -> Vec<Vec<Answer>> {
|
||||
pub fn predict(
|
||||
&self,
|
||||
qa_inputs: &[QaInput],
|
||||
top_k: i64,
|
||||
batch_size: usize,
|
||||
) -> Vec<Vec<Answer>> {
|
||||
let examples: Vec<QaExample> = qa_inputs
|
||||
.iter()
|
||||
.map(|qa_input| QaExample::new(&qa_input.question, &qa_input.context))
|
||||
@ -290,7 +315,15 @@ impl QuestionAnsweringModel {
|
||||
let features: Vec<QaFeature> = examples
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(example_index, qa_example)| self.generate_features(&qa_example, self.max_seq_len, self.doc_stride, self.max_query_length, example_index as i64))
|
||||
.map(|(example_index, qa_example)| {
|
||||
self.generate_features(
|
||||
&qa_example,
|
||||
self.max_seq_len,
|
||||
self.doc_stride,
|
||||
self.max_query_length,
|
||||
example_index as i64,
|
||||
)
|
||||
})
|
||||
.flatten()
|
||||
.collect();
|
||||
|
||||
@ -310,28 +343,36 @@ impl QuestionAnsweringModel {
|
||||
}
|
||||
|
||||
let input_ids = Tensor::stack(&input_ids, 0).to(self.var_store.device());
|
||||
let attention_masks = Tensor::stack(&attention_masks, 0).to(self.var_store.device());
|
||||
let attention_masks =
|
||||
Tensor::stack(&attention_masks, 0).to(self.var_store.device());
|
||||
|
||||
let (start_logits, end_logits, _, _) = self.distilbert_qa.forward_t(Some(input_ids), Some(attention_masks), None, false).unwrap();
|
||||
let (start_logits, end_logits, _, _) = self
|
||||
.distilbert_qa
|
||||
.forward_t(Some(input_ids), Some(attention_masks), None, false)
|
||||
.unwrap();
|
||||
|
||||
let start_logits = start_logits.detach();
|
||||
let end_logits = end_logits.detach();
|
||||
let example_index_to_feature_end_position: Vec<(usize, i64)> = batch_features
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(feature_index, feature)| (feature.example_index as usize, feature_index as i64 + 1))
|
||||
.map(|(feature_index, feature)| {
|
||||
(feature.example_index as usize, feature_index as i64 + 1)
|
||||
})
|
||||
.collect();
|
||||
|
||||
let mut feature_id_start = 0;
|
||||
|
||||
for (example_id, max_feature_id) in example_index_to_feature_end_position {
|
||||
let mut answers: Vec<Answer> = vec!();
|
||||
let mut answers: Vec<Answer> = vec![];
|
||||
let example = &examples[example_id];
|
||||
for feature_idx in feature_id_start..max_feature_id {
|
||||
let feature = &batch_features[feature_idx as usize];
|
||||
let start = start_logits.get(feature_idx);
|
||||
let end = end_logits.get(feature_idx);
|
||||
let p_mask = (Tensor::of_slice(&feature.p_mask) - 1).abs().to_device(start.device());
|
||||
let p_mask = (Tensor::of_slice(&feature.p_mask) - 1)
|
||||
.abs()
|
||||
.to_device(start.device());
|
||||
|
||||
let start: Tensor = start.exp() / start.exp().sum(Float) * &p_mask;
|
||||
let end: Tensor = end.exp() / end.exp().sum(Float) * &p_mask;
|
||||
@ -343,33 +384,42 @@ impl QuestionAnsweringModel {
|
||||
let end_pos = feature.token_to_orig_map[&ends[idx]] as usize;
|
||||
let answer = example.doc_tokens[start_pos..end_pos + 1].join(" ");
|
||||
|
||||
let start = example.char_to_word_offset
|
||||
let start = example
|
||||
.char_to_word_offset
|
||||
.iter()
|
||||
.position(|&v| v as usize == start_pos)
|
||||
.unwrap();
|
||||
|
||||
let end = example.char_to_word_offset
|
||||
let end = example
|
||||
.char_to_word_offset
|
||||
.iter()
|
||||
.rposition(|&v| v as usize == end_pos)
|
||||
.unwrap();
|
||||
|
||||
answers.push(Answer { score: scores[idx], start, end, answer });
|
||||
answers.push(Answer {
|
||||
score: scores[idx],
|
||||
start,
|
||||
end,
|
||||
answer,
|
||||
});
|
||||
}
|
||||
}
|
||||
feature_id_start = max_feature_id;
|
||||
let example_answers = example_top_k_answers_map.entry(example_id).or_insert(vec!());
|
||||
let example_answers = example_top_k_answers_map
|
||||
.entry(example_id)
|
||||
.or_insert(vec![]);
|
||||
example_answers.extend(answers);
|
||||
}
|
||||
});
|
||||
start = end;
|
||||
}
|
||||
let mut all_answers = vec!();
|
||||
let mut all_answers = vec![];
|
||||
for example_id in 0..examples.len() {
|
||||
if let Some(answers) = example_top_k_answers_map.get_mut(&example_id) {
|
||||
remove_duplicates(answers).sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
|
||||
all_answers.push(answers[..min(answers.len(), top_k as usize)].to_vec());
|
||||
} else {
|
||||
all_answers.push(vec!());
|
||||
all_answers.push(vec![]);
|
||||
}
|
||||
}
|
||||
all_answers
|
||||
@ -379,7 +429,10 @@ impl QuestionAnsweringModel {
|
||||
let outer = start.unsqueeze(-1).matmul(&end.unsqueeze(0));
|
||||
let start_dim = start.size()[0];
|
||||
let end_dim = end.size()[0];
|
||||
let candidates = outer.triu(0).tril(self.max_answer_len as i64 - 1).flatten(0, -1);
|
||||
let candidates = outer
|
||||
.triu(0)
|
||||
.tril(self.max_answer_len as i64 - 1)
|
||||
.flatten(0, -1);
|
||||
let idx_sort = if top_k == 1 {
|
||||
candidates.argmax(0, true)
|
||||
} else if candidates.size()[0] < top_k {
|
||||
@ -387,9 +440,9 @@ impl QuestionAnsweringModel {
|
||||
} else {
|
||||
candidates.argsort(0, true).slice(0, 0, top_k, 1)
|
||||
};
|
||||
let mut start: Vec<i64> = vec!();
|
||||
let mut end: Vec<i64> = vec!();
|
||||
let mut scores: Vec<f64> = vec!();
|
||||
let mut start: Vec<i64> = vec![];
|
||||
let mut end: Vec<i64> = vec![];
|
||||
let mut scores: Vec<f64> = vec![];
|
||||
for flat_index_position in 0..idx_sort.size()[0] {
|
||||
let flat_index = idx_sort.int64_value(&[flat_index_position]);
|
||||
scores.push(candidates.double_value(&[flat_index]));
|
||||
@ -399,10 +452,16 @@ impl QuestionAnsweringModel {
|
||||
(start, end, scores)
|
||||
}
|
||||
|
||||
|
||||
fn generate_features(&self, qa_example: &QaExample, max_seq_length: usize, doc_stride: usize, max_query_length: usize, example_index: i64) -> Vec<QaFeature> {
|
||||
let mut tok_to_orig_index: Vec<i64> = vec!();
|
||||
let mut all_doc_tokens: Vec<String> = vec!();
|
||||
fn generate_features(
|
||||
&self,
|
||||
qa_example: &QaExample,
|
||||
max_seq_length: usize,
|
||||
doc_stride: usize,
|
||||
max_query_length: usize,
|
||||
example_index: i64,
|
||||
) -> Vec<QaFeature> {
|
||||
let mut tok_to_orig_index: Vec<i64> = vec![];
|
||||
let mut all_doc_tokens: Vec<String> = vec![];
|
||||
|
||||
for (idx, token) in qa_example.doc_tokens.iter().enumerate() {
|
||||
let sub_tokens = self.tokenizer.tokenize(token);
|
||||
@ -414,28 +473,61 @@ impl QuestionAnsweringModel {
|
||||
|
||||
let truncated_query = self.prepare_query(&qa_example.question, max_query_length);
|
||||
|
||||
let sequence_added_tokens = self.tokenizer.build_input_with_special_tokens(vec!(), None, vec!(), None, vec!(), None, vec!(), None).0.len();
|
||||
let sequence_pair_added_tokens = self.tokenizer.build_input_with_special_tokens(vec!(), Some(vec!()), vec!(), Some(vec!()), vec!(), Some(vec!()), vec!(), Some(vec!())).0.len();
|
||||
let sequence_added_tokens = self
|
||||
.tokenizer
|
||||
.build_input_with_special_tokens(vec![], None, vec![], None, vec![], None, vec![], None)
|
||||
.0
|
||||
.len();
|
||||
let sequence_pair_added_tokens = self
|
||||
.tokenizer
|
||||
.build_input_with_special_tokens(
|
||||
vec![],
|
||||
Some(vec![]),
|
||||
vec![],
|
||||
Some(vec![]),
|
||||
vec![],
|
||||
Some(vec![]),
|
||||
vec![],
|
||||
Some(vec![]),
|
||||
)
|
||||
.0
|
||||
.len();
|
||||
|
||||
let mut spans: Vec<QaFeature> = vec!();
|
||||
let mut spans: Vec<QaFeature> = vec![];
|
||||
|
||||
let mut remaining_tokens = self.tokenizer.convert_tokens_to_ids(&all_doc_tokens);
|
||||
while (spans.len() * doc_stride as usize) < all_doc_tokens.len() {
|
||||
let (encoded_span, attention_mask) = self.encode_qa_pair(&truncated_query, &remaining_tokens, max_seq_length, doc_stride, sequence_pair_added_tokens);
|
||||
let (encoded_span, attention_mask) = self.encode_qa_pair(
|
||||
&truncated_query,
|
||||
&remaining_tokens,
|
||||
max_seq_length,
|
||||
doc_stride,
|
||||
sequence_pair_added_tokens,
|
||||
);
|
||||
|
||||
let paragraph_len = min(
|
||||
all_doc_tokens.len() - spans.len() * doc_stride,
|
||||
max_seq_length - truncated_query.len() - sequence_pair_added_tokens);
|
||||
max_seq_length - truncated_query.len() - sequence_pair_added_tokens,
|
||||
);
|
||||
|
||||
let mut token_to_orig_map = HashMap::new();
|
||||
for i in 0..paragraph_len {
|
||||
let index = truncated_query.len() + sequence_added_tokens + i;
|
||||
token_to_orig_map.insert(index as i64, tok_to_orig_index[spans.len() * doc_stride + i] as i64);
|
||||
token_to_orig_map.insert(
|
||||
index as i64,
|
||||
tok_to_orig_index[spans.len() * doc_stride + i] as i64,
|
||||
);
|
||||
}
|
||||
|
||||
let p_mask = self.get_mask(&encoded_span);
|
||||
|
||||
let qa_feature = QaFeature { input_ids: encoded_span.token_ids, attention_mask, token_to_orig_map, p_mask, example_index };
|
||||
let qa_feature = QaFeature {
|
||||
input_ids: encoded_span.token_ids,
|
||||
attention_mask,
|
||||
token_to_orig_map,
|
||||
p_mask,
|
||||
example_index,
|
||||
};
|
||||
|
||||
spans.push(qa_feature);
|
||||
if encoded_span.num_truncated_tokens == 0 {
|
||||
@ -447,59 +539,81 @@ impl QuestionAnsweringModel {
|
||||
}
|
||||
|
||||
fn prepare_query(&self, query: &str, max_query_length: usize) -> Vec<i64> {
|
||||
let truncated_query = self.tokenizer.convert_tokens_to_ids(&self.tokenizer.tokenize(&query));
|
||||
let num_query_tokens_to_remove = if truncated_query.len() > max_query_length as usize { truncated_query.len() - max_query_length } else { 0 };
|
||||
let (truncated_query, _, _, _, _, _, _, _, _, _) = truncate_sequences(truncated_query,
|
||||
let truncated_query = self
|
||||
.tokenizer
|
||||
.convert_tokens_to_ids(&self.tokenizer.tokenize(&query));
|
||||
let num_query_tokens_to_remove = if truncated_query.len() > max_query_length as usize {
|
||||
truncated_query.len() - max_query_length
|
||||
} else {
|
||||
0
|
||||
};
|
||||
let (truncated_query, _, _, _, _, _, _, _, _, _) = truncate_sequences(
|
||||
truncated_query,
|
||||
None,
|
||||
vec!(),
|
||||
vec![],
|
||||
None,
|
||||
vec!(),
|
||||
vec![],
|
||||
None,
|
||||
vec!(),
|
||||
vec![],
|
||||
None,
|
||||
num_query_tokens_to_remove,
|
||||
&TruncationStrategy::OnlyFirst,
|
||||
0).unwrap();
|
||||
0,
|
||||
)
|
||||
.unwrap();
|
||||
truncated_query
|
||||
}
|
||||
|
||||
fn encode_qa_pair(&self,
|
||||
fn encode_qa_pair(
|
||||
&self,
|
||||
truncated_query: &Vec<i64>,
|
||||
spans_token_ids: &Vec<i64>,
|
||||
max_seq_length: usize,
|
||||
doc_stride: usize,
|
||||
sequence_pair_added_tokens: usize) -> (TokenizedInput, Vec<i64>) {
|
||||
sequence_pair_added_tokens: usize,
|
||||
) -> (TokenizedInput, Vec<i64>) {
|
||||
let len_1 = truncated_query.len();
|
||||
let len_2 = spans_token_ids.len();
|
||||
let total_len = len_1 + len_2 + sequence_pair_added_tokens;
|
||||
let num_truncated_tokens = if total_len > max_seq_length { total_len - max_seq_length } else { 0 };
|
||||
let num_truncated_tokens = if total_len > max_seq_length {
|
||||
total_len - max_seq_length
|
||||
} else {
|
||||
0
|
||||
};
|
||||
|
||||
let (truncated_query, truncated_context, _, _, _, _, _, _, overflowing_tokens, _)
|
||||
= truncate_sequences(truncated_query.clone(),
|
||||
let (truncated_query, truncated_context, _, _, _, _, _, _, overflowing_tokens, _) =
|
||||
truncate_sequences(
|
||||
truncated_query.clone(),
|
||||
Some(spans_token_ids.clone()),
|
||||
vec!(),
|
||||
vec![],
|
||||
None,
|
||||
vec!(),
|
||||
vec![],
|
||||
None,
|
||||
vec!(),
|
||||
vec![],
|
||||
None,
|
||||
num_truncated_tokens,
|
||||
&TruncationStrategy::OnlySecond,
|
||||
max_seq_length - doc_stride - len_1 - sequence_pair_added_tokens).unwrap();
|
||||
max_seq_length - doc_stride - len_1 - sequence_pair_added_tokens,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let (mut token_ids,
|
||||
let (
|
||||
mut token_ids,
|
||||
mut segment_ids,
|
||||
special_tokens_mask,
|
||||
mut token_offsets,
|
||||
mut reference_offsets,
|
||||
mut mask) = self.tokenizer.build_input_with_special_tokens(truncated_query,
|
||||
mut mask,
|
||||
) = self.tokenizer.build_input_with_special_tokens(
|
||||
truncated_query,
|
||||
truncated_context,
|
||||
vec!(),
|
||||
vec![],
|
||||
None,
|
||||
vec!(),
|
||||
vec![],
|
||||
None,
|
||||
vec!(),
|
||||
None);
|
||||
vec![],
|
||||
None,
|
||||
);
|
||||
let mut attention_mask = vec![1; token_ids.len()];
|
||||
if token_ids.len() < max_seq_length {
|
||||
token_ids.append(&mut vec![self.pad_idx; max_seq_length - token_ids.len()]);
|
||||
@ -509,18 +623,32 @@ impl QuestionAnsweringModel {
|
||||
reference_offsets.append(&mut vec![vec!(); max_seq_length - token_offsets.len()]);
|
||||
mask.append(&mut vec![Mask::Special; max_seq_length - mask.len()]);
|
||||
}
|
||||
(TokenizedInput { token_ids, segment_ids, special_tokens_mask, overflowing_tokens, num_truncated_tokens, token_offsets, reference_offsets, mask }, attention_mask)
|
||||
(
|
||||
TokenizedInput {
|
||||
token_ids,
|
||||
segment_ids,
|
||||
special_tokens_mask,
|
||||
overflowing_tokens,
|
||||
num_truncated_tokens,
|
||||
token_offsets,
|
||||
reference_offsets,
|
||||
mask,
|
||||
},
|
||||
attention_mask,
|
||||
)
|
||||
}
|
||||
|
||||
fn get_mask(&self, encoded_span: &TokenizedInput) -> Vec<i8> {
|
||||
let sep_indices: Vec<usize> = encoded_span.token_ids
|
||||
let sep_indices: Vec<usize> = encoded_span
|
||||
.token_ids
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter(|(_, &value)| value == self.sep_idx)
|
||||
.map(|(position, _)| position)
|
||||
.collect();
|
||||
|
||||
let mut p_mask: Vec<i8> = encoded_span.segment_ids
|
||||
let mut p_mask: Vec<i8> = encoded_span
|
||||
.segment_ids
|
||||
.iter()
|
||||
.map(|v| min(v, &1i8))
|
||||
.map(|&v| 1i8 - v)
|
||||
@ -534,10 +662,13 @@ impl QuestionAnsweringModel {
|
||||
|
||||
pub fn squad_processor(file_path: PathBuf) -> Vec<QaInput> {
|
||||
let file = fs::File::open(file_path).expect("unable to open file");
|
||||
let json: serde_json::Value = serde_json::from_reader(file).expect("JSON not properly formatted");
|
||||
let json: serde_json::Value =
|
||||
serde_json::from_reader(file).expect("JSON not properly formatted");
|
||||
let data = json
|
||||
.get("data").expect("SQuAD file does not contain data field")
|
||||
.as_array().expect("Data array not properly formatted");
|
||||
.get("data")
|
||||
.expect("SQuAD file does not contain data field")
|
||||
.as_array()
|
||||
.expect("Data array not properly formatted");
|
||||
|
||||
let mut qa_inputs: Vec<QaInput> = Vec::with_capacity(data.len());
|
||||
for qa_input in data.iter() {
|
||||
@ -548,8 +679,17 @@ pub fn squad_processor(file_path: PathBuf) -> Vec<QaInput> {
|
||||
let context = paragraph.get("context").unwrap().as_str().unwrap();
|
||||
let qas = paragraph.get("qas").unwrap().as_array().unwrap();
|
||||
for qa in qas.iter() {
|
||||
let question = qa.as_object().unwrap().get("question").unwrap().as_str().unwrap();
|
||||
qa_inputs.push(QaInput { question: question.to_owned(), context: context.to_owned() });
|
||||
let question = qa
|
||||
.as_object()
|
||||
.unwrap()
|
||||
.get("question")
|
||||
.unwrap()
|
||||
.as_str()
|
||||
.unwrap();
|
||||
qa_inputs.push(QaInput {
|
||||
question: question.to_owned(),
|
||||
context: context.to_owned(),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -38,18 +38,29 @@
|
||||
//! # use rust_bert::pipelines::sentiment::SentimentPolarity::{Positive, Negative};
|
||||
//! # let output =
|
||||
//! [
|
||||
//! Sentiment { polarity: Positive, score: 0.998 },
|
||||
//! Sentiment { polarity: Negative, score: 0.992 },
|
||||
//! Sentiment { polarity: Positive, score: 0.999 }
|
||||
//! Sentiment {
|
||||
//! polarity: Positive,
|
||||
//! score: 0.998,
|
||||
//! },
|
||||
//! Sentiment {
|
||||
//! polarity: Negative,
|
||||
//! score: 0.992,
|
||||
//! },
|
||||
//! Sentiment {
|
||||
//! polarity: Positive,
|
||||
//! score: 0.999,
|
||||
//! },
|
||||
//! ]
|
||||
//! # ;
|
||||
//! ```
|
||||
|
||||
use std::path::PathBuf;
|
||||
use std::fs;
|
||||
use crate::pipelines::sequence_classification::{
|
||||
SequenceClassificationConfig, SequenceClassificationModel,
|
||||
};
|
||||
use serde::Deserialize;
|
||||
use std::error::Error;
|
||||
use crate::pipelines::sequence_classification::{SequenceClassificationConfig, SequenceClassificationModel};
|
||||
use std::fs;
|
||||
use std::path::PathBuf;
|
||||
|
||||
#[derive(Debug, PartialEq)]
|
||||
/// Enum with the possible sentiment polarities. Note that the pre-trained SST2 model does not include neutral sentiment.
|
||||
@ -71,7 +82,7 @@ type SentimentConfig = SequenceClassificationConfig;
|
||||
|
||||
/// # SentimentClassifier to perform sentiment analysis
|
||||
pub struct SentimentModel {
|
||||
sequence_classification_model: SequenceClassificationModel
|
||||
sequence_classification_model: SequenceClassificationModel,
|
||||
}
|
||||
|
||||
impl SentimentModel {
|
||||
@ -91,10 +102,11 @@ impl SentimentModel {
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
///
|
||||
pub fn new(sentiment_config: SentimentConfig) -> failure::Fallible<SentimentModel> {
|
||||
let sequence_classification_model = SequenceClassificationModel::new(sentiment_config)?;
|
||||
Ok(SentimentModel { sequence_classification_model })
|
||||
Ok(SentimentModel {
|
||||
sequence_classification_model,
|
||||
})
|
||||
}
|
||||
|
||||
/// Extract sentiment form an array of text inputs
|
||||
@ -124,14 +136,20 @@ impl SentimentModel {
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
///
|
||||
pub fn predict(&self, input: &[&str]) -> Vec<Sentiment> {
|
||||
let labels = self.sequence_classification_model.predict(input);
|
||||
let mut sentiments = Vec::with_capacity(labels.len());
|
||||
for label in labels {
|
||||
let polarity = if label.id == 1 { SentimentPolarity::Positive } else { SentimentPolarity::Negative };
|
||||
sentiments.push(Sentiment { polarity, score: label.score })
|
||||
let polarity = if label.id == 1 {
|
||||
SentimentPolarity::Positive
|
||||
} else {
|
||||
SentimentPolarity::Negative
|
||||
};
|
||||
sentiments.push(Sentiment {
|
||||
polarity,
|
||||
score: label.score,
|
||||
})
|
||||
}
|
||||
sentiments
|
||||
}
|
||||
}
|
||||
|
@ -56,17 +56,21 @@
|
||||
//! # ;
|
||||
//! ```
|
||||
//!
|
||||
use tch::nn::VarStore;
|
||||
use rust_tokenizers::preprocessing::tokenizer::base_tokenizer::{TokenizedInput, TruncationStrategy};
|
||||
use std::collections::HashMap;
|
||||
use tch::{Tensor, no_grad, Device, Kind};
|
||||
use crate::bert::BertForSequenceClassification;
|
||||
use crate::common::resources::{download_resource, RemoteResource, Resource};
|
||||
use crate::distilbert::{
|
||||
DistilBertConfigResources, DistilBertModelClassifier, DistilBertModelResources,
|
||||
DistilBertVocabResources,
|
||||
};
|
||||
use crate::pipelines::common::{ConfigOption, ModelType, TokenizerOption};
|
||||
use crate::roberta::RobertaForSequenceClassification;
|
||||
use crate::distilbert::{DistilBertModelResources, DistilBertConfigResources, DistilBertVocabResources, DistilBertModelClassifier};
|
||||
use crate::common::resources::{Resource, RemoteResource, download_resource};
|
||||
use serde::{Serialize, Deserialize};
|
||||
use crate::pipelines::common::{ModelType, ConfigOption, TokenizerOption};
|
||||
|
||||
use rust_tokenizers::preprocessing::tokenizer::base_tokenizer::{
|
||||
TokenizedInput, TruncationStrategy,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use tch::nn::VarStore;
|
||||
use tch::{no_grad, Device, Kind, Tensor};
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
/// # Label generated by a `SequenceClassificationModel`
|
||||
@ -112,8 +116,14 @@ impl SequenceClassificationConfig {
|
||||
/// * vocab - The `Resource' pointing to the tokenizer's vocabulary to load (e.g. vocab.txt/vocab.json)
|
||||
/// * vocab - An optional `Resource` tuple (`Option<Resource>`) pointing to the tokenizer's merge file to load (e.g. merges.txt), needed only for Roberta.
|
||||
/// * lower_case - A `bool' indicating whether the tokeniser should lower case all input (in case of a lower-cased model)
|
||||
///
|
||||
pub fn new(model_type: ModelType, model_resource: Resource, config_resource: Resource, vocab_resource: Resource, merges_resource: Option<Resource>, lower_case: bool) -> SequenceClassificationConfig {
|
||||
pub fn new(
|
||||
model_type: ModelType,
|
||||
model_resource: Resource,
|
||||
config_resource: Resource,
|
||||
vocab_resource: Resource,
|
||||
merges_resource: Option<Resource>,
|
||||
lower_case: bool,
|
||||
) -> SequenceClassificationConfig {
|
||||
SequenceClassificationConfig {
|
||||
model_type,
|
||||
model_resource,
|
||||
@ -131,9 +141,15 @@ impl Default for SequenceClassificationConfig {
|
||||
fn default() -> SequenceClassificationConfig {
|
||||
SequenceClassificationConfig {
|
||||
model_type: ModelType::DistilBert,
|
||||
model_resource: Resource::Remote(RemoteResource::from_pretrained(DistilBertModelResources::DISTIL_BERT_SST2)),
|
||||
config_resource: Resource::Remote(RemoteResource::from_pretrained(DistilBertConfigResources::DISTIL_BERT_SST2)),
|
||||
vocab_resource: Resource::Remote(RemoteResource::from_pretrained(DistilBertVocabResources::DISTIL_BERT_SST2)),
|
||||
model_resource: Resource::Remote(RemoteResource::from_pretrained(
|
||||
DistilBertModelResources::DISTIL_BERT_SST2,
|
||||
)),
|
||||
config_resource: Resource::Remote(RemoteResource::from_pretrained(
|
||||
DistilBertConfigResources::DISTIL_BERT_SST2,
|
||||
)),
|
||||
vocab_resource: Resource::Remote(RemoteResource::from_pretrained(
|
||||
DistilBertVocabResources::DISTIL_BERT_SST2,
|
||||
)),
|
||||
merges_resource: None,
|
||||
lower_case: true,
|
||||
device: Device::cuda_if_available(),
|
||||
@ -160,31 +176,38 @@ impl SequenceClassificationOption {
|
||||
/// * `p` - `tch::nn::Path` path to the model file to load (e.g. model.ot)
|
||||
/// * `config` - A configuration (the model type of the configuration must be compatible with the value for
|
||||
/// `model_type`)
|
||||
///
|
||||
pub fn new(model_type: ModelType, p: &tch::nn::Path, config: &ConfigOption) -> Self {
|
||||
match model_type {
|
||||
ModelType::Bert => {
|
||||
if let ConfigOption::Bert(config) = config {
|
||||
SequenceClassificationOption::Bert(BertForSequenceClassification::new(p, config))
|
||||
SequenceClassificationOption::Bert(BertForSequenceClassification::new(
|
||||
p, config,
|
||||
))
|
||||
} else {
|
||||
panic!("You can only supply a BertConfig for Bert!");
|
||||
}
|
||||
}
|
||||
ModelType::DistilBert => {
|
||||
if let ConfigOption::DistilBert(config) = config {
|
||||
SequenceClassificationOption::DistilBert(DistilBertModelClassifier::new(p, config))
|
||||
SequenceClassificationOption::DistilBert(DistilBertModelClassifier::new(
|
||||
p, config,
|
||||
))
|
||||
} else {
|
||||
panic!("You can only supply a DistilBertConfig for DistilBert!");
|
||||
}
|
||||
}
|
||||
ModelType::Roberta => {
|
||||
if let ConfigOption::Bert(config) = config {
|
||||
SequenceClassificationOption::Roberta(RobertaForSequenceClassification::new(p, config))
|
||||
SequenceClassificationOption::Roberta(RobertaForSequenceClassification::new(
|
||||
p, config,
|
||||
))
|
||||
} else {
|
||||
panic!("You can only supply a BertConfig for Roberta!");
|
||||
}
|
||||
}
|
||||
ModelType::Electra => { panic!("SequenceClassification not implemented for Electra!"); }
|
||||
ModelType::Electra => {
|
||||
panic!("SequenceClassification not implemented for Electra!");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -193,27 +216,44 @@ impl SequenceClassificationOption {
|
||||
match *self {
|
||||
Self::Bert(_) => ModelType::Bert,
|
||||
Self::Roberta(_) => ModelType::Roberta,
|
||||
Self::DistilBert(_) => ModelType::DistilBert
|
||||
Self::DistilBert(_) => ModelType::DistilBert,
|
||||
}
|
||||
}
|
||||
|
||||
/// Interface method to forward_t() of the particular models.
|
||||
pub fn forward_t(&self,
|
||||
pub fn forward_t(
|
||||
&self,
|
||||
input_ids: Option<Tensor>,
|
||||
mask: Option<Tensor>,
|
||||
token_type_ids: Option<Tensor>,
|
||||
position_ids: Option<Tensor>,
|
||||
input_embeds: Option<Tensor>,
|
||||
train: bool) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
|
||||
train: bool,
|
||||
) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
|
||||
match *self {
|
||||
Self::Bert(ref model) => model.forward_t(input_ids, mask, token_type_ids, position_ids, input_embeds, train),
|
||||
Self::DistilBert(ref model) => model.forward_t(input_ids, mask, input_embeds, train).expect("Error in distilbert forward_t"),
|
||||
Self::Roberta(ref model) => model.forward_t(input_ids, mask, token_type_ids, position_ids, input_embeds, train),
|
||||
Self::Bert(ref model) => model.forward_t(
|
||||
input_ids,
|
||||
mask,
|
||||
token_type_ids,
|
||||
position_ids,
|
||||
input_embeds,
|
||||
train,
|
||||
),
|
||||
Self::DistilBert(ref model) => model
|
||||
.forward_t(input_ids, mask, input_embeds, train)
|
||||
.expect("Error in distilbert forward_t"),
|
||||
Self::Roberta(ref model) => model.forward_t(
|
||||
input_ids,
|
||||
mask,
|
||||
token_type_ids,
|
||||
position_ids,
|
||||
input_embeds,
|
||||
train,
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/// # SequenceClassificationModel for Classification (e.g. Sentiment Analysis)
|
||||
pub struct SequenceClassificationModel {
|
||||
tokenizer: TokenizerOption,
|
||||
@ -239,8 +279,9 @@ impl SequenceClassificationModel {
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
///
|
||||
pub fn new(config: SequenceClassificationConfig) -> failure::Fallible<SequenceClassificationModel> {
|
||||
pub fn new(
|
||||
config: SequenceClassificationConfig,
|
||||
) -> failure::Fallible<SequenceClassificationModel> {
|
||||
let config_path = download_resource(&config.config_resource)?;
|
||||
let vocab_path = download_resource(&config.vocab_resource)?;
|
||||
let weights_path = download_resource(&config.model_resource)?;
|
||||
@ -251,31 +292,44 @@ impl SequenceClassificationModel {
|
||||
};
|
||||
let device = config.device;
|
||||
|
||||
let tokenizer = TokenizerOption::from_file(config.model_type, vocab_path.to_str().unwrap(), merges_path.map(|path| path.to_str().unwrap()), config.lower_case);
|
||||
let tokenizer = TokenizerOption::from_file(
|
||||
config.model_type,
|
||||
vocab_path.to_str().unwrap(),
|
||||
merges_path.map(|path| path.to_str().unwrap()),
|
||||
config.lower_case,
|
||||
);
|
||||
let mut var_store = VarStore::new(device);
|
||||
let model_config = ConfigOption::from_file(config.model_type, config_path);
|
||||
let sequence_classifier = SequenceClassificationOption::new(config.model_type, &var_store.root(), &model_config);
|
||||
let sequence_classifier =
|
||||
SequenceClassificationOption::new(config.model_type, &var_store.root(), &model_config);
|
||||
let label_mapping = model_config.get_label_mapping();
|
||||
var_store.load(weights_path)?;
|
||||
Ok(SequenceClassificationModel { tokenizer, sequence_classifier, label_mapping, var_store })
|
||||
Ok(SequenceClassificationModel {
|
||||
tokenizer,
|
||||
sequence_classifier,
|
||||
label_mapping,
|
||||
var_store,
|
||||
})
|
||||
}
|
||||
|
||||
fn prepare_for_model(&self, input: Vec<&str>) -> Tensor {
|
||||
let tokenized_input: Vec<TokenizedInput> = self.tokenizer.encode_list(input.to_vec(),
|
||||
128,
|
||||
&TruncationStrategy::LongestFirst,
|
||||
0);
|
||||
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
|
||||
let tokenized_input_tensors: Vec<tch::Tensor> = tokenized_input.
|
||||
iter().
|
||||
map(|input| input.token_ids.clone()).
|
||||
map(|mut input| {
|
||||
let tokenized_input: Vec<TokenizedInput> =
|
||||
self.tokenizer
|
||||
.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
||||
let max_len = tokenized_input
|
||||
.iter()
|
||||
.map(|input| input.token_ids.len())
|
||||
.max()
|
||||
.unwrap();
|
||||
let tokenized_input_tensors: Vec<tch::Tensor> = tokenized_input
|
||||
.iter()
|
||||
.map(|input| input.token_ids.clone())
|
||||
.map(|mut input| {
|
||||
input.extend(vec![0; max_len - input.len()]);
|
||||
input
|
||||
}).
|
||||
map(|input|
|
||||
Tensor::of_slice(&(input))).
|
||||
collect::<Vec<_>>();
|
||||
})
|
||||
.map(|input| Tensor::of_slice(&(input)))
|
||||
.collect::<Vec<_>>();
|
||||
Tensor::stack(tokenized_input_tensors.as_slice(), 0).to(self.var_store.device())
|
||||
}
|
||||
|
||||
@ -308,23 +362,30 @@ impl SequenceClassificationModel {
|
||||
pub fn predict(&self, input: &[&str]) -> Vec<Label> {
|
||||
let input_tensor = self.prepare_for_model(input.to_vec());
|
||||
let output = no_grad(|| {
|
||||
let (output, _, _) = self.sequence_classifier
|
||||
.forward_t(Some(input_tensor.copy()),
|
||||
let (output, _, _) = self.sequence_classifier.forward_t(
|
||||
Some(input_tensor.copy()),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
false);
|
||||
false,
|
||||
);
|
||||
output.softmax(-1, Kind::Float).detach().to(Device::Cpu)
|
||||
});
|
||||
let label_indices = output.as_ref().argmax(-1, true).squeeze1(1);
|
||||
let scores = output.gather(1, &label_indices.unsqueeze(-1), false).squeeze1(1);
|
||||
let scores = output
|
||||
.gather(1, &label_indices.unsqueeze(-1), false)
|
||||
.squeeze1(1);
|
||||
let label_indices = label_indices.iter::<i64>().unwrap().collect::<Vec<i64>>();
|
||||
let scores = scores.iter::<f64>().unwrap().collect::<Vec<f64>>();
|
||||
|
||||
let mut labels: Vec<Label> = vec!();
|
||||
let mut labels: Vec<Label> = vec![];
|
||||
for sentence_idx in 0..label_indices.len() {
|
||||
let label_string = self.label_mapping.get(&label_indices[sentence_idx]).unwrap().clone();
|
||||
let label_string = self
|
||||
.label_mapping
|
||||
.get(&label_indices[sentence_idx])
|
||||
.unwrap()
|
||||
.clone();
|
||||
let label = Label {
|
||||
text: label_string,
|
||||
score: scores[sentence_idx],
|
||||
@ -366,28 +427,31 @@ impl SequenceClassificationModel {
|
||||
pub fn predict_multilabel(&self, input: &[&str], threshold: f64) -> Vec<Vec<Label>> {
|
||||
let input_tensor = self.prepare_for_model(input.to_vec());
|
||||
let output = no_grad(|| {
|
||||
let (output, _, _) = self.sequence_classifier
|
||||
.forward_t(Some(input_tensor.copy()),
|
||||
let (output, _, _) = self.sequence_classifier.forward_t(
|
||||
Some(input_tensor.copy()),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
false);
|
||||
false,
|
||||
);
|
||||
output.sigmoid().detach().to(Device::Cpu)
|
||||
});
|
||||
let label_indices = output.as_ref().ge(threshold).nonzero();
|
||||
|
||||
let mut labels: Vec<Vec<Label>> = vec!();
|
||||
let mut sequence_labels: Vec<Label> = vec!();
|
||||
let mut labels: Vec<Vec<Label>> = vec![];
|
||||
let mut sequence_labels: Vec<Label> = vec![];
|
||||
|
||||
for sentence_idx in 0..label_indices.size()[0] {
|
||||
|
||||
let label_index_tensor = label_indices.get(sentence_idx);
|
||||
let sentence_label = label_index_tensor.iter::<i64>().unwrap().collect::<Vec<i64>>();
|
||||
let sentence_label = label_index_tensor
|
||||
.iter::<i64>()
|
||||
.unwrap()
|
||||
.collect::<Vec<i64>>();
|
||||
let (sentence, id) = (sentence_label[0], sentence_label[1]);
|
||||
if sentence as usize > labels.len() {
|
||||
labels.push(sequence_labels);
|
||||
sequence_labels = vec!();
|
||||
sequence_labels = vec![];
|
||||
}
|
||||
let score = output.double_value(sentence_label.as_slice());
|
||||
let label_string = self.label_mapping.get(&id).unwrap().to_owned();
|
||||
@ -398,7 +462,6 @@ impl SequenceClassificationModel {
|
||||
sentence: sentence as usize,
|
||||
};
|
||||
sequence_labels.push(label);
|
||||
|
||||
}
|
||||
if sequence_labels.len() > 0 {
|
||||
labels.push(sequence_labels);
|
||||
|
@ -11,7 +11,6 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
|
||||
//! # Summarization pipeline
|
||||
//! Abstractive summarization of texts based on the BART encoder-decoder architecture
|
||||
//! Include techniques such as beam search, top-k and nucleus sampling, temperature setting and repetition penalty.
|
||||
@ -63,10 +62,12 @@
|
||||
//! # ;
|
||||
//! ```
|
||||
|
||||
use crate::bart::{
|
||||
BartConfigResources, BartMergesResources, BartModelResources, BartVocabResources,
|
||||
};
|
||||
use crate::common::resources::{RemoteResource, Resource};
|
||||
use crate::pipelines::generation::{BartGenerator, GenerateConfig, LanguageGenerator};
|
||||
use tch::Device;
|
||||
use crate::common::resources::{Resource, RemoteResource};
|
||||
use crate::bart::{BartModelResources, BartConfigResources, BartVocabResources, BartMergesResources};
|
||||
|
||||
/// # Configuration for text summarization
|
||||
/// Contains information regarding the model to load, mirrors the GenerationConfig, with a
|
||||
@ -111,10 +112,18 @@ pub struct SummarizationConfig {
|
||||
impl Default for SummarizationConfig {
|
||||
fn default() -> SummarizationConfig {
|
||||
SummarizationConfig {
|
||||
model_resource: Resource::Remote(RemoteResource::from_pretrained(BartModelResources::BART_CNN)),
|
||||
config_resource: Resource::Remote(RemoteResource::from_pretrained(BartConfigResources::BART_CNN)),
|
||||
vocab_resource: Resource::Remote(RemoteResource::from_pretrained(BartVocabResources::BART_CNN)),
|
||||
merges_resource: Resource::Remote(RemoteResource::from_pretrained(BartMergesResources::BART_CNN)),
|
||||
model_resource: Resource::Remote(RemoteResource::from_pretrained(
|
||||
BartModelResources::BART_CNN,
|
||||
)),
|
||||
config_resource: Resource::Remote(RemoteResource::from_pretrained(
|
||||
BartConfigResources::BART_CNN,
|
||||
)),
|
||||
vocab_resource: Resource::Remote(RemoteResource::from_pretrained(
|
||||
BartVocabResources::BART_CNN,
|
||||
)),
|
||||
merges_resource: Resource::Remote(RemoteResource::from_pretrained(
|
||||
BartMergesResources::BART_CNN,
|
||||
)),
|
||||
min_length: 56,
|
||||
max_length: 142,
|
||||
do_sample: false,
|
||||
@ -134,7 +143,7 @@ impl Default for SummarizationConfig {
|
||||
|
||||
/// # SummarizationModel to perform summarization
|
||||
pub struct SummarizationModel {
|
||||
model: BartGenerator
|
||||
model: BartGenerator,
|
||||
}
|
||||
|
||||
impl SummarizationModel {
|
||||
@ -154,9 +163,7 @@ impl SummarizationModel {
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
///
|
||||
pub fn new(summarization_config: SummarizationConfig)
|
||||
-> failure::Fallible<SummarizationModel> {
|
||||
pub fn new(summarization_config: SummarizationConfig) -> failure::Fallible<SummarizationModel> {
|
||||
let generate_config = GenerateConfig {
|
||||
model_resource: summarization_config.model_resource,
|
||||
config_resource: summarization_config.config_resource,
|
||||
@ -226,7 +233,6 @@ impl SummarizationModel {
|
||||
/// # }
|
||||
/// ```
|
||||
/// (New sample credits: [WikiNews](https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b))
|
||||
///
|
||||
pub fn summarize(&self, texts: &[&str]) -> Vec<String> {
|
||||
self.model.generate(Some(texts.to_vec()), None)
|
||||
}
|
||||
|
@ -47,34 +47,87 @@
|
||||
//! ```no_run
|
||||
//! # use rust_bert::pipelines::token_classification::Token;
|
||||
//! use rust_tokenizers::preprocessing::tokenizer::base_tokenizer::Mask::Special;
|
||||
//! use rust_tokenizers::preprocessing::tokenizer::base_tokenizer::{Offset, Mask};
|
||||
//! use rust_tokenizers::preprocessing::tokenizer::base_tokenizer::{Mask, Offset};
|
||||
//! # let output =
|
||||
//! [
|
||||
//! Token { text: String::from("[CLS]"), score: 0.9995001554489136, label: String::from("O"), label_index: 0, sentence: 0, index: 0, word_index: 0, offset: None, mask: Special },
|
||||
//! Token { text: String::from("My"), score: 0.9980450868606567, label: String::from("O"), label_index: 0, sentence: 0, index: 1, word_index: 1, offset: Some(Offset { begin: 0, end: 2 }), mask: Mask::None },
|
||||
//! Token { text: String::from("name"), score: 0.9995062351226807, label: String::from("O"), label_index: 0, sentence: 0, index: 2, word_index: 2, offset: Some(Offset { begin: 3, end: 7 }), mask: Mask::None },
|
||||
//! Token { text: String::from("is"), score: 0.9997343420982361, label: String::from("O"), label_index: 0, sentence: 0, index: 3, word_index: 3, offset: Some(Offset { begin: 8, end: 10 }), mask: Mask::None },
|
||||
//! Token { text: String::from("Amélie"), score: 0.9913727683112525, label: String::from("I-PER"), label_index: 4, sentence: 0, index: 4, word_index: 4, offset: Some(Offset { begin: 11, end: 17 }), mask: Mask::None }
|
||||
//! // ...
|
||||
//! Token {
|
||||
//! text: String::from("[CLS]"),
|
||||
//! score: 0.9995001554489136,
|
||||
//! label: String::from("O"),
|
||||
//! label_index: 0,
|
||||
//! sentence: 0,
|
||||
//! index: 0,
|
||||
//! word_index: 0,
|
||||
//! offset: None,
|
||||
//! mask: Special,
|
||||
//! },
|
||||
//! Token {
|
||||
//! text: String::from("My"),
|
||||
//! score: 0.9980450868606567,
|
||||
//! label: String::from("O"),
|
||||
//! label_index: 0,
|
||||
//! sentence: 0,
|
||||
//! index: 1,
|
||||
//! word_index: 1,
|
||||
//! offset: Some(Offset { begin: 0, end: 2 }),
|
||||
//! mask: Mask::None,
|
||||
//! },
|
||||
//! Token {
|
||||
//! text: String::from("name"),
|
||||
//! score: 0.9995062351226807,
|
||||
//! label: String::from("O"),
|
||||
//! label_index: 0,
|
||||
//! sentence: 0,
|
||||
//! index: 2,
|
||||
//! word_index: 2,
|
||||
//! offset: Some(Offset { begin: 3, end: 7 }),
|
||||
//! mask: Mask::None,
|
||||
//! },
|
||||
//! Token {
|
||||
//! text: String::from("is"),
|
||||
//! score: 0.9997343420982361,
|
||||
//! label: String::from("O"),
|
||||
//! label_index: 0,
|
||||
//! sentence: 0,
|
||||
//! index: 3,
|
||||
//! word_index: 3,
|
||||
//! offset: Some(Offset { begin: 8, end: 10 }),
|
||||
//! mask: Mask::None,
|
||||
//! },
|
||||
//! Token {
|
||||
//! text: String::from("Amélie"),
|
||||
//! score: 0.9913727683112525,
|
||||
//! label: String::from("I-PER"),
|
||||
//! label_index: 4,
|
||||
//! sentence: 0,
|
||||
//! index: 4,
|
||||
//! word_index: 4,
|
||||
//! offset: Some(Offset { begin: 11, end: 17 }),
|
||||
//! mask: Mask::None,
|
||||
//! }, // ...
|
||||
//! ]
|
||||
//! # ;
|
||||
//! ```
|
||||
|
||||
use tch::nn::VarStore;
|
||||
use rust_tokenizers::preprocessing::tokenizer::base_tokenizer::{Tokenizer, TokenizedInput, TruncationStrategy, Mask, Offset, ConsolidatableTokens, ConsolidatedTokenIterator, TokenTrait};
|
||||
use std::collections::HashMap;
|
||||
use tch::{Tensor, no_grad, Device};
|
||||
use tch::kind::Kind::Float;
|
||||
use crate::bert::{BertForTokenClassification, BertModelResources, BertConfigResources, BertVocabResources};
|
||||
use crate::roberta::RobertaForTokenClassification;
|
||||
use crate::bert::{
|
||||
BertConfigResources, BertForTokenClassification, BertModelResources, BertVocabResources,
|
||||
};
|
||||
use crate::common::resources::{download_resource, RemoteResource, Resource};
|
||||
use crate::distilbert::DistilBertForTokenClassification;
|
||||
use crate::common::resources::{Resource, RemoteResource, download_resource};
|
||||
use crate::pipelines::common::{ModelType, ConfigOption, TokenizerOption};
|
||||
use crate::electra::ElectraForTokenClassification;
|
||||
use crate::pipelines::common::{ConfigOption, ModelType, TokenizerOption};
|
||||
use crate::roberta::RobertaForTokenClassification;
|
||||
use itertools::Itertools;
|
||||
use rust_tokenizers::preprocessing::tokenizer::base_tokenizer::{
|
||||
ConsolidatableTokens, ConsolidatedTokenIterator, Mask, Offset, TokenTrait, TokenizedInput,
|
||||
Tokenizer, TruncationStrategy,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::cmp::min;
|
||||
use serde::{Serialize, Deserialize};
|
||||
|
||||
use std::collections::HashMap;
|
||||
use tch::kind::Kind::Float;
|
||||
use tch::nn::VarStore;
|
||||
use tch::{no_grad, Device, Tensor};
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
/// # Token generated by a `TokenClassificationModel`
|
||||
@ -140,7 +193,6 @@ pub enum LabelAggregationOption {
|
||||
Custom(Box<dyn Fn(&[Token]) -> (i64, String)>),
|
||||
}
|
||||
|
||||
|
||||
/// # Configuration for TokenClassificationModel
|
||||
/// Contains information regarding the model to load and device to place the model on.
|
||||
pub struct TokenClassificationConfig {
|
||||
@ -173,14 +225,15 @@ impl TokenClassificationConfig {
|
||||
/// * vocab - The `Resource' pointing to the tokenizer's vocabulary to load (e.g. vocab.txt/vocab.json)
|
||||
/// * vocab - An optional `Resource` tuple (`Option<Resource>`) pointing to the tokenizer's merge file to load (e.g. merges.txt), needed only for Roberta.
|
||||
/// * lower_case - A `bool' indicating whether the tokenizer should lower case all input (in case of a lower-cased model)
|
||||
///
|
||||
pub fn new(model_type: ModelType,
|
||||
pub fn new(
|
||||
model_type: ModelType,
|
||||
model_resource: Resource,
|
||||
config_resource: Resource,
|
||||
vocab_resource: Resource,
|
||||
merges_resource: Option<Resource>,
|
||||
lower_case: bool,
|
||||
label_aggregation_function: LabelAggregationOption) -> TokenClassificationConfig {
|
||||
label_aggregation_function: LabelAggregationOption,
|
||||
) -> TokenClassificationConfig {
|
||||
TokenClassificationConfig {
|
||||
model_type,
|
||||
model_resource,
|
||||
@ -199,9 +252,15 @@ impl Default for TokenClassificationConfig {
|
||||
fn default() -> TokenClassificationConfig {
|
||||
TokenClassificationConfig {
|
||||
model_type: ModelType::Bert,
|
||||
model_resource: Resource::Remote(RemoteResource::from_pretrained(BertModelResources::BERT_NER)),
|
||||
config_resource: Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT_NER)),
|
||||
vocab_resource: Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT_NER)),
|
||||
model_resource: Resource::Remote(RemoteResource::from_pretrained(
|
||||
BertModelResources::BERT_NER,
|
||||
)),
|
||||
config_resource: Resource::Remote(RemoteResource::from_pretrained(
|
||||
BertConfigResources::BERT_NER,
|
||||
)),
|
||||
vocab_resource: Resource::Remote(RemoteResource::from_pretrained(
|
||||
BertVocabResources::BERT_NER,
|
||||
)),
|
||||
merges_resource: None,
|
||||
lower_case: false,
|
||||
device: Device::cuda_if_available(),
|
||||
@ -231,7 +290,6 @@ impl TokenClassificationOption {
|
||||
/// * `p` - `tch::nn::Path` path to the model file to load (e.g. model.ot)
|
||||
/// * `config` - A configuration (the model type of the configuration must be compatible with the value for
|
||||
/// `model_type`)
|
||||
///
|
||||
pub fn new(model_type: ModelType, p: &tch::nn::Path, config: &ConfigOption) -> Self {
|
||||
match model_type {
|
||||
ModelType::Bert => {
|
||||
@ -243,21 +301,27 @@ impl TokenClassificationOption {
|
||||
}
|
||||
ModelType::DistilBert => {
|
||||
if let ConfigOption::DistilBert(config) = config {
|
||||
TokenClassificationOption::DistilBert(DistilBertForTokenClassification::new(p, config))
|
||||
TokenClassificationOption::DistilBert(DistilBertForTokenClassification::new(
|
||||
p, config,
|
||||
))
|
||||
} else {
|
||||
panic!("You can only supply a DistilBertConfig for DistilBert!");
|
||||
}
|
||||
}
|
||||
ModelType::Roberta => {
|
||||
if let ConfigOption::Bert(config) = config {
|
||||
TokenClassificationOption::Roberta(RobertaForTokenClassification::new(p, config))
|
||||
TokenClassificationOption::Roberta(RobertaForTokenClassification::new(
|
||||
p, config,
|
||||
))
|
||||
} else {
|
||||
panic!("You can only supply a BertConfig for Roberta!");
|
||||
}
|
||||
}
|
||||
ModelType::Electra => {
|
||||
if let ConfigOption::Electra(config) = config {
|
||||
TokenClassificationOption::Electra(ElectraForTokenClassification::new(p, config))
|
||||
TokenClassificationOption::Electra(ElectraForTokenClassification::new(
|
||||
p, config,
|
||||
))
|
||||
} else {
|
||||
panic!("You can only supply a BertConfig for Roberta!");
|
||||
}
|
||||
@ -271,27 +335,51 @@ impl TokenClassificationOption {
|
||||
Self::Bert(_) => ModelType::Bert,
|
||||
Self::Roberta(_) => ModelType::Roberta,
|
||||
Self::DistilBert(_) => ModelType::DistilBert,
|
||||
Self::Electra(_) => ModelType::Electra
|
||||
Self::Electra(_) => ModelType::Electra,
|
||||
}
|
||||
}
|
||||
|
||||
fn forward_t(&self,
|
||||
fn forward_t(
|
||||
&self,
|
||||
input_ids: Option<Tensor>,
|
||||
mask: Option<Tensor>,
|
||||
token_type_ids: Option<Tensor>,
|
||||
position_ids: Option<Tensor>,
|
||||
input_embeds: Option<Tensor>,
|
||||
train: bool) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
|
||||
train: bool,
|
||||
) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
|
||||
match *self {
|
||||
Self::Bert(ref model) => model.forward_t(input_ids, mask, token_type_ids, position_ids, input_embeds, train),
|
||||
Self::DistilBert(ref model) => model.forward_t(input_ids, mask, input_embeds, train).expect("Error in distilbert forward_t"),
|
||||
Self::Roberta(ref model) => model.forward_t(input_ids, mask, token_type_ids, position_ids, input_embeds, train),
|
||||
Self::Electra(ref model) => model.forward_t(input_ids, mask, token_type_ids, position_ids, input_embeds, train),
|
||||
Self::Bert(ref model) => model.forward_t(
|
||||
input_ids,
|
||||
mask,
|
||||
token_type_ids,
|
||||
position_ids,
|
||||
input_embeds,
|
||||
train,
|
||||
),
|
||||
Self::DistilBert(ref model) => model
|
||||
.forward_t(input_ids, mask, input_embeds, train)
|
||||
.expect("Error in distilbert forward_t"),
|
||||
Self::Roberta(ref model) => model.forward_t(
|
||||
input_ids,
|
||||
mask,
|
||||
token_type_ids,
|
||||
position_ids,
|
||||
input_embeds,
|
||||
train,
|
||||
),
|
||||
Self::Electra(ref model) => model.forward_t(
|
||||
input_ids,
|
||||
mask,
|
||||
token_type_ids,
|
||||
position_ids,
|
||||
input_embeds,
|
||||
train,
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/// # TokenClassificationModel for Named Entity Recognition or Part-of-Speech tagging
|
||||
pub struct TokenClassificationModel {
|
||||
tokenizer: TokenizerOption,
|
||||
@ -318,7 +406,6 @@ impl TokenClassificationModel {
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
///
|
||||
pub fn new(config: TokenClassificationConfig) -> failure::Fallible<TokenClassificationModel> {
|
||||
let config_path = download_resource(&config.config_resource)?;
|
||||
let vocab_path = download_resource(&config.vocab_resource)?;
|
||||
@ -331,32 +418,49 @@ impl TokenClassificationModel {
|
||||
let device = config.device;
|
||||
let label_aggregation_function = config.label_aggregation_function;
|
||||
|
||||
let tokenizer = TokenizerOption::from_file(config.model_type, vocab_path.to_str().unwrap(), merges_path.map(|path| path.to_str().unwrap()), config.lower_case);
|
||||
let tokenizer = TokenizerOption::from_file(
|
||||
config.model_type,
|
||||
vocab_path.to_str().unwrap(),
|
||||
merges_path.map(|path| path.to_str().unwrap()),
|
||||
config.lower_case,
|
||||
);
|
||||
let mut var_store = VarStore::new(device);
|
||||
let model_config = ConfigOption::from_file(config.model_type, config_path);
|
||||
let token_sequence_classifier = TokenClassificationOption::new(config.model_type, &var_store.root(), &model_config);
|
||||
let token_sequence_classifier =
|
||||
TokenClassificationOption::new(config.model_type, &var_store.root(), &model_config);
|
||||
let label_mapping = model_config.get_label_mapping();
|
||||
var_store.load(weights_path)?;
|
||||
Ok(TokenClassificationModel { tokenizer, token_sequence_classifier, label_mapping, var_store, label_aggregation_function })
|
||||
Ok(TokenClassificationModel {
|
||||
tokenizer,
|
||||
token_sequence_classifier,
|
||||
label_mapping,
|
||||
var_store,
|
||||
label_aggregation_function,
|
||||
})
|
||||
}
|
||||
|
||||
fn prepare_for_model(&self, input: Vec<&str>) -> (Vec<TokenizedInput>, Tensor) {
|
||||
let tokenized_input: Vec<TokenizedInput> = self.tokenizer.encode_list(input.to_vec(),
|
||||
128,
|
||||
&TruncationStrategy::LongestFirst,
|
||||
0);
|
||||
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
|
||||
let tokenized_input_tensors: Vec<tch::Tensor> = tokenized_input.
|
||||
iter().
|
||||
map(|input| input.token_ids.clone()).
|
||||
map(|mut input| {
|
||||
let tokenized_input: Vec<TokenizedInput> =
|
||||
self.tokenizer
|
||||
.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
||||
let max_len = tokenized_input
|
||||
.iter()
|
||||
.map(|input| input.token_ids.len())
|
||||
.max()
|
||||
.unwrap();
|
||||
let tokenized_input_tensors: Vec<tch::Tensor> = tokenized_input
|
||||
.iter()
|
||||
.map(|input| input.token_ids.clone())
|
||||
.map(|mut input| {
|
||||
input.extend(vec![0; max_len - input.len()]);
|
||||
input
|
||||
}).
|
||||
map(|input|
|
||||
Tensor::of_slice(&(input))).
|
||||
collect::<Vec<_>>();
|
||||
(tokenized_input, Tensor::stack(tokenized_input_tensors.as_slice(), 0).to(self.var_store.device()))
|
||||
})
|
||||
.map(|input| Tensor::of_slice(&(input)))
|
||||
.collect::<Vec<_>>();
|
||||
(
|
||||
tokenized_input,
|
||||
Tensor::stack(tokenized_input_tensors.as_slice(), 0).to(self.var_store.device()),
|
||||
)
|
||||
}
|
||||
|
||||
/// Classify tokens in a text sequence
|
||||
@ -380,27 +484,33 @@ impl TokenClassificationModel {
|
||||
/// let ner_model = TokenClassificationModel::new(Default::default())?;
|
||||
/// let input = [
|
||||
/// "My name is Amy. I live in Paris.",
|
||||
/// "Paris is a city in France."
|
||||
/// "Paris is a city in France.",
|
||||
/// ];
|
||||
/// let output = ner_model.predict(&input, true, true);
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
pub fn predict(&self, input: &[&str], consolidate_sub_tokens: bool, return_special: bool) -> Vec<Token> {
|
||||
pub fn predict(
|
||||
&self,
|
||||
input: &[&str],
|
||||
consolidate_sub_tokens: bool,
|
||||
return_special: bool,
|
||||
) -> Vec<Token> {
|
||||
let (tokenized_input, input_tensor) = self.prepare_for_model(input.to_vec());
|
||||
let (output, _, _) = no_grad(|| {
|
||||
self.token_sequence_classifier
|
||||
.forward_t(Some(input_tensor.copy()),
|
||||
self.token_sequence_classifier.forward_t(
|
||||
Some(input_tensor.copy()),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
false)
|
||||
false,
|
||||
)
|
||||
});
|
||||
let output = output.detach().to(Device::Cpu);
|
||||
let score: Tensor = output.exp() / output.exp().sum1(&[-1], true, Float);
|
||||
let labels_idx = &score.argmax(-1, true);
|
||||
let mut tokens: Vec<Token> = vec!();
|
||||
let mut tokens: Vec<Token> = vec![];
|
||||
for sentence_idx in 0..labels_idx.size()[0] {
|
||||
let labels = labels_idx.get(sentence_idx);
|
||||
let sentence_tokens = &tokenized_input[sentence_idx as usize];
|
||||
@ -415,7 +525,16 @@ impl TokenClassificationModel {
|
||||
word_idx += 1;
|
||||
}
|
||||
let token = {
|
||||
self.decode_token(&original_chars, sentence_tokens, &input_tensor, &labels, &score, sentence_idx, position_idx as i64, word_idx - 1)
|
||||
self.decode_token(
|
||||
&original_chars,
|
||||
sentence_tokens,
|
||||
&input_tensor,
|
||||
&labels,
|
||||
&score,
|
||||
sentence_idx,
|
||||
position_idx as i64,
|
||||
word_idx - 1,
|
||||
)
|
||||
};
|
||||
tokens.push(token);
|
||||
}
|
||||
@ -426,8 +545,17 @@ impl TokenClassificationModel {
|
||||
tokens
|
||||
}
|
||||
|
||||
fn decode_token(&self, original_sentence_chars: &Vec<char>, sentence_tokens: &TokenizedInput, input_tensor: &Tensor,
|
||||
labels: &Tensor, score: &Tensor, sentence_idx: i64, position_idx: i64, word_index: u16) -> Token {
|
||||
fn decode_token(
|
||||
&self,
|
||||
original_sentence_chars: &Vec<char>,
|
||||
sentence_tokens: &TokenizedInput,
|
||||
input_tensor: &Tensor,
|
||||
labels: &Tensor,
|
||||
score: &Tensor,
|
||||
sentence_idx: i64,
|
||||
position_idx: i64,
|
||||
word_index: u16,
|
||||
) -> Token {
|
||||
let label_id = labels.int64_value(&[position_idx as i64]);
|
||||
let token_id = input_tensor.int64_value(&[sentence_idx, position_idx as i64]);
|
||||
|
||||
@ -435,13 +563,19 @@ impl TokenClassificationModel {
|
||||
|
||||
let text = match offsets {
|
||||
None => match self.tokenizer {
|
||||
TokenizerOption::Bert(ref tokenizer) => Tokenizer::decode(tokenizer, vec!(token_id), false, false),
|
||||
TokenizerOption::Roberta(ref tokenizer) => Tokenizer::decode(tokenizer, vec!(token_id), false, false),
|
||||
TokenizerOption::Bert(ref tokenizer) => {
|
||||
Tokenizer::decode(tokenizer, vec![token_id], false, false)
|
||||
}
|
||||
TokenizerOption::Roberta(ref tokenizer) => {
|
||||
Tokenizer::decode(tokenizer, vec![token_id], false, false)
|
||||
}
|
||||
},
|
||||
Some(offsets) => {
|
||||
let (start_char, end_char) = (offsets.begin as usize, offsets.end as usize);
|
||||
let end_char = min(end_char, original_sentence_chars.len());
|
||||
let text = original_sentence_chars[start_char..end_char].iter().collect();
|
||||
let text = original_sentence_chars[start_char..end_char]
|
||||
.iter()
|
||||
.collect();
|
||||
text
|
||||
}
|
||||
};
|
||||
@ -449,7 +583,11 @@ impl TokenClassificationModel {
|
||||
Token {
|
||||
text,
|
||||
score: score.double_value(&[sentence_idx, position_idx, label_id]),
|
||||
label: self.label_mapping.get(&label_id).expect("Index out of vocabulary bounds.").to_owned(),
|
||||
label: self
|
||||
.label_mapping
|
||||
.get(&label_id)
|
||||
.expect("Index out of vocabulary bounds.")
|
||||
.to_owned(),
|
||||
label_index: label_id,
|
||||
sentence: sentence_idx as usize,
|
||||
index: position_idx as u16,
|
||||
@ -459,24 +597,29 @@ impl TokenClassificationModel {
|
||||
}
|
||||
}
|
||||
|
||||
fn consolidate_tokens(&self, tokens: &mut Vec<Token>, label_aggregation_function: &LabelAggregationOption) {
|
||||
let mut tokens_to_replace = vec!();
|
||||
fn consolidate_tokens(
|
||||
&self,
|
||||
tokens: &mut Vec<Token>,
|
||||
label_aggregation_function: &LabelAggregationOption,
|
||||
) {
|
||||
let mut tokens_to_replace = vec![];
|
||||
let mut token_iter = tokens.iter_consolidate_tokens();
|
||||
let mut cursor = 0;
|
||||
|
||||
while let Some(sub_tokens) = token_iter.next() {
|
||||
if sub_tokens.len() > 1 {
|
||||
let (label_index, label) = self.consolidate_labels(sub_tokens, label_aggregation_function);
|
||||
let (label_index, label) =
|
||||
self.consolidate_labels(sub_tokens, label_aggregation_function);
|
||||
let sentence = (&sub_tokens[0]).sentence;
|
||||
let index = (&sub_tokens[0]).index;
|
||||
let word_index = (&sub_tokens[0]).word_index;
|
||||
let offset_start = match &sub_tokens.first().unwrap().offset {
|
||||
Some(offset) => Some(offset.begin),
|
||||
None => None
|
||||
None => None,
|
||||
};
|
||||
let offset_end = match &sub_tokens.last().unwrap().offset {
|
||||
Some(offset) => Some(offset.end),
|
||||
None => None
|
||||
None => None,
|
||||
};
|
||||
let offset = if offset_start.is_some() & offset_end.is_some() {
|
||||
Some(Offset::new(offset_start.unwrap(), offset_end.unwrap()))
|
||||
@ -513,7 +656,11 @@ impl TokenClassificationModel {
|
||||
}
|
||||
}
|
||||
|
||||
fn consolidate_labels(&self, tokens: &[Token], aggregation: &LabelAggregationOption) -> (i64, String) {
|
||||
fn consolidate_labels(
|
||||
&self,
|
||||
tokens: &[Token],
|
||||
aggregation: &LabelAggregationOption,
|
||||
) -> (i64, String) {
|
||||
match aggregation {
|
||||
LabelAggregationOption::First => {
|
||||
let token = tokens.first().unwrap();
|
||||
@ -524,22 +671,17 @@ impl TokenClassificationModel {
|
||||
(token.label_index, token.label.clone())
|
||||
}
|
||||
LabelAggregationOption::Mode => {
|
||||
let counts = tokens
|
||||
.iter()
|
||||
.fold(
|
||||
HashMap::new(),
|
||||
|mut m, c| {
|
||||
let counts = tokens.iter().fold(HashMap::new(), |mut m, c| {
|
||||
*m.entry((c.label_index, c.label.as_str())).or_insert(0) += 1;
|
||||
m
|
||||
},
|
||||
);
|
||||
});
|
||||
counts
|
||||
.into_iter()
|
||||
.max_by(|a, b| a.1.cmp(&b.1))
|
||||
.map(|((label_index, label), _)| (label_index, label.to_owned()))
|
||||
.unwrap()
|
||||
}
|
||||
LabelAggregationOption::Custom(function) => function(tokens)
|
||||
LabelAggregationOption::Custom(function) => function(tokens),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -11,7 +11,6 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
|
||||
//! # Translation pipeline
|
||||
//! Translation based on the Marian encoder-decoder architecture
|
||||
//! Include techniques such as beam search, top-k and nucleus sampling, temperature setting and repetition penalty.
|
||||
@ -35,9 +34,10 @@
|
||||
//! ```no_run
|
||||
//! # fn main() -> failure::Fallible<()> {
|
||||
//! # use rust_bert::pipelines::generation::LanguageGenerator;
|
||||
//! use rust_bert::pipelines::translation::{TranslationModel, TranslationConfig, Language};
|
||||
//! use rust_bert::pipelines::translation::{Language, TranslationConfig, TranslationModel};
|
||||
//! use tch::Device;
|
||||
//! let translation_config = TranslationConfig::new(Language::EnglishToFrench, Device::cuda_if_available());
|
||||
//! let translation_config =
|
||||
//! TranslationConfig::new(Language::EnglishToFrench, Device::cuda_if_available());
|
||||
//! let mut model = TranslationModel::new(translation_config)?;
|
||||
//!
|
||||
//! let input = ["This is a sentence to be translated"];
|
||||
@ -54,10 +54,13 @@
|
||||
//! # ;
|
||||
//! ```
|
||||
|
||||
use crate::pipelines::generation::{MarianGenerator, GenerateConfig, LanguageGenerator};
|
||||
use crate::common::resources::{RemoteResource, Resource};
|
||||
use crate::marian::{
|
||||
MarianConfigResources, MarianModelResources, MarianPrefix, MarianSpmResources,
|
||||
MarianVocabResources,
|
||||
};
|
||||
use crate::pipelines::generation::{GenerateConfig, LanguageGenerator, MarianGenerator};
|
||||
use tch::Device;
|
||||
use crate::common::resources::{Resource, RemoteResource};
|
||||
use crate::marian::{MarianModelResources, MarianConfigResources, MarianVocabResources, MarianSpmResources, MarianPrefix};
|
||||
|
||||
/// Pretrained languages available for direct use
|
||||
pub enum Language {
|
||||
@ -84,47 +87,244 @@ pub enum Language {
|
||||
struct RemoteTranslationResources;
|
||||
|
||||
impl RemoteTranslationResources {
|
||||
pub const ENGLISH2FRENCH: ((&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), Option<&'static str>) =
|
||||
(MarianModelResources::ENGLISH2ROMANCE, MarianConfigResources::ENGLISH2ROMANCE, MarianVocabResources::ENGLISH2ROMANCE, MarianSpmResources::ENGLISH2ROMANCE, MarianPrefix::ENGLISH2FRENCH);
|
||||
pub const ENGLISH2CATALAN: ((&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), Option<&'static str>) =
|
||||
(MarianModelResources::ENGLISH2ROMANCE, MarianConfigResources::ENGLISH2ROMANCE, MarianVocabResources::ENGLISH2ROMANCE, MarianSpmResources::ENGLISH2ROMANCE, MarianPrefix::ENGLISH2CATALAN);
|
||||
pub const ENGLISH2SPANISH: ((&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), Option<&'static str>) =
|
||||
(MarianModelResources::ENGLISH2ROMANCE, MarianConfigResources::ENGLISH2ROMANCE, MarianVocabResources::ENGLISH2ROMANCE, MarianSpmResources::ENGLISH2ROMANCE, MarianPrefix::ENGLISH2SPANISH);
|
||||
pub const ENGLISH2PORTUGUESE: ((&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), Option<&'static str>) =
|
||||
(MarianModelResources::ENGLISH2ROMANCE, MarianConfigResources::ENGLISH2ROMANCE, MarianVocabResources::ENGLISH2ROMANCE, MarianSpmResources::ENGLISH2ROMANCE, MarianPrefix::ENGLISH2PORTUGUESE);
|
||||
pub const ENGLISH2ITALIAN: ((&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), Option<&'static str>) =
|
||||
(MarianModelResources::ENGLISH2ROMANCE, MarianConfigResources::ENGLISH2ROMANCE, MarianVocabResources::ENGLISH2ROMANCE, MarianSpmResources::ENGLISH2ROMANCE, MarianPrefix::ENGLISH2ITALIAN);
|
||||
pub const ENGLISH2ROMANIAN: ((&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), Option<&'static str>) =
|
||||
(MarianModelResources::ENGLISH2ROMANCE, MarianConfigResources::ENGLISH2ROMANCE, MarianVocabResources::ENGLISH2ROMANCE, MarianSpmResources::ENGLISH2ROMANCE, MarianPrefix::ENGLISH2ROMANIAN);
|
||||
pub const ENGLISH2GERMAN: ((&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), Option<&'static str>) =
|
||||
(MarianModelResources::ENGLISH2GERMAN, MarianConfigResources::ENGLISH2GERMAN, MarianVocabResources::ENGLISH2GERMAN, MarianSpmResources::ENGLISH2GERMAN, MarianPrefix::ENGLISH2GERMAN);
|
||||
pub const ENGLISH2RUSSIAN: ((&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), Option<&'static str>) =
|
||||
(MarianModelResources::ENGLISH2RUSSIAN, MarianConfigResources::ENGLISH2RUSSIAN, MarianVocabResources::ENGLISH2RUSSIAN, MarianSpmResources::ENGLISH2RUSSIAN, MarianPrefix::ENGLISH2RUSSIAN);
|
||||
pub const ENGLISH2FRENCH: (
|
||||
(&'static str, &'static str),
|
||||
(&'static str, &'static str),
|
||||
(&'static str, &'static str),
|
||||
(&'static str, &'static str),
|
||||
Option<&'static str>,
|
||||
) = (
|
||||
MarianModelResources::ENGLISH2ROMANCE,
|
||||
MarianConfigResources::ENGLISH2ROMANCE,
|
||||
MarianVocabResources::ENGLISH2ROMANCE,
|
||||
MarianSpmResources::ENGLISH2ROMANCE,
|
||||
MarianPrefix::ENGLISH2FRENCH,
|
||||
);
|
||||
pub const ENGLISH2CATALAN: (
|
||||
(&'static str, &'static str),
|
||||
(&'static str, &'static str),
|
||||
(&'static str, &'static str),
|
||||
(&'static str, &'static str),
|
||||
Option<&'static str>,
|
||||
) = (
|
||||
MarianModelResources::ENGLISH2ROMANCE,
|
||||
MarianConfigResources::ENGLISH2ROMANCE,
|
||||
MarianVocabResources::ENGLISH2ROMANCE,
|
||||
MarianSpmResources::ENGLISH2ROMANCE,
|
||||
MarianPrefix::ENGLISH2CATALAN,
|
||||
);
|
||||
pub const ENGLISH2SPANISH: (
|
||||
(&'static str, &'static str),
|
||||
(&'static str, &'static str),
|
||||
(&'static str, &'static str),
|
||||
(&'static str, &'static str),
|
||||
Option<&'static str>,
|
||||
) = (
|
||||
MarianModelResources::ENGLISH2ROMANCE,
|
||||
MarianConfigResources::ENGLISH2ROMANCE,
|
||||
MarianVocabResources::ENGLISH2ROMANCE,
|
||||
MarianSpmResources::ENGLISH2ROMANCE,
|
||||
MarianPrefix::ENGLISH2SPANISH,
|
||||
);
|
||||
pub const ENGLISH2PORTUGUESE: (
|
||||
(&'static str, &'static str),
|
||||
(&'static str, &'static str),
|
||||
(&'static str, &'static str),
|
||||
(&'static str, &'static str),
|
||||
Option<&'static str>,
|
||||
) = (
|
||||
MarianModelResources::ENGLISH2ROMANCE,
|
||||
MarianConfigResources::ENGLISH2ROMANCE,
|
||||
MarianVocabResources::ENGLISH2ROMANCE,
|
||||
MarianSpmResources::ENGLISH2ROMANCE,
|
||||
MarianPrefix::ENGLISH2PORTUGUESE,
|
||||
);
|
||||
pub const ENGLISH2ITALIAN: (
|
||||
(&'static str, &'static str),
|
||||
(&'static str, &'static str),
|
||||
(&'static str, &'static str),
|
||||
(&'static str, &'static str),
|
||||
Option<&'static str>,
|
||||
) = (
|
||||
MarianModelResources::ENGLISH2ROMANCE,
|
||||
MarianConfigResources::ENGLISH2ROMANCE,
|
||||
MarianVocabResources::ENGLISH2ROMANCE,
|
||||
MarianSpmResources::ENGLISH2ROMANCE,
|
||||
MarianPrefix::ENGLISH2ITALIAN,
|
||||
);
|
||||
pub const ENGLISH2ROMANIAN: (
|
||||
(&'static str, &'static str),
|
||||
(&'static str, &'static str),
|
||||
(&'static str, &'static str),
|
||||
(&'static str, &'static str),
|
||||
Option<&'static str>,
|
||||
) = (
|
||||
MarianModelResources::ENGLISH2ROMANCE,
|
||||
MarianConfigResources::ENGLISH2ROMANCE,
|
||||
MarianVocabResources::ENGLISH2ROMANCE,
|
||||
MarianSpmResources::ENGLISH2ROMANCE,
|
||||
MarianPrefix::ENGLISH2ROMANIAN,
|
||||
);
|
||||
pub const ENGLISH2GERMAN: (
|
||||
(&'static str, &'static str),
|
||||
(&'static str, &'static str),
|
||||
(&'static str, &'static str),
|
||||
(&'static str, &'static str),
|
||||
Option<&'static str>,
|
||||
) = (
|
||||
MarianModelResources::ENGLISH2GERMAN,
|
||||
MarianConfigResources::ENGLISH2GERMAN,
|
||||
MarianVocabResources::ENGLISH2GERMAN,
|
||||
MarianSpmResources::ENGLISH2GERMAN,
|
||||
MarianPrefix::ENGLISH2GERMAN,
|
||||
);
|
||||
pub const ENGLISH2RUSSIAN: (
|
||||
(&'static str, &'static str),
|
||||
(&'static str, &'static str),
|
||||
(&'static str, &'static str),
|
||||
(&'static str, &'static str),
|
||||
Option<&'static str>,
|
||||
) = (
|
||||
MarianModelResources::ENGLISH2RUSSIAN,
|
||||
MarianConfigResources::ENGLISH2RUSSIAN,
|
||||
MarianVocabResources::ENGLISH2RUSSIAN,
|
||||
MarianSpmResources::ENGLISH2RUSSIAN,
|
||||
MarianPrefix::ENGLISH2RUSSIAN,
|
||||
);
|
||||
|
||||
pub const FRENCH2ENGLISH: ((&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), Option<&'static str>) =
|
||||
(MarianModelResources::ROMANCE2ENGLISH, MarianConfigResources::ROMANCE2ENGLISH, MarianVocabResources::ROMANCE2ENGLISH, MarianSpmResources::ROMANCE2ENGLISH, MarianPrefix::FRENCH2ENGLISH);
|
||||
pub const CATALAN2ENGLISH: ((&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), Option<&'static str>) =
|
||||
(MarianModelResources::ROMANCE2ENGLISH, MarianConfigResources::ROMANCE2ENGLISH, MarianVocabResources::ROMANCE2ENGLISH, MarianSpmResources::ROMANCE2ENGLISH, MarianPrefix::CATALAN2ENGLISH);
|
||||
pub const SPANISH2ENGLISH: ((&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), Option<&'static str>) =
|
||||
(MarianModelResources::ROMANCE2ENGLISH, MarianConfigResources::ROMANCE2ENGLISH, MarianVocabResources::ROMANCE2ENGLISH, MarianSpmResources::ROMANCE2ENGLISH, MarianPrefix::SPANISH2ENGLISH);
|
||||
pub const PORTUGUESE2ENGLISH: ((&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), Option<&'static str>) =
|
||||
(MarianModelResources::ROMANCE2ENGLISH, MarianConfigResources::ROMANCE2ENGLISH, MarianVocabResources::ROMANCE2ENGLISH, MarianSpmResources::ROMANCE2ENGLISH, MarianPrefix::PORTUGUESE2ENGLISH);
|
||||
pub const ITALIAN2ENGLISH: ((&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), Option<&'static str>) =
|
||||
(MarianModelResources::ROMANCE2ENGLISH, MarianConfigResources::ROMANCE2ENGLISH, MarianVocabResources::ROMANCE2ENGLISH, MarianSpmResources::ROMANCE2ENGLISH, MarianPrefix::ITALIAN2ENGLISH);
|
||||
pub const ROMANIAN2ENGLISH: ((&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), Option<&'static str>) =
|
||||
(MarianModelResources::ROMANCE2ENGLISH, MarianConfigResources::ROMANCE2ENGLISH, MarianVocabResources::ROMANCE2ENGLISH, MarianSpmResources::ROMANCE2ENGLISH, MarianPrefix::ROMANIAN2ENGLISH);
|
||||
pub const GERMAN2ENGLISH: ((&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), Option<&'static str>) =
|
||||
(MarianModelResources::GERMAN2ENGLISH, MarianConfigResources::GERMAN2ENGLISH, MarianVocabResources::GERMAN2ENGLISH, MarianSpmResources::GERMAN2ENGLISH, MarianPrefix::GERMAN2ENGLISH);
|
||||
pub const RUSSIAN2ENGLISH: ((&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), Option<&'static str>) =
|
||||
(MarianModelResources::RUSSIAN2ENGLISH, MarianConfigResources::RUSSIAN2ENGLISH, MarianVocabResources::RUSSIAN2ENGLISH, MarianSpmResources::RUSSIAN2ENGLISH, MarianPrefix::RUSSIAN2ENGLISH);
|
||||
pub const FRENCH2ENGLISH: (
|
||||
(&'static str, &'static str),
|
||||
(&'static str, &'static str),
|
||||
(&'static str, &'static str),
|
||||
(&'static str, &'static str),
|
||||
Option<&'static str>,
|
||||
) = (
|
||||
MarianModelResources::ROMANCE2ENGLISH,
|
||||
MarianConfigResources::ROMANCE2ENGLISH,
|
||||
MarianVocabResources::ROMANCE2ENGLISH,
|
||||
MarianSpmResources::ROMANCE2ENGLISH,
|
||||
MarianPrefix::FRENCH2ENGLISH,
|
||||
);
|
||||
pub const CATALAN2ENGLISH: (
|
||||
(&'static str, &'static str),
|
||||
(&'static str, &'static str),
|
||||
(&'static str, &'static str),
|
||||
(&'static str, &'static str),
|
||||
Option<&'static str>,
|
||||
) = (
|
||||
MarianModelResources::ROMANCE2ENGLISH,
|
||||
MarianConfigResources::ROMANCE2ENGLISH,
|
||||
MarianVocabResources::ROMANCE2ENGLISH,
|
||||
MarianSpmResources::ROMANCE2ENGLISH,
|
||||
MarianPrefix::CATALAN2ENGLISH,
|
||||
);
|
||||
pub const SPANISH2ENGLISH: (
|
||||
(&'static str, &'static str),
|
||||
(&'static str, &'static str),
|
||||
(&'static str, &'static str),
|
||||
(&'static str, &'static str),
|
||||
Option<&'static str>,
|
||||
) = (
|
||||
MarianModelResources::ROMANCE2ENGLISH,
|
||||
MarianConfigResources::ROMANCE2ENGLISH,
|
||||
MarianVocabResources::ROMANCE2ENGLISH,
|
||||
MarianSpmResources::ROMANCE2ENGLISH,
|
||||
MarianPrefix::SPANISH2ENGLISH,
|
||||
);
|
||||
pub const PORTUGUESE2ENGLISH: (
|
||||
(&'static str, &'static str),
|
||||
(&'static str, &'static str),
|
||||
(&'static str, &'static str),
|
||||
(&'static str, &'static str),
|
||||
Option<&'static str>,
|
||||
) = (
|
||||
MarianModelResources::ROMANCE2ENGLISH,
|
||||
MarianConfigResources::ROMANCE2ENGLISH,
|
||||
MarianVocabResources::ROMANCE2ENGLISH,
|
||||
MarianSpmResources::ROMANCE2ENGLISH,
|
||||
MarianPrefix::PORTUGUESE2ENGLISH,
|
||||
);
|
||||
pub const ITALIAN2ENGLISH: (
|
||||
(&'static str, &'static str),
|
||||
(&'static str, &'static str),
|
||||
(&'static str, &'static str),
|
||||
(&'static str, &'static str),
|
||||
Option<&'static str>,
|
||||
) = (
|
||||
MarianModelResources::ROMANCE2ENGLISH,
|
||||
MarianConfigResources::ROMANCE2ENGLISH,
|
||||
MarianVocabResources::ROMANCE2ENGLISH,
|
||||
MarianSpmResources::ROMANCE2ENGLISH,
|
||||
MarianPrefix::ITALIAN2ENGLISH,
|
||||
);
|
||||
pub const ROMANIAN2ENGLISH: (
|
||||
(&'static str, &'static str),
|
||||
(&'static str, &'static str),
|
||||
(&'static str, &'static str),
|
||||
(&'static str, &'static str),
|
||||
Option<&'static str>,
|
||||
) = (
|
||||
MarianModelResources::ROMANCE2ENGLISH,
|
||||
MarianConfigResources::ROMANCE2ENGLISH,
|
||||
MarianVocabResources::ROMANCE2ENGLISH,
|
||||
MarianSpmResources::ROMANCE2ENGLISH,
|
||||
MarianPrefix::ROMANIAN2ENGLISH,
|
||||
);
|
||||
pub const GERMAN2ENGLISH: (
|
||||
(&'static str, &'static str),
|
||||
(&'static str, &'static str),
|
||||
(&'static str, &'static str),
|
||||
(&'static str, &'static str),
|
||||
Option<&'static str>,
|
||||
) = (
|
||||
MarianModelResources::GERMAN2ENGLISH,
|
||||
MarianConfigResources::GERMAN2ENGLISH,
|
||||
MarianVocabResources::GERMAN2ENGLISH,
|
||||
MarianSpmResources::GERMAN2ENGLISH,
|
||||
MarianPrefix::GERMAN2ENGLISH,
|
||||
);
|
||||
pub const RUSSIAN2ENGLISH: (
|
||||
(&'static str, &'static str),
|
||||
(&'static str, &'static str),
|
||||
(&'static str, &'static str),
|
||||
(&'static str, &'static str),
|
||||
Option<&'static str>,
|
||||
) = (
|
||||
MarianModelResources::RUSSIAN2ENGLISH,
|
||||
MarianConfigResources::RUSSIAN2ENGLISH,
|
||||
MarianVocabResources::RUSSIAN2ENGLISH,
|
||||
MarianSpmResources::RUSSIAN2ENGLISH,
|
||||
MarianPrefix::RUSSIAN2ENGLISH,
|
||||
);
|
||||
|
||||
pub const FRENCH2GERMAN: ((&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), Option<&'static str>) =
|
||||
(MarianModelResources::FRENCH2GERMAN, MarianConfigResources::FRENCH2GERMAN, MarianVocabResources::FRENCH2GERMAN, MarianSpmResources::FRENCH2GERMAN, MarianPrefix::FRENCH2GERMAN);
|
||||
pub const GERMAN2FRENCH: ((&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), Option<&'static str>) =
|
||||
(MarianModelResources::GERMAN2FRENCH, MarianConfigResources::GERMAN2FRENCH, MarianVocabResources::GERMAN2FRENCH, MarianSpmResources::GERMAN2FRENCH, MarianPrefix::GERMAN2FRENCH);
|
||||
pub const FRENCH2GERMAN: (
|
||||
(&'static str, &'static str),
|
||||
(&'static str, &'static str),
|
||||
(&'static str, &'static str),
|
||||
(&'static str, &'static str),
|
||||
Option<&'static str>,
|
||||
) = (
|
||||
MarianModelResources::FRENCH2GERMAN,
|
||||
MarianConfigResources::FRENCH2GERMAN,
|
||||
MarianVocabResources::FRENCH2GERMAN,
|
||||
MarianSpmResources::FRENCH2GERMAN,
|
||||
MarianPrefix::FRENCH2GERMAN,
|
||||
);
|
||||
pub const GERMAN2FRENCH: (
|
||||
(&'static str, &'static str),
|
||||
(&'static str, &'static str),
|
||||
(&'static str, &'static str),
|
||||
(&'static str, &'static str),
|
||||
Option<&'static str>,
|
||||
) = (
|
||||
MarianModelResources::GERMAN2FRENCH,
|
||||
MarianConfigResources::GERMAN2FRENCH,
|
||||
MarianVocabResources::GERMAN2FRENCH,
|
||||
MarianSpmResources::GERMAN2FRENCH,
|
||||
MarianPrefix::GERMAN2FRENCH,
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
/// # Configuration for text translation
|
||||
/// Contains information regarding the model to load, mirrors the GenerationConfig, with a
|
||||
/// different set of default parameters and sets the device to place the model on.
|
||||
@ -179,16 +379,17 @@ impl TranslationConfig {
|
||||
///
|
||||
/// ```no_run
|
||||
/// # fn main() -> failure::Fallible<()> {
|
||||
/// use rust_bert::pipelines::translation::{TranslationConfig, Language};
|
||||
/// use rust_bert::pipelines::translation::{Language, TranslationConfig};
|
||||
/// use tch::Device;
|
||||
///
|
||||
/// let translation_config = TranslationConfig::new(Language::FrenchToEnglish, Device::cuda_if_available());
|
||||
/// let translation_config =
|
||||
/// TranslationConfig::new(Language::FrenchToEnglish, Device::cuda_if_available());
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
///
|
||||
pub fn new(language: Language, device: Device) -> TranslationConfig {
|
||||
let (model_resource, config_resource, vocab_resource, merges_resource, prefix) = match language {
|
||||
let (model_resource, config_resource, vocab_resource, merges_resource, prefix) =
|
||||
match language {
|
||||
Language::EnglishToFrench => RemoteTranslationResources::ENGLISH2FRENCH,
|
||||
Language::EnglishToCatalan => RemoteTranslationResources::ENGLISH2CATALAN,
|
||||
Language::EnglishToSpanish => RemoteTranslationResources::ENGLISH2SPANISH,
|
||||
@ -216,7 +417,7 @@ impl TranslationConfig {
|
||||
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(merges_resource));
|
||||
let prefix = match prefix {
|
||||
Some(value) => Some(value.to_string()),
|
||||
None => None
|
||||
None => None,
|
||||
};
|
||||
TranslationConfig {
|
||||
model_resource,
|
||||
@ -255,31 +456,42 @@ impl TranslationConfig {
|
||||
/// ```no_run
|
||||
/// # fn main() -> failure::Fallible<()> {
|
||||
/// use rust_bert::pipelines::translation::TranslationConfig;
|
||||
/// use tch::Device;
|
||||
/// use rust_bert::resources::{Resource, LocalResource};
|
||||
/// use rust_bert::resources::{LocalResource, Resource};
|
||||
/// use std::path::PathBuf;
|
||||
/// use tch::Device;
|
||||
///
|
||||
/// let config_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/config.json") });
|
||||
/// let model_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/model.ot") });
|
||||
/// let vocab_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/vocab.json") });
|
||||
/// let sentence_piece_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/spiece.model") });
|
||||
/// let config_resource = Resource::Local(LocalResource {
|
||||
/// local_path: PathBuf::from("path/to/config.json"),
|
||||
/// });
|
||||
/// let model_resource = Resource::Local(LocalResource {
|
||||
/// local_path: PathBuf::from("path/to/model.ot"),
|
||||
/// });
|
||||
/// let vocab_resource = Resource::Local(LocalResource {
|
||||
/// local_path: PathBuf::from("path/to/vocab.json"),
|
||||
/// });
|
||||
/// let sentence_piece_resource = Resource::Local(LocalResource {
|
||||
/// local_path: PathBuf::from("path/to/spiece.model"),
|
||||
/// });
|
||||
///
|
||||
/// let translation_config = TranslationConfig::new_from_resources(model_resource,
|
||||
/// let translation_config = TranslationConfig::new_from_resources(
|
||||
/// model_resource,
|
||||
/// config_resource,
|
||||
/// vocab_resource,
|
||||
/// sentence_piece_resource,
|
||||
/// Some(">>fr<<".to_string()),
|
||||
/// Device::cuda_if_available());
|
||||
/// Device::cuda_if_available(),
|
||||
/// );
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
///
|
||||
pub fn new_from_resources(model_resource: Resource,
|
||||
pub fn new_from_resources(
|
||||
model_resource: Resource,
|
||||
config_resource: Resource,
|
||||
vocab_resource: Resource,
|
||||
sentence_piece_resource: Resource,
|
||||
prefix: Option<String>,
|
||||
device: Device) -> TranslationConfig {
|
||||
device: Device,
|
||||
) -> TranslationConfig {
|
||||
TranslationConfig {
|
||||
model_resource,
|
||||
config_resource,
|
||||
@ -320,17 +532,16 @@ impl TranslationModel {
|
||||
///
|
||||
/// ```no_run
|
||||
/// # fn main() -> failure::Fallible<()> {
|
||||
/// use rust_bert::pipelines::translation::{TranslationModel, TranslationConfig, Language};
|
||||
/// use rust_bert::pipelines::translation::{Language, TranslationConfig, TranslationModel};
|
||||
/// use tch::Device;
|
||||
///
|
||||
/// let translation_config = TranslationConfig::new(Language::FrenchToEnglish, Device::cuda_if_available());
|
||||
/// let translation_config =
|
||||
/// TranslationConfig::new(Language::FrenchToEnglish, Device::cuda_if_available());
|
||||
/// let mut summarization_model = TranslationModel::new(translation_config)?;
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
///
|
||||
pub fn new(translation_config: TranslationConfig)
|
||||
-> failure::Fallible<TranslationModel> {
|
||||
pub fn new(translation_config: TranslationConfig) -> failure::Fallible<TranslationModel> {
|
||||
let generate_config = GenerateConfig {
|
||||
model_resource: translation_config.model_resource,
|
||||
config_resource: translation_config.config_resource,
|
||||
@ -353,7 +564,10 @@ impl TranslationModel {
|
||||
|
||||
let model = MarianGenerator::new(generate_config)?;
|
||||
|
||||
Ok(TranslationModel { model, prefix: translation_config.prefix })
|
||||
Ok(TranslationModel {
|
||||
model,
|
||||
prefix: translation_config.prefix,
|
||||
})
|
||||
}
|
||||
|
||||
/// Translates texts provided
|
||||
@ -370,10 +584,11 @@ impl TranslationModel {
|
||||
/// ```no_run
|
||||
/// # fn main() -> failure::Fallible<()> {
|
||||
/// use rust_bert::pipelines::generation::LanguageGenerator;
|
||||
/// use rust_bert::pipelines::translation::{TranslationModel, TranslationConfig, Language};
|
||||
/// use rust_bert::pipelines::translation::{Language, TranslationConfig, TranslationModel};
|
||||
/// use tch::Device;
|
||||
///
|
||||
/// let translation_config = TranslationConfig::new(Language::EnglishToFrench, Device::cuda_if_available());
|
||||
/// let translation_config =
|
||||
/// TranslationConfig::new(Language::EnglishToFrench, Device::cuda_if_available());
|
||||
/// let model = TranslationModel::new(translation_config)?;
|
||||
///
|
||||
/// let input = ["This is a sentence to be translated"];
|
||||
@ -382,17 +597,17 @@ impl TranslationModel {
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
///
|
||||
pub fn translate(&self, texts: &[&str]) -> Vec<String> {
|
||||
match &self.prefix {
|
||||
Some(value) => {
|
||||
let texts: Vec<String> = texts
|
||||
.into_iter()
|
||||
.map(|&v| { format!("{} {}", value, v) })
|
||||
.map(|&v| format!("{} {}", value, v))
|
||||
.collect();
|
||||
self.model.generate(Some(texts.iter().map(AsRef::as_ref).collect()), None)
|
||||
self.model
|
||||
.generate(Some(texts.iter().map(AsRef::as_ref).collect()), None)
|
||||
}
|
||||
None => self.model.generate(Some(texts.to_vec()), None)
|
||||
None => self.model.generate(Some(texts.to_vec()), None),
|
||||
}
|
||||
}
|
||||
}
|
@ -11,10 +11,10 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use tch::{nn, Tensor, Kind};
|
||||
use crate::common::dropout::Dropout;
|
||||
use tch::nn::{EmbeddingConfig, embedding};
|
||||
use crate::bert::{BertConfig, BertEmbedding};
|
||||
use crate::common::dropout::Dropout;
|
||||
use tch::nn::{embedding, EmbeddingConfig};
|
||||
use tch::{nn, Kind, Tensor};
|
||||
|
||||
#[derive(Debug)]
|
||||
/// # BertEmbeddings implementation for RoBERTa model
|
||||
@ -36,8 +36,12 @@ impl RobertaEmbeddings {
|
||||
|
||||
fn create_position_ids_from_embeddings(&self, x: &Tensor) -> Tensor {
|
||||
let input_shape = x.size();
|
||||
let input_shape = vec!(input_shape[0], input_shape[1]);
|
||||
let position_ids: Tensor = Tensor::arange1(self.padding_index + 1, input_shape[0], (Kind::Int64, x.device()));
|
||||
let input_shape = vec![input_shape[0], input_shape[1]];
|
||||
let position_ids: Tensor = Tensor::arange1(
|
||||
self.padding_index + 1,
|
||||
input_shape[0],
|
||||
(Kind::Int64, x.device()),
|
||||
);
|
||||
position_ids.unsqueeze(0).expand(&input_shape, true)
|
||||
}
|
||||
}
|
||||
@ -54,10 +58,10 @@ impl BertEmbedding for RobertaEmbeddings {
|
||||
///
|
||||
/// ```no_run
|
||||
/// use rust_bert::bert::{BertConfig, BertEmbedding};
|
||||
/// use tch::{nn, Device};
|
||||
/// use rust_bert::roberta::RobertaEmbeddings;
|
||||
/// use rust_bert::Config;
|
||||
/// use std::path::Path;
|
||||
/// use rust_bert::roberta::RobertaEmbeddings;
|
||||
/// use tch::{nn, Device};
|
||||
///
|
||||
/// let config_path = Path::new("path/to/config.json");
|
||||
/// let device = Device::Cpu;
|
||||
@ -65,29 +69,48 @@ impl BertEmbedding for RobertaEmbeddings {
|
||||
/// let config = BertConfig::from_file(config_path);
|
||||
/// let robert_embeddings = RobertaEmbeddings::new(&(&p.root() / "bert_embeddings"), &config);
|
||||
/// ```
|
||||
///
|
||||
fn new(p: &nn::Path, config: &BertConfig) -> RobertaEmbeddings {
|
||||
let embedding_config = EmbeddingConfig { padding_idx: 1, ..Default::default() };
|
||||
let embedding_config = EmbeddingConfig {
|
||||
padding_idx: 1,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let word_embeddings: nn::Embedding = embedding(p / "word_embeddings",
|
||||
let word_embeddings: nn::Embedding = embedding(
|
||||
p / "word_embeddings",
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
embedding_config);
|
||||
embedding_config,
|
||||
);
|
||||
|
||||
let position_embeddings: nn::Embedding = embedding(p / "position_embeddings",
|
||||
let position_embeddings: nn::Embedding = embedding(
|
||||
p / "position_embeddings",
|
||||
config.max_position_embeddings,
|
||||
config.hidden_size,
|
||||
Default::default());
|
||||
Default::default(),
|
||||
);
|
||||
|
||||
let token_type_embeddings: nn::Embedding = embedding(p / "token_type_embeddings",
|
||||
let token_type_embeddings: nn::Embedding = embedding(
|
||||
p / "token_type_embeddings",
|
||||
config.type_vocab_size,
|
||||
config.hidden_size,
|
||||
Default::default());
|
||||
Default::default(),
|
||||
);
|
||||
|
||||
let layer_norm_config = nn::LayerNormConfig { eps: 1e-12, ..Default::default() };
|
||||
let layer_norm: nn::LayerNorm = nn::layer_norm(p / "LayerNorm", vec![config.hidden_size], layer_norm_config);
|
||||
let layer_norm_config = nn::LayerNormConfig {
|
||||
eps: 1e-12,
|
||||
..Default::default()
|
||||
};
|
||||
let layer_norm: nn::LayerNorm =
|
||||
nn::layer_norm(p / "LayerNorm", vec![config.hidden_size], layer_norm_config);
|
||||
let dropout: Dropout = Dropout::new(config.hidden_dropout_prob);
|
||||
RobertaEmbeddings { word_embeddings, position_embeddings, token_type_embeddings, layer_norm, dropout, padding_index: 1 }
|
||||
RobertaEmbeddings {
|
||||
word_embeddings,
|
||||
position_embeddings,
|
||||
token_type_embeddings,
|
||||
layer_norm,
|
||||
dropout,
|
||||
padding_index: 1,
|
||||
}
|
||||
}
|
||||
|
||||
/// Forward pass through the embedding layer.
|
||||
@ -123,53 +146,67 @@ impl BertEmbedding for RobertaEmbeddings {
|
||||
/// let (batch_size, sequence_length) = (64, 128);
|
||||
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
|
||||
/// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
|
||||
/// let position_ids = Tensor::arange(sequence_length, (Int64, device)).expand(&[batch_size, sequence_length], true);
|
||||
/// let position_ids = Tensor::arange(sequence_length, (Int64, device))
|
||||
/// .expand(&[batch_size, sequence_length], true);
|
||||
///
|
||||
/// let embedded_output = no_grad(|| {
|
||||
/// roberta_embeddings
|
||||
/// .forward_t(Some(input_tensor),
|
||||
/// .forward_t(
|
||||
/// Some(input_tensor),
|
||||
/// Some(token_type_ids),
|
||||
/// Some(position_ids),
|
||||
/// None,
|
||||
/// false).unwrap()
|
||||
/// false,
|
||||
/// )
|
||||
/// .unwrap()
|
||||
/// });
|
||||
/// ```
|
||||
///
|
||||
fn forward_t(&self,
|
||||
fn forward_t(
|
||||
&self,
|
||||
input_ids: Option<Tensor>,
|
||||
token_type_ids: Option<Tensor>,
|
||||
position_ids: Option<Tensor>,
|
||||
input_embeds: Option<Tensor>,
|
||||
train: bool) -> Result<Tensor, &'static str> {
|
||||
train: bool,
|
||||
) -> Result<Tensor, &'static str> {
|
||||
let (input_embeddings, input_shape) = match &input_ids {
|
||||
Some(input_value) => match &input_embeds {
|
||||
Some(_) => { return Err("Only one of input ids or input embeddings may be set"); }
|
||||
None => (input_value.apply_t(&self.word_embeddings, train), input_value.size())
|
||||
Some(_) => {
|
||||
return Err("Only one of input ids or input embeddings may be set");
|
||||
}
|
||||
None => (
|
||||
input_value.apply_t(&self.word_embeddings, train),
|
||||
input_value.size(),
|
||||
),
|
||||
},
|
||||
None => match &input_embeds {
|
||||
Some(embeds) => (embeds.copy(), vec!(embeds.size()[0], embeds.size()[1])),
|
||||
None => { return Err("Only one of input ids or input embeddings may be set"); }
|
||||
Some(embeds) => (embeds.copy(), vec![embeds.size()[0], embeds.size()[1]]),
|
||||
None => {
|
||||
return Err("Only one of input ids or input embeddings may be set");
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
|
||||
let position_ids = match position_ids {
|
||||
Some(value) => value,
|
||||
None => match input_ids {
|
||||
Some(value) => self.create_position_ids_from_input_ids(&value),
|
||||
None => self.create_position_ids_from_embeddings(&input_embeds.unwrap())
|
||||
}
|
||||
None => self.create_position_ids_from_embeddings(&input_embeds.unwrap()),
|
||||
},
|
||||
};
|
||||
|
||||
let token_type_ids = match token_type_ids {
|
||||
Some(value) => value,
|
||||
None => Tensor::zeros(&input_shape, (Kind::Int64, input_embeddings.device()))
|
||||
None => Tensor::zeros(&input_shape, (Kind::Int64, input_embeddings.device())),
|
||||
};
|
||||
|
||||
let position_embeddings = position_ids.apply(&self.position_embeddings);
|
||||
let token_type_embeddings = token_type_ids.apply(&self.token_type_embeddings);
|
||||
|
||||
let input_embeddings: Tensor = input_embeddings + position_embeddings + token_type_embeddings;
|
||||
Ok(input_embeddings.apply(&self.layer_norm).apply_t(&self.dropout, train))
|
||||
let input_embeddings: Tensor =
|
||||
input_embeddings + position_embeddings + token_type_embeddings;
|
||||
Ok(input_embeddings
|
||||
.apply(&self.layer_norm)
|
||||
.apply_t(&self.dropout, train))
|
||||
}
|
||||
}
|
@ -25,14 +25,22 @@
|
||||
//! use tch::{nn, Device};
|
||||
//! # use std::path::PathBuf;
|
||||
//! use rust_bert::bert::BertConfig;
|
||||
//! use rust_bert::Config;
|
||||
//! use rust_bert::resources::{download_resource, LocalResource, Resource};
|
||||
//! use rust_bert::roberta::RobertaForMaskedLM;
|
||||
//! use rust_bert::resources::{Resource, download_resource, LocalResource};
|
||||
//! use rust_bert::Config;
|
||||
//!
|
||||
//! let config_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/config.json")});
|
||||
//! let vocab_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/vocab.txt")});
|
||||
//! let merges_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/vocab.txt")});
|
||||
//! let weights_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/model.ot")});
|
||||
//! let config_resource = Resource::Local(LocalResource {
|
||||
//! local_path: PathBuf::from("path/to/config.json"),
|
||||
//! });
|
||||
//! let vocab_resource = Resource::Local(LocalResource {
|
||||
//! local_path: PathBuf::from("path/to/vocab.txt"),
|
||||
//! });
|
||||
//! let merges_resource = Resource::Local(LocalResource {
|
||||
//! local_path: PathBuf::from("path/to/vocab.txt"),
|
||||
//! });
|
||||
//! let weights_resource = Resource::Local(LocalResource {
|
||||
//! local_path: PathBuf::from("path/to/model.ot"),
|
||||
//! });
|
||||
//! let config_path = download_resource(&config_resource)?;
|
||||
//! let vocab_path = download_resource(&vocab_resource)?;
|
||||
//! let merges_path = download_resource(&merges_resource)?;
|
||||
@ -40,7 +48,11 @@
|
||||
//!
|
||||
//! let device = Device::cuda_if_available();
|
||||
//! let mut vs = nn::VarStore::new(device);
|
||||
//! let tokenizer: RobertaTokenizer = RobertaTokenizer::from_file(vocab_path.to_str().unwrap(), merges_path.to_str().unwrap(), true);
|
||||
//! let tokenizer: RobertaTokenizer = RobertaTokenizer::from_file(
|
||||
//! vocab_path.to_str().unwrap(),
|
||||
//! merges_path.to_str().unwrap(),
|
||||
//! true,
|
||||
//! );
|
||||
//! let config = BertConfig::from_file(config_path);
|
||||
//! let bert_model = RobertaForMaskedLM::new(&vs.root(), &config);
|
||||
//! vs.load(weights_path)?;
|
||||
@ -49,10 +61,12 @@
|
||||
//! # }
|
||||
//! ```
|
||||
|
||||
|
||||
mod embeddings;
|
||||
mod roberta;
|
||||
|
||||
pub use roberta::{RobertaModelResources, RobertaConfigResources, RobertaVocabResources, RobertaMergesResources,
|
||||
RobertaForMaskedLM, RobertaForMultipleChoice, RobertaForTokenClassification, RobertaForQuestionAnswering, RobertaForSequenceClassification};
|
||||
pub use embeddings::RobertaEmbeddings;
|
||||
pub use roberta::{
|
||||
RobertaConfigResources, RobertaForMaskedLM, RobertaForMultipleChoice,
|
||||
RobertaForQuestionAnswering, RobertaForSequenceClassification, RobertaForTokenClassification,
|
||||
RobertaMergesResources, RobertaModelResources, RobertaVocabResources,
|
||||
};
|
||||
|
@ -11,13 +11,13 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use tch::{nn, Tensor};
|
||||
use crate::common::linear::{linear_no_bias, LinearNoBias};
|
||||
use tch::nn::Init;
|
||||
use crate::common::activations::_gelu;
|
||||
use crate::roberta::embeddings::RobertaEmbeddings;
|
||||
use crate::common::dropout::Dropout;
|
||||
use crate::bert::{BertConfig, BertModel};
|
||||
use crate::common::activations::_gelu;
|
||||
use crate::common::dropout::Dropout;
|
||||
use crate::common::linear::{linear_no_bias, LinearNoBias};
|
||||
use crate::roberta::embeddings::RobertaEmbeddings;
|
||||
use tch::nn::Init;
|
||||
use tch::{nn, Tensor};
|
||||
|
||||
/// # RoBERTa Pretrained model weight files
|
||||
pub struct RobertaModelResources;
|
||||
@ -33,22 +33,34 @@ pub struct RobertaMergesResources;
|
||||
|
||||
impl RobertaModelResources {
|
||||
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
|
||||
pub const ROBERTA: (&'static str, &'static str) = ("roberta/model.ot", "https://cdn.huggingface.co/roberta-base-rust_model.ot");
|
||||
pub const ROBERTA: (&'static str, &'static str) = (
|
||||
"roberta/model.ot",
|
||||
"https://cdn.huggingface.co/roberta-base-rust_model.ot",
|
||||
);
|
||||
}
|
||||
|
||||
impl RobertaConfigResources {
|
||||
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
|
||||
pub const ROBERTA: (&'static str, &'static str) = ("roberta/config.json", "https://cdn.huggingface.co/roberta-base-config.json");
|
||||
pub const ROBERTA: (&'static str, &'static str) = (
|
||||
"roberta/config.json",
|
||||
"https://cdn.huggingface.co/roberta-base-config.json",
|
||||
);
|
||||
}
|
||||
|
||||
impl RobertaVocabResources {
|
||||
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
|
||||
pub const ROBERTA: (&'static str, &'static str) = ("roberta/vocab.txt", "https://cdn.huggingface.co/roberta-base-vocab.json");
|
||||
pub const ROBERTA: (&'static str, &'static str) = (
|
||||
"roberta/vocab.txt",
|
||||
"https://cdn.huggingface.co/roberta-base-vocab.json",
|
||||
);
|
||||
}
|
||||
|
||||
impl RobertaMergesResources {
|
||||
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
|
||||
pub const ROBERTA: (&'static str, &'static str) = ("roberta/merges.txt", "https://cdn.huggingface.co/roberta-base-merges.txt");
|
||||
pub const ROBERTA: (&'static str, &'static str) = (
|
||||
"roberta/merges.txt",
|
||||
"https://cdn.huggingface.co/roberta-base-merges.txt",
|
||||
);
|
||||
}
|
||||
|
||||
pub struct RobertaLMHead {
|
||||
@ -60,17 +72,42 @@ pub struct RobertaLMHead {
|
||||
|
||||
impl RobertaLMHead {
|
||||
pub fn new(p: &nn::Path, config: &BertConfig) -> RobertaLMHead {
|
||||
let dense = nn::linear(p / "dense", config.hidden_size, config.hidden_size, Default::default());
|
||||
let layer_norm_config = nn::LayerNormConfig { eps: 1e-12, ..Default::default() };
|
||||
let layer_norm = nn::layer_norm(p / "layer_norm", vec![config.hidden_size], layer_norm_config);
|
||||
let decoder = linear_no_bias(&(p / "decoder"), config.hidden_size, config.vocab_size, Default::default());
|
||||
let dense = nn::linear(
|
||||
p / "dense",
|
||||
config.hidden_size,
|
||||
config.hidden_size,
|
||||
Default::default(),
|
||||
);
|
||||
let layer_norm_config = nn::LayerNormConfig {
|
||||
eps: 1e-12,
|
||||
..Default::default()
|
||||
};
|
||||
let layer_norm = nn::layer_norm(
|
||||
p / "layer_norm",
|
||||
vec![config.hidden_size],
|
||||
layer_norm_config,
|
||||
);
|
||||
let decoder = linear_no_bias(
|
||||
&(p / "decoder"),
|
||||
config.hidden_size,
|
||||
config.vocab_size,
|
||||
Default::default(),
|
||||
);
|
||||
let bias = p.var("bias", &[config.vocab_size], Init::KaimingUniform);
|
||||
|
||||
RobertaLMHead { dense, decoder, layer_norm, bias }
|
||||
RobertaLMHead {
|
||||
dense,
|
||||
decoder,
|
||||
layer_norm,
|
||||
bias,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn forward(&self, hidden_states: &Tensor) -> Tensor {
|
||||
(_gelu(&hidden_states.apply(&self.dense))).apply(&self.layer_norm).apply(&self.decoder) + &self.bias
|
||||
(_gelu(&hidden_states.apply(&self.dense)))
|
||||
.apply(&self.layer_norm)
|
||||
.apply(&self.decoder)
|
||||
+ &self.bias
|
||||
}
|
||||
}
|
||||
|
||||
@ -96,10 +133,10 @@ impl RobertaForMaskedLM {
|
||||
///
|
||||
/// ```no_run
|
||||
/// use rust_bert::bert::BertConfig;
|
||||
/// use tch::{nn, Device};
|
||||
/// use rust_bert::roberta::RobertaForMaskedLM;
|
||||
/// use rust_bert::Config;
|
||||
/// use std::path::Path;
|
||||
/// use rust_bert::roberta::RobertaForMaskedLM;
|
||||
/// use tch::{nn, Device};
|
||||
///
|
||||
/// let config_path = Path::new("path/to/config.json");
|
||||
/// let device = Device::Cpu;
|
||||
@ -107,7 +144,6 @@ impl RobertaForMaskedLM {
|
||||
/// let config = BertConfig::from_file(config_path);
|
||||
/// let roberta = RobertaForMaskedLM::new(&(&p.root() / "roberta"), &config);
|
||||
/// ```
|
||||
///
|
||||
pub fn new(p: &nn::Path, config: &BertConfig) -> RobertaForMaskedLM {
|
||||
let roberta = BertModel::<RobertaEmbeddings>::new(&(p / "roberta"), config);
|
||||
let lm_head = RobertaLMHead::new(&(p / "lm_head"), config);
|
||||
@ -153,23 +189,24 @@ impl RobertaForMaskedLM {
|
||||
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
|
||||
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
|
||||
/// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
|
||||
/// let position_ids = Tensor::arange(sequence_length, (Int64, device)).expand(&[batch_size, sequence_length], true);
|
||||
/// let position_ids = Tensor::arange(sequence_length, (Int64, device))
|
||||
/// .expand(&[batch_size, sequence_length], true);
|
||||
///
|
||||
/// let (output, all_hidden_states, all_attentions) = no_grad(|| {
|
||||
/// roberta_model
|
||||
/// .forward_t(Some(input_tensor),
|
||||
/// roberta_model.forward_t(
|
||||
/// Some(input_tensor),
|
||||
/// Some(mask),
|
||||
/// Some(token_type_ids),
|
||||
/// Some(position_ids),
|
||||
/// None,
|
||||
/// &None,
|
||||
/// &None,
|
||||
/// false)
|
||||
/// false,
|
||||
/// )
|
||||
/// });
|
||||
///
|
||||
/// ```
|
||||
///
|
||||
pub fn forward_t(&self,
|
||||
pub fn forward_t(
|
||||
&self,
|
||||
input_ids: Option<Tensor>,
|
||||
mask: Option<Tensor>,
|
||||
token_type_ids: Option<Tensor>,
|
||||
@ -177,9 +214,21 @@ impl RobertaForMaskedLM {
|
||||
input_embeds: Option<Tensor>,
|
||||
encoder_hidden_states: &Option<Tensor>,
|
||||
encoder_mask: &Option<Tensor>,
|
||||
train: bool) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
|
||||
let (hidden_state, _, all_hidden_states, all_attentions) = self.roberta.forward_t(input_ids, mask, token_type_ids, position_ids,
|
||||
input_embeds, encoder_hidden_states, encoder_mask, train).unwrap();
|
||||
train: bool,
|
||||
) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
|
||||
let (hidden_state, _, all_hidden_states, all_attentions) = self
|
||||
.roberta
|
||||
.forward_t(
|
||||
input_ids,
|
||||
mask,
|
||||
token_type_ids,
|
||||
position_ids,
|
||||
input_embeds,
|
||||
encoder_hidden_states,
|
||||
encoder_mask,
|
||||
train,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let prediction_scores = self.lm_head.forward(&hidden_state);
|
||||
(prediction_scores, all_hidden_states, all_attentions)
|
||||
@ -194,12 +243,30 @@ pub struct RobertaClassificationHead {
|
||||
|
||||
impl RobertaClassificationHead {
|
||||
pub fn new(p: &nn::Path, config: &BertConfig) -> RobertaClassificationHead {
|
||||
let dense = nn::linear(p / "dense", config.hidden_size, config.hidden_size, Default::default());
|
||||
let num_labels = config.id2label.as_ref().expect("num_labels not provided in configuration").len() as i64;
|
||||
let out_proj = nn::linear(p / "out_proj", config.hidden_size, num_labels, Default::default());
|
||||
let dense = nn::linear(
|
||||
p / "dense",
|
||||
config.hidden_size,
|
||||
config.hidden_size,
|
||||
Default::default(),
|
||||
);
|
||||
let num_labels = config
|
||||
.id2label
|
||||
.as_ref()
|
||||
.expect("num_labels not provided in configuration")
|
||||
.len() as i64;
|
||||
let out_proj = nn::linear(
|
||||
p / "out_proj",
|
||||
config.hidden_size,
|
||||
num_labels,
|
||||
Default::default(),
|
||||
);
|
||||
let dropout = Dropout::new(config.hidden_dropout_prob);
|
||||
|
||||
RobertaClassificationHead { dense, dropout, out_proj }
|
||||
RobertaClassificationHead {
|
||||
dense,
|
||||
dropout,
|
||||
out_proj,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn forward_t(&self, hidden_states: &Tensor, train: bool) -> Tensor {
|
||||
@ -235,10 +302,10 @@ impl RobertaForSequenceClassification {
|
||||
///
|
||||
/// ```no_run
|
||||
/// use rust_bert::bert::BertConfig;
|
||||
/// use tch::{nn, Device};
|
||||
/// use rust_bert::roberta::RobertaForSequenceClassification;
|
||||
/// use rust_bert::Config;
|
||||
/// use std::path::Path;
|
||||
/// use rust_bert::roberta::RobertaForSequenceClassification;
|
||||
/// use tch::{nn, Device};
|
||||
///
|
||||
/// let config_path = Path::new("path/to/config.json");
|
||||
/// let device = Device::Cpu;
|
||||
@ -246,12 +313,14 @@ impl RobertaForSequenceClassification {
|
||||
/// let config = BertConfig::from_file(config_path);
|
||||
/// let roberta = RobertaForSequenceClassification::new(&(&p.root() / "roberta"), &config);
|
||||
/// ```
|
||||
///
|
||||
pub fn new(p: &nn::Path, config: &BertConfig) -> RobertaForSequenceClassification {
|
||||
let roberta = BertModel::<RobertaEmbeddings>::new(&(p / "roberta"), config);
|
||||
let classifier = RobertaClassificationHead::new(&(p / "classifier"), config);
|
||||
|
||||
RobertaForSequenceClassification { roberta, classifier }
|
||||
RobertaForSequenceClassification {
|
||||
roberta,
|
||||
classifier,
|
||||
}
|
||||
}
|
||||
|
||||
/// Forward pass through the model
|
||||
@ -290,29 +359,42 @@ impl RobertaForSequenceClassification {
|
||||
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
|
||||
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
|
||||
/// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
|
||||
/// let position_ids = Tensor::arange(sequence_length, (Int64, device)).expand(&[batch_size, sequence_length], true);
|
||||
/// let position_ids = Tensor::arange(sequence_length, (Int64, device))
|
||||
/// .expand(&[batch_size, sequence_length], true);
|
||||
///
|
||||
/// let (labels, all_hidden_states, all_attentions) = no_grad(|| {
|
||||
/// roberta_model
|
||||
/// .forward_t(Some(input_tensor),
|
||||
/// roberta_model.forward_t(
|
||||
/// Some(input_tensor),
|
||||
/// Some(mask),
|
||||
/// Some(token_type_ids),
|
||||
/// Some(position_ids),
|
||||
/// None,
|
||||
/// false)
|
||||
/// false,
|
||||
/// )
|
||||
/// });
|
||||
///
|
||||
/// ```
|
||||
///
|
||||
pub fn forward_t(&self,
|
||||
pub fn forward_t(
|
||||
&self,
|
||||
input_ids: Option<Tensor>,
|
||||
mask: Option<Tensor>,
|
||||
token_type_ids: Option<Tensor>,
|
||||
position_ids: Option<Tensor>,
|
||||
input_embeds: Option<Tensor>,
|
||||
train: bool) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
|
||||
let (hidden_state, _, all_hidden_states, all_attentions) = self.roberta.forward_t(input_ids, mask, token_type_ids, position_ids,
|
||||
input_embeds, &None, &None, train).unwrap();
|
||||
train: bool,
|
||||
) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
|
||||
let (hidden_state, _, all_hidden_states, all_attentions) = self
|
||||
.roberta
|
||||
.forward_t(
|
||||
input_ids,
|
||||
mask,
|
||||
token_type_ids,
|
||||
position_ids,
|
||||
input_embeds,
|
||||
&None,
|
||||
&None,
|
||||
train,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let output = self.classifier.forward_t(&hidden_state, train);
|
||||
(output, all_hidden_states, all_attentions)
|
||||
@ -344,10 +426,10 @@ impl RobertaForMultipleChoice {
|
||||
///
|
||||
/// ```no_run
|
||||
/// use rust_bert::bert::BertConfig;
|
||||
/// use tch::{nn, Device};
|
||||
/// use rust_bert::roberta::RobertaForMultipleChoice;
|
||||
/// use rust_bert::Config;
|
||||
/// use std::path::Path;
|
||||
/// use rust_bert::roberta::RobertaForMultipleChoice;
|
||||
/// use tch::{nn, Device};
|
||||
///
|
||||
/// let config_path = Path::new("path/to/config.json");
|
||||
/// let device = Device::Cpu;
|
||||
@ -355,13 +437,16 @@ impl RobertaForMultipleChoice {
|
||||
/// let config = BertConfig::from_file(config_path);
|
||||
/// let roberta = RobertaForMultipleChoice::new(&(&p.root() / "roberta"), &config);
|
||||
/// ```
|
||||
///
|
||||
pub fn new(p: &nn::Path, config: &BertConfig) -> RobertaForMultipleChoice {
|
||||
let roberta = BertModel::<RobertaEmbeddings>::new(&(p / "roberta"), config);
|
||||
let dropout = Dropout::new(config.hidden_dropout_prob);
|
||||
let classifier = nn::linear(p / "classifier", config.hidden_size, 1, Default::default());
|
||||
|
||||
RobertaForMultipleChoice { roberta, dropout, classifier }
|
||||
RobertaForMultipleChoice {
|
||||
roberta,
|
||||
dropout,
|
||||
classifier,
|
||||
}
|
||||
}
|
||||
|
||||
/// Forward pass through the model
|
||||
@ -399,45 +484,61 @@ impl RobertaForMultipleChoice {
|
||||
/// let input_tensor = Tensor::rand(&[num_choices, sequence_length], (Int64, device));
|
||||
/// let mask = Tensor::zeros(&[num_choices, sequence_length], (Int64, device));
|
||||
/// let token_type_ids = Tensor::zeros(&[num_choices, sequence_length], (Int64, device));
|
||||
/// let position_ids = Tensor::arange(sequence_length, (Int64, device)).expand(&[num_choices, sequence_length], true);
|
||||
/// let position_ids = Tensor::arange(sequence_length, (Int64, device))
|
||||
/// .expand(&[num_choices, sequence_length], true);
|
||||
///
|
||||
/// let (choices, all_hidden_states, all_attentions) = no_grad(|| {
|
||||
/// roberta_model
|
||||
/// .forward_t(input_tensor,
|
||||
/// roberta_model.forward_t(
|
||||
/// input_tensor,
|
||||
/// Some(mask),
|
||||
/// Some(token_type_ids),
|
||||
/// Some(position_ids),
|
||||
/// false)
|
||||
/// false,
|
||||
/// )
|
||||
/// });
|
||||
///
|
||||
/// ```
|
||||
///
|
||||
pub fn forward_t(&self,
|
||||
pub fn forward_t(
|
||||
&self,
|
||||
input_ids: Tensor,
|
||||
mask: Option<Tensor>,
|
||||
token_type_ids: Option<Tensor>,
|
||||
position_ids: Option<Tensor>,
|
||||
train: bool) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
|
||||
train: bool,
|
||||
) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
|
||||
let num_choices = input_ids.size()[1];
|
||||
|
||||
let flat_input_ids = Some(input_ids.view((-1i64, *input_ids.size().last().unwrap())));
|
||||
let flat_position_ids = match position_ids {
|
||||
Some(value) => Some(value.view((-1i64, *value.size().last().unwrap()))),
|
||||
None => None
|
||||
None => None,
|
||||
};
|
||||
let flat_token_type_ids = match token_type_ids {
|
||||
Some(value) => Some(value.view((-1i64, *value.size().last().unwrap()))),
|
||||
None => None
|
||||
None => None,
|
||||
};
|
||||
let flat_mask = match mask {
|
||||
Some(value) => Some(value.view((-1i64, *value.size().last().unwrap()))),
|
||||
None => None
|
||||
None => None,
|
||||
};
|
||||
|
||||
let (_, pooled_output, all_hidden_states, all_attentions) = self.roberta.forward_t(flat_input_ids, flat_mask, flat_token_type_ids, flat_position_ids,
|
||||
None, &None, &None, train).unwrap();
|
||||
let (_, pooled_output, all_hidden_states, all_attentions) = self
|
||||
.roberta
|
||||
.forward_t(
|
||||
flat_input_ids,
|
||||
flat_mask,
|
||||
flat_token_type_ids,
|
||||
flat_position_ids,
|
||||
None,
|
||||
&None,
|
||||
&None,
|
||||
train,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let output = pooled_output.apply_t(&self.dropout, train).apply(&self.classifier).view((-1, num_choices));
|
||||
let output = pooled_output
|
||||
.apply_t(&self.dropout, train)
|
||||
.apply(&self.classifier)
|
||||
.view((-1, num_choices));
|
||||
(output, all_hidden_states, all_attentions)
|
||||
}
|
||||
}
|
||||
@ -466,10 +567,10 @@ impl RobertaForTokenClassification {
|
||||
///
|
||||
/// ```no_run
|
||||
/// use rust_bert::bert::BertConfig;
|
||||
/// use tch::{nn, Device};
|
||||
/// use rust_bert::roberta::RobertaForTokenClassification;
|
||||
/// use rust_bert::Config;
|
||||
/// use std::path::Path;
|
||||
/// use rust_bert::roberta::RobertaForTokenClassification;
|
||||
/// use tch::{nn, Device};
|
||||
///
|
||||
/// let config_path = Path::new("path/to/config.json");
|
||||
/// let device = Device::Cpu;
|
||||
@ -477,14 +578,26 @@ impl RobertaForTokenClassification {
|
||||
/// let config = BertConfig::from_file(config_path);
|
||||
/// let roberta = RobertaForTokenClassification::new(&(&p.root() / "roberta"), &config);
|
||||
/// ```
|
||||
///
|
||||
pub fn new(p: &nn::Path, config: &BertConfig) -> RobertaForTokenClassification {
|
||||
let roberta = BertModel::<RobertaEmbeddings>::new(&(p / "roberta"), config);
|
||||
let dropout = Dropout::new(config.hidden_dropout_prob);
|
||||
let num_labels = config.id2label.as_ref().expect("num_labels not provided in configuration").len() as i64;
|
||||
let classifier = nn::linear(p / "classifier", config.hidden_size, num_labels, Default::default());
|
||||
let num_labels = config
|
||||
.id2label
|
||||
.as_ref()
|
||||
.expect("num_labels not provided in configuration")
|
||||
.len() as i64;
|
||||
let classifier = nn::linear(
|
||||
p / "classifier",
|
||||
config.hidden_size,
|
||||
num_labels,
|
||||
Default::default(),
|
||||
);
|
||||
|
||||
RobertaForTokenClassification { roberta, dropout, classifier }
|
||||
RobertaForTokenClassification {
|
||||
roberta,
|
||||
dropout,
|
||||
classifier,
|
||||
}
|
||||
}
|
||||
|
||||
/// Forward pass through the model
|
||||
@ -523,31 +636,46 @@ impl RobertaForTokenClassification {
|
||||
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
|
||||
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
|
||||
/// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
|
||||
/// let position_ids = Tensor::arange(sequence_length, (Int64, device)).expand(&[batch_size, sequence_length], true);
|
||||
/// let position_ids = Tensor::arange(sequence_length, (Int64, device))
|
||||
/// .expand(&[batch_size, sequence_length], true);
|
||||
///
|
||||
/// let (token_labels, all_hidden_states, all_attentions) = no_grad(|| {
|
||||
/// roberta_model
|
||||
/// .forward_t(Some(input_tensor),
|
||||
/// roberta_model.forward_t(
|
||||
/// Some(input_tensor),
|
||||
/// Some(mask),
|
||||
/// Some(token_type_ids),
|
||||
/// Some(position_ids),
|
||||
/// None,
|
||||
/// false)
|
||||
/// false,
|
||||
/// )
|
||||
/// });
|
||||
///
|
||||
/// ```
|
||||
///
|
||||
pub fn forward_t(&self,
|
||||
pub fn forward_t(
|
||||
&self,
|
||||
input_ids: Option<Tensor>,
|
||||
mask: Option<Tensor>,
|
||||
token_type_ids: Option<Tensor>,
|
||||
position_ids: Option<Tensor>,
|
||||
input_embeds: Option<Tensor>,
|
||||
train: bool) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
|
||||
let (hidden_state, _, all_hidden_states, all_attentions) = self.roberta.forward_t(input_ids, mask, token_type_ids, position_ids,
|
||||
input_embeds, &None, &None, train).unwrap();
|
||||
train: bool,
|
||||
) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
|
||||
let (hidden_state, _, all_hidden_states, all_attentions) = self
|
||||
.roberta
|
||||
.forward_t(
|
||||
input_ids,
|
||||
mask,
|
||||
token_type_ids,
|
||||
position_ids,
|
||||
input_embeds,
|
||||
&None,
|
||||
&None,
|
||||
train,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let sequence_output = hidden_state.apply_t(&self.dropout, train).apply(&self.classifier);
|
||||
let sequence_output = hidden_state
|
||||
.apply_t(&self.dropout, train)
|
||||
.apply(&self.classifier);
|
||||
(sequence_output, all_hidden_states, all_attentions)
|
||||
}
|
||||
}
|
||||
@ -576,10 +704,10 @@ impl RobertaForQuestionAnswering {
|
||||
///
|
||||
/// ```no_run
|
||||
/// use rust_bert::bert::BertConfig;
|
||||
/// use tch::{nn, Device};
|
||||
/// use rust_bert::roberta::RobertaForQuestionAnswering;
|
||||
/// use rust_bert::Config;
|
||||
/// use std::path::Path;
|
||||
/// use rust_bert::roberta::RobertaForQuestionAnswering;
|
||||
/// use tch::{nn, Device};
|
||||
///
|
||||
/// let config_path = Path::new("path/to/config.json");
|
||||
/// let device = Device::Cpu;
|
||||
@ -587,13 +715,20 @@ impl RobertaForQuestionAnswering {
|
||||
/// let config = BertConfig::from_file(config_path);
|
||||
/// let roberta = RobertaForQuestionAnswering::new(&(&p.root() / "roberta"), &config);
|
||||
/// ```
|
||||
///
|
||||
pub fn new(p: &nn::Path, config: &BertConfig) -> RobertaForQuestionAnswering {
|
||||
let roberta = BertModel::<RobertaEmbeddings>::new(&(p / "roberta"), config);
|
||||
let num_labels = 2;
|
||||
let qa_outputs = nn::linear(p / "qa_outputs", config.hidden_size, num_labels, Default::default());
|
||||
let qa_outputs = nn::linear(
|
||||
p / "qa_outputs",
|
||||
config.hidden_size,
|
||||
num_labels,
|
||||
Default::default(),
|
||||
);
|
||||
|
||||
RobertaForQuestionAnswering { roberta, qa_outputs }
|
||||
RobertaForQuestionAnswering {
|
||||
roberta,
|
||||
qa_outputs,
|
||||
}
|
||||
}
|
||||
|
||||
/// Forward pass through the model
|
||||
@ -633,29 +768,42 @@ impl RobertaForQuestionAnswering {
|
||||
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
|
||||
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
|
||||
/// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
|
||||
/// let position_ids = Tensor::arange(sequence_length, (Int64, device)).expand(&[batch_size, sequence_length], true);
|
||||
/// let position_ids = Tensor::arange(sequence_length, (Int64, device))
|
||||
/// .expand(&[batch_size, sequence_length], true);
|
||||
///
|
||||
/// let (start_scores, end_scores, all_hidden_states, all_attentions) = no_grad(|| {
|
||||
/// roberta_model
|
||||
/// .forward_t(Some(input_tensor),
|
||||
/// roberta_model.forward_t(
|
||||
/// Some(input_tensor),
|
||||
/// Some(mask),
|
||||
/// Some(token_type_ids),
|
||||
/// Some(position_ids),
|
||||
/// None,
|
||||
/// false)
|
||||
/// false,
|
||||
/// )
|
||||
/// });
|
||||
///
|
||||
/// ```
|
||||
///
|
||||
pub fn forward_t(&self,
|
||||
pub fn forward_t(
|
||||
&self,
|
||||
input_ids: Option<Tensor>,
|
||||
mask: Option<Tensor>,
|
||||
token_type_ids: Option<Tensor>,
|
||||
position_ids: Option<Tensor>,
|
||||
input_embeds: Option<Tensor>,
|
||||
train: bool) -> (Tensor, Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
|
||||
let (hidden_state, _, all_hidden_states, all_attentions) = self.roberta.forward_t(input_ids, mask, token_type_ids, position_ids,
|
||||
input_embeds, &None, &None, train).unwrap();
|
||||
train: bool,
|
||||
) -> (Tensor, Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
|
||||
let (hidden_state, _, all_hidden_states, all_attentions) = self
|
||||
.roberta
|
||||
.forward_t(
|
||||
input_ids,
|
||||
mask,
|
||||
token_type_ids,
|
||||
position_ids,
|
||||
input_embeds,
|
||||
&None,
|
||||
&None,
|
||||
train,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let sequence_output = hidden_state.apply(&self.qa_outputs);
|
||||
let logits = sequence_output.split(1, -1);
|
||||
|
317
tests/albert.rs
317
tests/albert.rs
@ -1,20 +1,29 @@
|
||||
extern crate failure;
|
||||
extern crate dirs;
|
||||
extern crate failure;
|
||||
|
||||
use tch::{Device, nn, Tensor, no_grad};
|
||||
use rust_tokenizers::{TruncationStrategy, Tokenizer, Vocab, AlbertTokenizer};
|
||||
use rust_bert::albert::{
|
||||
AlbertConfig, AlbertConfigResources, AlbertForMaskedLM, AlbertForMultipleChoice,
|
||||
AlbertForQuestionAnswering, AlbertForSequenceClassification, AlbertForTokenClassification,
|
||||
AlbertModelResources, AlbertVocabResources,
|
||||
};
|
||||
use rust_bert::resources::{download_resource, RemoteResource, Resource};
|
||||
use rust_bert::Config;
|
||||
use rust_bert::resources::{Resource, RemoteResource, download_resource};
|
||||
use rust_bert::albert::{AlbertConfigResources, AlbertVocabResources, AlbertModelResources, AlbertConfig, AlbertForMaskedLM, AlbertForSequenceClassification, AlbertForMultipleChoice, AlbertForTokenClassification, AlbertForQuestionAnswering};
|
||||
use rust_tokenizers::{AlbertTokenizer, Tokenizer, TruncationStrategy, Vocab};
|
||||
use std::collections::HashMap;
|
||||
|
||||
use tch::{nn, no_grad, Device, Tensor};
|
||||
|
||||
#[test]
|
||||
fn albert_masked_lm() -> failure::Fallible<()> {
|
||||
// Resources paths
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(AlbertConfigResources::ALBERT_BASE_V2));
|
||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(AlbertVocabResources::ALBERT_BASE_V2));
|
||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(AlbertModelResources::ALBERT_BASE_V2));
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
AlbertConfigResources::ALBERT_BASE_V2,
|
||||
));
|
||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
AlbertVocabResources::ALBERT_BASE_V2,
|
||||
));
|
||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
AlbertModelResources::ALBERT_BASE_V2,
|
||||
));
|
||||
let config_path = download_resource(&config_resource)?;
|
||||
let vocab_path = download_resource(&vocab_resource)?;
|
||||
let weights_path = download_resource(&weights_resource)?;
|
||||
@ -22,37 +31,38 @@ fn albert_masked_lm() -> failure::Fallible<()> {
|
||||
// Set-up masked LM model
|
||||
let device = Device::Cpu;
|
||||
let mut vs = nn::VarStore::new(device);
|
||||
let tokenizer: AlbertTokenizer = AlbertTokenizer::from_file(vocab_path.to_str().unwrap(), true, false);
|
||||
let tokenizer: AlbertTokenizer =
|
||||
AlbertTokenizer::from_file(vocab_path.to_str().unwrap(), true, false);
|
||||
let config = AlbertConfig::from_file(config_path);
|
||||
let albert_model = AlbertForMaskedLM::new(&vs.root(), &config);
|
||||
vs.load(weights_path)?;
|
||||
|
||||
// Define input
|
||||
let input = ["Looks like one [MASK] is missing", "It\'s like comparing [MASK] to apples"];
|
||||
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
||||
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
|
||||
let tokenized_input = tokenized_input.
|
||||
iter().
|
||||
map(|input| input.token_ids.clone()).
|
||||
map(|mut input| {
|
||||
let input = [
|
||||
"Looks like one [MASK] is missing",
|
||||
"It\'s like comparing [MASK] to apples",
|
||||
];
|
||||
let tokenized_input =
|
||||
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
||||
let max_len = tokenized_input
|
||||
.iter()
|
||||
.map(|input| input.token_ids.len())
|
||||
.max()
|
||||
.unwrap();
|
||||
let tokenized_input = tokenized_input
|
||||
.iter()
|
||||
.map(|input| input.token_ids.clone())
|
||||
.map(|mut input| {
|
||||
input.extend(vec![0; max_len - input.len()]);
|
||||
input
|
||||
}).
|
||||
map(|input|
|
||||
Tensor::of_slice(&(input))).
|
||||
collect::<Vec<_>>();
|
||||
})
|
||||
.map(|input| Tensor::of_slice(&(input)))
|
||||
.collect::<Vec<_>>();
|
||||
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
|
||||
|
||||
// Forward pass
|
||||
let (output, _, _) = no_grad(|| {
|
||||
albert_model
|
||||
.forward_t(Some(input_tensor),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
false)
|
||||
});
|
||||
let (output, _, _) =
|
||||
no_grad(|| albert_model.forward_t(Some(input_tensor), None, None, None, None, false));
|
||||
|
||||
// Print masked tokens
|
||||
let index_1 = output.get(0).get(4).argmax(0, false);
|
||||
@ -69,15 +79,20 @@ fn albert_masked_lm() -> failure::Fallible<()> {
|
||||
#[test]
|
||||
fn albert_for_sequence_classification() -> failure::Fallible<()> {
|
||||
// Resources paths
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(AlbertConfigResources::ALBERT_BASE_V2));
|
||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(AlbertVocabResources::ALBERT_BASE_V2));
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
AlbertConfigResources::ALBERT_BASE_V2,
|
||||
));
|
||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
AlbertVocabResources::ALBERT_BASE_V2,
|
||||
));
|
||||
let config_path = download_resource(&config_resource)?;
|
||||
let vocab_path = download_resource(&vocab_resource)?;
|
||||
|
||||
// Set-up model
|
||||
let device = Device::Cpu;
|
||||
let vs = nn::VarStore::new(device);
|
||||
let tokenizer: AlbertTokenizer = AlbertTokenizer::from_file(vocab_path.to_str().unwrap(), true, false);
|
||||
let tokenizer: AlbertTokenizer =
|
||||
AlbertTokenizer::from_file(vocab_path.to_str().unwrap(), true, false);
|
||||
let mut config = AlbertConfig::from_file(config_path);
|
||||
let mut dummy_label_mapping = HashMap::new();
|
||||
dummy_label_mapping.insert(0, String::from("Positive"));
|
||||
@ -88,37 +103,42 @@ fn albert_for_sequence_classification() -> failure::Fallible<()> {
|
||||
config.output_hidden_states = Some(true);
|
||||
let albert_model = AlbertForSequenceClassification::new(&vs.root(), &config);
|
||||
|
||||
|
||||
// Define input
|
||||
let input = ["Looks like one thing is missing", "It\'s like comparing oranges to apples"];
|
||||
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
||||
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
|
||||
let tokenized_input = tokenized_input.
|
||||
iter().
|
||||
map(|input| input.token_ids.clone()).
|
||||
map(|mut input| {
|
||||
let input = [
|
||||
"Looks like one thing is missing",
|
||||
"It\'s like comparing oranges to apples",
|
||||
];
|
||||
let tokenized_input =
|
||||
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
||||
let max_len = tokenized_input
|
||||
.iter()
|
||||
.map(|input| input.token_ids.len())
|
||||
.max()
|
||||
.unwrap();
|
||||
let tokenized_input = tokenized_input
|
||||
.iter()
|
||||
.map(|input| input.token_ids.clone())
|
||||
.map(|mut input| {
|
||||
input.extend(vec![0; max_len - input.len()]);
|
||||
input
|
||||
}).
|
||||
map(|input|
|
||||
Tensor::of_slice(&(input))).
|
||||
collect::<Vec<_>>();
|
||||
})
|
||||
.map(|input| Tensor::of_slice(&(input)))
|
||||
.collect::<Vec<_>>();
|
||||
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
|
||||
|
||||
// Forward pass
|
||||
let (output, all_hidden_states, all_attentions) = no_grad(|| {
|
||||
albert_model
|
||||
.forward_t(Some(input_tensor),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
false)
|
||||
});
|
||||
let (output, all_hidden_states, all_attentions) =
|
||||
no_grad(|| albert_model.forward_t(Some(input_tensor), None, None, None, None, false));
|
||||
|
||||
assert_eq!(output.size(), &[2, 3]);
|
||||
assert_eq!(config.num_hidden_layers as usize, all_hidden_states.unwrap().len());
|
||||
assert_eq!(config.num_hidden_layers as usize, all_attentions.unwrap().len());
|
||||
assert_eq!(
|
||||
config.num_hidden_layers as usize,
|
||||
all_hidden_states.unwrap().len()
|
||||
);
|
||||
assert_eq!(
|
||||
config.num_hidden_layers as usize,
|
||||
all_attentions.unwrap().len()
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@ -126,50 +146,66 @@ fn albert_for_sequence_classification() -> failure::Fallible<()> {
|
||||
#[test]
|
||||
fn albert_for_multiple_choice() -> failure::Fallible<()> {
|
||||
// Resources paths
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(AlbertConfigResources::ALBERT_BASE_V2));
|
||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(AlbertVocabResources::ALBERT_BASE_V2));
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
AlbertConfigResources::ALBERT_BASE_V2,
|
||||
));
|
||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
AlbertVocabResources::ALBERT_BASE_V2,
|
||||
));
|
||||
let config_path = download_resource(&config_resource)?;
|
||||
let vocab_path = download_resource(&vocab_resource)?;
|
||||
|
||||
// Set-up model
|
||||
let device = Device::Cpu;
|
||||
let vs = nn::VarStore::new(device);
|
||||
let tokenizer: AlbertTokenizer = AlbertTokenizer::from_file(vocab_path.to_str().unwrap(), true, false);
|
||||
let tokenizer: AlbertTokenizer =
|
||||
AlbertTokenizer::from_file(vocab_path.to_str().unwrap(), true, false);
|
||||
let mut config = AlbertConfig::from_file(config_path);
|
||||
config.output_attentions = Some(true);
|
||||
config.output_hidden_states = Some(true);
|
||||
let albert_model = AlbertForMultipleChoice::new(&vs.root(), &config);
|
||||
|
||||
// Define input
|
||||
let input = ["Looks like one thing is missing", "It\'s like comparing oranges to apples"];
|
||||
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
||||
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
|
||||
let tokenized_input = tokenized_input.
|
||||
iter().
|
||||
map(|input| input.token_ids.clone()).
|
||||
map(|mut input| {
|
||||
let input = [
|
||||
"Looks like one thing is missing",
|
||||
"It\'s like comparing oranges to apples",
|
||||
];
|
||||
let tokenized_input =
|
||||
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
||||
let max_len = tokenized_input
|
||||
.iter()
|
||||
.map(|input| input.token_ids.len())
|
||||
.max()
|
||||
.unwrap();
|
||||
let tokenized_input = tokenized_input
|
||||
.iter()
|
||||
.map(|input| input.token_ids.clone())
|
||||
.map(|mut input| {
|
||||
input.extend(vec![0; max_len - input.len()]);
|
||||
input
|
||||
}).
|
||||
map(|input|
|
||||
Tensor::of_slice(&(input))).
|
||||
collect::<Vec<_>>();
|
||||
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device).unsqueeze(0);
|
||||
})
|
||||
.map(|input| Tensor::of_slice(&(input)))
|
||||
.collect::<Vec<_>>();
|
||||
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0)
|
||||
.to(device)
|
||||
.unsqueeze(0);
|
||||
|
||||
// Forward pass
|
||||
let (output, all_hidden_states, all_attentions) = no_grad(|| {
|
||||
albert_model
|
||||
.forward_t(Some(input_tensor),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
false).unwrap()
|
||||
.forward_t(Some(input_tensor), None, None, None, None, false)
|
||||
.unwrap()
|
||||
});
|
||||
|
||||
assert_eq!(output.size(), &[1, 2]);
|
||||
assert_eq!(config.num_hidden_layers as usize, all_hidden_states.unwrap().len());
|
||||
assert_eq!(config.num_hidden_layers as usize, all_attentions.unwrap().len());
|
||||
assert_eq!(
|
||||
config.num_hidden_layers as usize,
|
||||
all_hidden_states.unwrap().len()
|
||||
);
|
||||
assert_eq!(
|
||||
config.num_hidden_layers as usize,
|
||||
all_attentions.unwrap().len()
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@ -177,15 +213,20 @@ fn albert_for_multiple_choice() -> failure::Fallible<()> {
|
||||
#[test]
|
||||
fn albert_for_token_classification() -> failure::Fallible<()> {
|
||||
// Resources paths
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(AlbertConfigResources::ALBERT_BASE_V2));
|
||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(AlbertVocabResources::ALBERT_BASE_V2));
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
AlbertConfigResources::ALBERT_BASE_V2,
|
||||
));
|
||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
AlbertVocabResources::ALBERT_BASE_V2,
|
||||
));
|
||||
let config_path = download_resource(&config_resource)?;
|
||||
let vocab_path = download_resource(&vocab_resource)?;
|
||||
|
||||
// Set-up model
|
||||
let device = Device::Cpu;
|
||||
let vs = nn::VarStore::new(device);
|
||||
let tokenizer: AlbertTokenizer = AlbertTokenizer::from_file(vocab_path.to_str().unwrap(), true, false);
|
||||
let tokenizer: AlbertTokenizer =
|
||||
AlbertTokenizer::from_file(vocab_path.to_str().unwrap(), true, false);
|
||||
let mut config = AlbertConfig::from_file(config_path);
|
||||
let mut dummy_label_mapping = HashMap::new();
|
||||
dummy_label_mapping.insert(0, String::from("O"));
|
||||
@ -197,37 +238,42 @@ fn albert_for_token_classification() -> failure::Fallible<()> {
|
||||
config.output_hidden_states = Some(true);
|
||||
let bert_model = AlbertForTokenClassification::new(&vs.root(), &config);
|
||||
|
||||
|
||||
// Define input
|
||||
let input = ["Looks like one thing is missing", "It\'s like comparing oranges to apples"];
|
||||
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
||||
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
|
||||
let tokenized_input = tokenized_input.
|
||||
iter().
|
||||
map(|input| input.token_ids.clone()).
|
||||
map(|mut input| {
|
||||
let input = [
|
||||
"Looks like one thing is missing",
|
||||
"It\'s like comparing oranges to apples",
|
||||
];
|
||||
let tokenized_input =
|
||||
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
||||
let max_len = tokenized_input
|
||||
.iter()
|
||||
.map(|input| input.token_ids.len())
|
||||
.max()
|
||||
.unwrap();
|
||||
let tokenized_input = tokenized_input
|
||||
.iter()
|
||||
.map(|input| input.token_ids.clone())
|
||||
.map(|mut input| {
|
||||
input.extend(vec![0; max_len - input.len()]);
|
||||
input
|
||||
}).
|
||||
map(|input|
|
||||
Tensor::of_slice(&(input))).
|
||||
collect::<Vec<_>>();
|
||||
})
|
||||
.map(|input| Tensor::of_slice(&(input)))
|
||||
.collect::<Vec<_>>();
|
||||
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
|
||||
|
||||
// Forward pass
|
||||
let (output, all_hidden_states, all_attentions) = no_grad(|| {
|
||||
bert_model
|
||||
.forward_t(Some(input_tensor),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
false)
|
||||
});
|
||||
let (output, all_hidden_states, all_attentions) =
|
||||
no_grad(|| bert_model.forward_t(Some(input_tensor), None, None, None, None, false));
|
||||
|
||||
assert_eq!(output.size(), &[2, 12, 4]);
|
||||
assert_eq!(config.num_hidden_layers as usize, all_hidden_states.unwrap().len());
|
||||
assert_eq!(config.num_hidden_layers as usize, all_attentions.unwrap().len());
|
||||
assert_eq!(
|
||||
config.num_hidden_layers as usize,
|
||||
all_hidden_states.unwrap().len()
|
||||
);
|
||||
assert_eq!(
|
||||
config.num_hidden_layers as usize,
|
||||
all_attentions.unwrap().len()
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@ -235,51 +281,62 @@ fn albert_for_token_classification() -> failure::Fallible<()> {
|
||||
#[test]
|
||||
fn albert_for_question_answering() -> failure::Fallible<()> {
|
||||
// Resources paths
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(AlbertConfigResources::ALBERT_BASE_V2));
|
||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(AlbertVocabResources::ALBERT_BASE_V2));
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
AlbertConfigResources::ALBERT_BASE_V2,
|
||||
));
|
||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
AlbertVocabResources::ALBERT_BASE_V2,
|
||||
));
|
||||
let config_path = download_resource(&config_resource)?;
|
||||
let vocab_path = download_resource(&vocab_resource)?;
|
||||
|
||||
// Set-up model
|
||||
let device = Device::Cpu;
|
||||
let vs = nn::VarStore::new(device);
|
||||
let tokenizer: AlbertTokenizer = AlbertTokenizer::from_file(vocab_path.to_str().unwrap(), true, false);
|
||||
let tokenizer: AlbertTokenizer =
|
||||
AlbertTokenizer::from_file(vocab_path.to_str().unwrap(), true, false);
|
||||
let mut config = AlbertConfig::from_file(config_path);
|
||||
config.output_attentions = Some(true);
|
||||
config.output_hidden_states = Some(true);
|
||||
let albert_model = AlbertForQuestionAnswering::new(&vs.root(), &config);
|
||||
|
||||
// Define input
|
||||
let input = ["Looks like one thing is missing", "It\'s like comparing oranges to apples"];
|
||||
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
||||
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
|
||||
let tokenized_input = tokenized_input.
|
||||
iter().
|
||||
map(|input| input.token_ids.clone()).
|
||||
map(|mut input| {
|
||||
let input = [
|
||||
"Looks like one thing is missing",
|
||||
"It\'s like comparing oranges to apples",
|
||||
];
|
||||
let tokenized_input =
|
||||
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
||||
let max_len = tokenized_input
|
||||
.iter()
|
||||
.map(|input| input.token_ids.len())
|
||||
.max()
|
||||
.unwrap();
|
||||
let tokenized_input = tokenized_input
|
||||
.iter()
|
||||
.map(|input| input.token_ids.clone())
|
||||
.map(|mut input| {
|
||||
input.extend(vec![0; max_len - input.len()]);
|
||||
input
|
||||
}).
|
||||
map(|input|
|
||||
Tensor::of_slice(&(input))).
|
||||
collect::<Vec<_>>();
|
||||
})
|
||||
.map(|input| Tensor::of_slice(&(input)))
|
||||
.collect::<Vec<_>>();
|
||||
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
|
||||
|
||||
// Forward pass
|
||||
let (start_scores, end_scores, all_hidden_states, all_attentions) = no_grad(|| {
|
||||
albert_model
|
||||
.forward_t(Some(input_tensor),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
false)
|
||||
});
|
||||
let (start_scores, end_scores, all_hidden_states, all_attentions) =
|
||||
no_grad(|| albert_model.forward_t(Some(input_tensor), None, None, None, None, false));
|
||||
|
||||
assert_eq!(start_scores.size(), &[2, 12]);
|
||||
assert_eq!(end_scores.size(), &[2, 12]);
|
||||
assert_eq!(config.num_hidden_layers as usize, all_hidden_states.unwrap().len());
|
||||
assert_eq!(config.num_hidden_layers as usize, all_attentions.unwrap().len());
|
||||
assert_eq!(
|
||||
config.num_hidden_layers as usize,
|
||||
all_hidden_states.unwrap().len()
|
||||
);
|
||||
assert_eq!(
|
||||
config.num_hidden_layers as usize,
|
||||
all_attentions.unwrap().len()
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
@ -1,18 +1,25 @@
|
||||
use tch::{Device, nn, Tensor};
|
||||
use rust_tokenizers::{TruncationStrategy, Tokenizer, RobertaTokenizer};
|
||||
use rust_bert::Config;
|
||||
use rust_bert::bart::{BartConfig, BartConfigResources, BartVocabResources, BartModelResources, BartMergesResources, BartModel};
|
||||
use rust_bert::bart::{
|
||||
BartConfig, BartConfigResources, BartMergesResources, BartModel, BartModelResources,
|
||||
BartVocabResources,
|
||||
};
|
||||
use rust_bert::pipelines::summarization::{SummarizationConfig, SummarizationModel};
|
||||
use rust_bert::resources::{Resource, RemoteResource, download_resource};
|
||||
use rust_bert::resources::{download_resource, RemoteResource, Resource};
|
||||
use rust_bert::Config;
|
||||
use rust_tokenizers::{RobertaTokenizer, Tokenizer, TruncationStrategy};
|
||||
use tch::{nn, Device, Tensor};
|
||||
|
||||
#[test]
|
||||
#[cfg_attr(not(feature = "all-tests"), ignore)]
|
||||
fn bart_lm_model() -> failure::Fallible<()> {
|
||||
// Resources paths
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(BartConfigResources::BART));
|
||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(BartVocabResources::BART));
|
||||
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(BartMergesResources::BART));
|
||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(BartModelResources::BART));
|
||||
let config_resource =
|
||||
Resource::Remote(RemoteResource::from_pretrained(BartConfigResources::BART));
|
||||
let vocab_resource =
|
||||
Resource::Remote(RemoteResource::from_pretrained(BartVocabResources::BART));
|
||||
let merges_resource =
|
||||
Resource::Remote(RemoteResource::from_pretrained(BartMergesResources::BART));
|
||||
let weights_resource =
|
||||
Resource::Remote(RemoteResource::from_pretrained(BartModelResources::BART));
|
||||
let config_path = download_resource(&config_resource)?;
|
||||
let vocab_path = download_resource(&vocab_resource)?;
|
||||
let merges_path = download_resource(&merges_resource)?;
|
||||
@ -21,36 +28,38 @@ fn bart_lm_model() -> failure::Fallible<()> {
|
||||
// Set-up masked LM model
|
||||
let device = Device::Cpu;
|
||||
let mut vs = nn::VarStore::new(device);
|
||||
let tokenizer: RobertaTokenizer = RobertaTokenizer::from_file(vocab_path.to_str().unwrap(), merges_path.to_str().unwrap(), false);
|
||||
let tokenizer: RobertaTokenizer = RobertaTokenizer::from_file(
|
||||
vocab_path.to_str().unwrap(),
|
||||
merges_path.to_str().unwrap(),
|
||||
false,
|
||||
);
|
||||
let config = BartConfig::from_file(config_path);
|
||||
let bart_model = BartModel::new(&vs.root(), &config, false);
|
||||
vs.load(weights_path)?;
|
||||
|
||||
// Define input
|
||||
let input = ["One two three four"];
|
||||
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
||||
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
|
||||
let tokenized_input = tokenized_input.
|
||||
iter().
|
||||
map(|input| input.token_ids.clone()).
|
||||
map(|mut input| {
|
||||
let tokenized_input =
|
||||
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
||||
let max_len = tokenized_input
|
||||
.iter()
|
||||
.map(|input| input.token_ids.len())
|
||||
.max()
|
||||
.unwrap();
|
||||
let tokenized_input = tokenized_input
|
||||
.iter()
|
||||
.map(|input| input.token_ids.clone())
|
||||
.map(|mut input| {
|
||||
input.extend(vec![0; max_len - input.len()]);
|
||||
input
|
||||
}).
|
||||
map(|input|
|
||||
Tensor::of_slice(&(input))).
|
||||
collect::<Vec<_>>();
|
||||
})
|
||||
.map(|input| Tensor::of_slice(&(input)))
|
||||
.collect::<Vec<_>>();
|
||||
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
|
||||
|
||||
// Forward pass
|
||||
let (output, encoder_outputs, _, _, _, _, _) = bart_model.forward_t(
|
||||
Some(&input_tensor),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
false);
|
||||
let (output, encoder_outputs, _, _, _, _, _) =
|
||||
bart_model.forward_t(Some(&input_tensor), None, None, None, None, None, false);
|
||||
|
||||
assert_eq!(output.size(), vec!(1, 6, 1024));
|
||||
assert_eq!(encoder_outputs.size(), vec!(1, 6, 1024));
|
||||
@ -58,11 +67,9 @@ fn bart_lm_model() -> failure::Fallible<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
#[test]
|
||||
#[cfg_attr(not(feature = "all-tests"), ignore)]
|
||||
fn bart_summarization_greedy() -> failure::Fallible<()> {
|
||||
|
||||
// Set-up masked LM model
|
||||
let summarization_config = SummarizationConfig {
|
||||
num_beams: 1,
|
||||
@ -107,7 +114,6 @@ about exoplanets like K2-18b."];
|
||||
#[test]
|
||||
#[cfg_attr(not(feature = "all-tests"), ignore)]
|
||||
fn bart_summarization_beam_search() -> failure::Fallible<()> {
|
||||
|
||||
// Set-up masked LM model
|
||||
let summarization_config = SummarizationConfig {
|
||||
num_beams: 3,
|
||||
|
300
tests/bert.rs
300
tests/bert.rs
@ -1,22 +1,27 @@
|
||||
extern crate failure;
|
||||
extern crate dirs;
|
||||
extern crate failure;
|
||||
|
||||
use tch::{Device, nn, Tensor, no_grad};
|
||||
use rust_tokenizers::{BertTokenizer, TruncationStrategy, Tokenizer, Vocab};
|
||||
use rust_bert::Config;
|
||||
use rust_bert::bert::{BertConfig, BertForMaskedLM, BertForSequenceClassification, BertForMultipleChoice, BertForTokenClassification, BertForQuestionAnswering,
|
||||
BertConfigResources, BertVocabResources, BertModelResources};
|
||||
use rust_bert::bert::{
|
||||
BertConfig, BertConfigResources, BertForMaskedLM, BertForMultipleChoice,
|
||||
BertForQuestionAnswering, BertForSequenceClassification, BertForTokenClassification,
|
||||
BertModelResources, BertVocabResources,
|
||||
};
|
||||
use rust_bert::pipelines::ner::NERModel;
|
||||
use rust_bert::resources::{Resource, RemoteResource, download_resource};
|
||||
use rust_bert::resources::{download_resource, RemoteResource, Resource};
|
||||
use rust_bert::Config;
|
||||
use rust_tokenizers::{BertTokenizer, Tokenizer, TruncationStrategy, Vocab};
|
||||
use std::collections::HashMap;
|
||||
|
||||
use tch::{nn, no_grad, Device, Tensor};
|
||||
|
||||
#[test]
|
||||
fn bert_masked_lm() -> failure::Fallible<()> {
|
||||
// Resources paths
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT));
|
||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT));
|
||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(BertModelResources::BERT));
|
||||
let config_resource =
|
||||
Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT));
|
||||
let vocab_resource =
|
||||
Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT));
|
||||
let weights_resource =
|
||||
Resource::Remote(RemoteResource::from_pretrained(BertModelResources::BERT));
|
||||
let config_path = download_resource(&config_resource)?;
|
||||
let vocab_path = download_resource(&vocab_resource)?;
|
||||
let weights_path = download_resource(&weights_resource)?;
|
||||
@ -30,39 +35,47 @@ fn bert_masked_lm() -> failure::Fallible<()> {
|
||||
vs.load(weights_path)?;
|
||||
|
||||
// Define input
|
||||
let input = ["Looks like one thing is missing", "It\'s like comparing oranges to apples"];
|
||||
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
||||
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
|
||||
let mut tokenized_input = tokenized_input.
|
||||
iter().
|
||||
map(|input| input.token_ids.clone()).
|
||||
map(|mut input| {
|
||||
let input = [
|
||||
"Looks like one thing is missing",
|
||||
"It\'s like comparing oranges to apples",
|
||||
];
|
||||
let tokenized_input =
|
||||
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
||||
let max_len = tokenized_input
|
||||
.iter()
|
||||
.map(|input| input.token_ids.len())
|
||||
.max()
|
||||
.unwrap();
|
||||
let mut tokenized_input = tokenized_input
|
||||
.iter()
|
||||
.map(|input| input.token_ids.clone())
|
||||
.map(|mut input| {
|
||||
input.extend(vec![0; max_len - input.len()]);
|
||||
input
|
||||
}).
|
||||
collect::<Vec<_>>();
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
// Masking the token [thing] of sentence 1 and [oranges] of sentence 2
|
||||
tokenized_input[0][4] = 103;
|
||||
tokenized_input[1][6] = 103;
|
||||
let tokenized_input = tokenized_input.
|
||||
iter().
|
||||
map(|input|
|
||||
Tensor::of_slice(&(input))).
|
||||
collect::<Vec<_>>();
|
||||
let tokenized_input = tokenized_input
|
||||
.iter()
|
||||
.map(|input| Tensor::of_slice(&(input)))
|
||||
.collect::<Vec<_>>();
|
||||
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
|
||||
|
||||
// Forward pass
|
||||
let (output, _, _) = no_grad(|| {
|
||||
bert_model
|
||||
.forward_t(Some(input_tensor),
|
||||
bert_model.forward_t(
|
||||
Some(input_tensor),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
&None,
|
||||
&None,
|
||||
false)
|
||||
false,
|
||||
)
|
||||
});
|
||||
|
||||
// Print masked tokens
|
||||
@ -80,8 +93,10 @@ fn bert_masked_lm() -> failure::Fallible<()> {
|
||||
#[test]
|
||||
fn bert_for_sequence_classification() -> failure::Fallible<()> {
|
||||
// Resources paths
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT));
|
||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT));
|
||||
let config_resource =
|
||||
Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT));
|
||||
let vocab_resource =
|
||||
Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT));
|
||||
let config_path = download_resource(&config_resource)?;
|
||||
let vocab_path = download_resource(&vocab_resource)?;
|
||||
|
||||
@ -99,37 +114,42 @@ fn bert_for_sequence_classification() -> failure::Fallible<()> {
|
||||
config.output_hidden_states = Some(true);
|
||||
let bert_model = BertForSequenceClassification::new(&vs.root(), &config);
|
||||
|
||||
|
||||
// Define input
|
||||
let input = ["Looks like one thing is missing", "It\'s like comparing oranges to apples"];
|
||||
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
||||
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
|
||||
let tokenized_input = tokenized_input.
|
||||
iter().
|
||||
map(|input| input.token_ids.clone()).
|
||||
map(|mut input| {
|
||||
let input = [
|
||||
"Looks like one thing is missing",
|
||||
"It\'s like comparing oranges to apples",
|
||||
];
|
||||
let tokenized_input =
|
||||
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
||||
let max_len = tokenized_input
|
||||
.iter()
|
||||
.map(|input| input.token_ids.len())
|
||||
.max()
|
||||
.unwrap();
|
||||
let tokenized_input = tokenized_input
|
||||
.iter()
|
||||
.map(|input| input.token_ids.clone())
|
||||
.map(|mut input| {
|
||||
input.extend(vec![0; max_len - input.len()]);
|
||||
input
|
||||
}).
|
||||
map(|input|
|
||||
Tensor::of_slice(&(input))).
|
||||
collect::<Vec<_>>();
|
||||
})
|
||||
.map(|input| Tensor::of_slice(&(input)))
|
||||
.collect::<Vec<_>>();
|
||||
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
|
||||
|
||||
// Forward pass
|
||||
let (output, all_hidden_states, all_attentions) = no_grad(|| {
|
||||
bert_model
|
||||
.forward_t(Some(input_tensor),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
false)
|
||||
});
|
||||
let (output, all_hidden_states, all_attentions) =
|
||||
no_grad(|| bert_model.forward_t(Some(input_tensor), None, None, None, None, false));
|
||||
|
||||
assert_eq!(output.size(), &[2, 3]);
|
||||
assert_eq!(config.num_hidden_layers as usize, all_hidden_states.unwrap().len());
|
||||
assert_eq!(config.num_hidden_layers as usize, all_attentions.unwrap().len());
|
||||
assert_eq!(
|
||||
config.num_hidden_layers as usize,
|
||||
all_hidden_states.unwrap().len()
|
||||
);
|
||||
assert_eq!(
|
||||
config.num_hidden_layers as usize,
|
||||
all_attentions.unwrap().len()
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@ -137,8 +157,10 @@ fn bert_for_sequence_classification() -> failure::Fallible<()> {
|
||||
#[test]
|
||||
fn bert_for_multiple_choice() -> failure::Fallible<()> {
|
||||
// Resources paths
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT));
|
||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT));
|
||||
let config_resource =
|
||||
Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT));
|
||||
let vocab_resource =
|
||||
Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT));
|
||||
let config_path = download_resource(&config_resource)?;
|
||||
let vocab_path = download_resource(&vocab_resource)?;
|
||||
|
||||
@ -152,34 +174,43 @@ fn bert_for_multiple_choice() -> failure::Fallible<()> {
|
||||
let bert_model = BertForMultipleChoice::new(&vs.root(), &config);
|
||||
|
||||
// Define input
|
||||
let input = ["Looks like one thing is missing", "It\'s like comparing oranges to apples"];
|
||||
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
||||
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
|
||||
let tokenized_input = tokenized_input.
|
||||
iter().
|
||||
map(|input| input.token_ids.clone()).
|
||||
map(|mut input| {
|
||||
let input = [
|
||||
"Looks like one thing is missing",
|
||||
"It\'s like comparing oranges to apples",
|
||||
];
|
||||
let tokenized_input =
|
||||
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
||||
let max_len = tokenized_input
|
||||
.iter()
|
||||
.map(|input| input.token_ids.len())
|
||||
.max()
|
||||
.unwrap();
|
||||
let tokenized_input = tokenized_input
|
||||
.iter()
|
||||
.map(|input| input.token_ids.clone())
|
||||
.map(|mut input| {
|
||||
input.extend(vec![0; max_len - input.len()]);
|
||||
input
|
||||
}).
|
||||
map(|input|
|
||||
Tensor::of_slice(&(input))).
|
||||
collect::<Vec<_>>();
|
||||
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device).unsqueeze(0);
|
||||
})
|
||||
.map(|input| Tensor::of_slice(&(input)))
|
||||
.collect::<Vec<_>>();
|
||||
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0)
|
||||
.to(device)
|
||||
.unsqueeze(0);
|
||||
|
||||
// Forward pass
|
||||
let (output, all_hidden_states, all_attentions) = no_grad(|| {
|
||||
bert_model
|
||||
.forward_t(input_tensor,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
false)
|
||||
});
|
||||
let (output, all_hidden_states, all_attentions) =
|
||||
no_grad(|| bert_model.forward_t(input_tensor, None, None, None, false));
|
||||
|
||||
assert_eq!(output.size(), &[1, 2]);
|
||||
assert_eq!(config.num_hidden_layers as usize, all_hidden_states.unwrap().len());
|
||||
assert_eq!(config.num_hidden_layers as usize, all_attentions.unwrap().len());
|
||||
assert_eq!(
|
||||
config.num_hidden_layers as usize,
|
||||
all_hidden_states.unwrap().len()
|
||||
);
|
||||
assert_eq!(
|
||||
config.num_hidden_layers as usize,
|
||||
all_attentions.unwrap().len()
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@ -187,8 +218,10 @@ fn bert_for_multiple_choice() -> failure::Fallible<()> {
|
||||
#[test]
|
||||
fn bert_for_token_classification() -> failure::Fallible<()> {
|
||||
// Resources paths
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT));
|
||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT));
|
||||
let config_resource =
|
||||
Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT));
|
||||
let vocab_resource =
|
||||
Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT));
|
||||
let config_path = download_resource(&config_resource)?;
|
||||
let vocab_path = download_resource(&vocab_resource)?;
|
||||
|
||||
@ -207,37 +240,42 @@ fn bert_for_token_classification() -> failure::Fallible<()> {
|
||||
config.output_hidden_states = Some(true);
|
||||
let bert_model = BertForTokenClassification::new(&vs.root(), &config);
|
||||
|
||||
|
||||
// Define input
|
||||
let input = ["Looks like one thing is missing", "It\'s like comparing oranges to apples"];
|
||||
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
||||
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
|
||||
let tokenized_input = tokenized_input.
|
||||
iter().
|
||||
map(|input| input.token_ids.clone()).
|
||||
map(|mut input| {
|
||||
let input = [
|
||||
"Looks like one thing is missing",
|
||||
"It\'s like comparing oranges to apples",
|
||||
];
|
||||
let tokenized_input =
|
||||
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
||||
let max_len = tokenized_input
|
||||
.iter()
|
||||
.map(|input| input.token_ids.len())
|
||||
.max()
|
||||
.unwrap();
|
||||
let tokenized_input = tokenized_input
|
||||
.iter()
|
||||
.map(|input| input.token_ids.clone())
|
||||
.map(|mut input| {
|
||||
input.extend(vec![0; max_len - input.len()]);
|
||||
input
|
||||
}).
|
||||
map(|input|
|
||||
Tensor::of_slice(&(input))).
|
||||
collect::<Vec<_>>();
|
||||
})
|
||||
.map(|input| Tensor::of_slice(&(input)))
|
||||
.collect::<Vec<_>>();
|
||||
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
|
||||
|
||||
// Forward pass
|
||||
let (output, all_hidden_states, all_attentions) = no_grad(|| {
|
||||
bert_model
|
||||
.forward_t(Some(input_tensor),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
false)
|
||||
});
|
||||
let (output, all_hidden_states, all_attentions) =
|
||||
no_grad(|| bert_model.forward_t(Some(input_tensor), None, None, None, None, false));
|
||||
|
||||
assert_eq!(output.size(), &[2, 11, 4]);
|
||||
assert_eq!(config.num_hidden_layers as usize, all_hidden_states.unwrap().len());
|
||||
assert_eq!(config.num_hidden_layers as usize, all_attentions.unwrap().len());
|
||||
assert_eq!(
|
||||
config.num_hidden_layers as usize,
|
||||
all_hidden_states.unwrap().len()
|
||||
);
|
||||
assert_eq!(
|
||||
config.num_hidden_layers as usize,
|
||||
all_attentions.unwrap().len()
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@ -245,8 +283,10 @@ fn bert_for_token_classification() -> failure::Fallible<()> {
|
||||
#[test]
|
||||
fn bert_for_question_answering() -> failure::Fallible<()> {
|
||||
// Resources paths
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT));
|
||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT));
|
||||
let config_resource =
|
||||
Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT));
|
||||
let vocab_resource =
|
||||
Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT));
|
||||
let config_path = download_resource(&config_resource)?;
|
||||
let vocab_path = download_resource(&vocab_resource)?;
|
||||
|
||||
@ -260,36 +300,42 @@ fn bert_for_question_answering() -> failure::Fallible<()> {
|
||||
let bert_model = BertForQuestionAnswering::new(&vs.root(), &config);
|
||||
|
||||
// Define input
|
||||
let input = ["Looks like one thing is missing", "It\'s like comparing oranges to apples"];
|
||||
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
||||
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
|
||||
let tokenized_input = tokenized_input.
|
||||
iter().
|
||||
map(|input| input.token_ids.clone()).
|
||||
map(|mut input| {
|
||||
let input = [
|
||||
"Looks like one thing is missing",
|
||||
"It\'s like comparing oranges to apples",
|
||||
];
|
||||
let tokenized_input =
|
||||
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
||||
let max_len = tokenized_input
|
||||
.iter()
|
||||
.map(|input| input.token_ids.len())
|
||||
.max()
|
||||
.unwrap();
|
||||
let tokenized_input = tokenized_input
|
||||
.iter()
|
||||
.map(|input| input.token_ids.clone())
|
||||
.map(|mut input| {
|
||||
input.extend(vec![0; max_len - input.len()]);
|
||||
input
|
||||
}).
|
||||
map(|input|
|
||||
Tensor::of_slice(&(input))).
|
||||
collect::<Vec<_>>();
|
||||
})
|
||||
.map(|input| Tensor::of_slice(&(input)))
|
||||
.collect::<Vec<_>>();
|
||||
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
|
||||
|
||||
// Forward pass
|
||||
let (start_scores, end_scores, all_hidden_states, all_attentions) = no_grad(|| {
|
||||
bert_model
|
||||
.forward_t(Some(input_tensor),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
false)
|
||||
});
|
||||
let (start_scores, end_scores, all_hidden_states, all_attentions) =
|
||||
no_grad(|| bert_model.forward_t(Some(input_tensor), None, None, None, None, false));
|
||||
|
||||
assert_eq!(start_scores.size(), &[2, 11]);
|
||||
assert_eq!(end_scores.size(), &[2, 11]);
|
||||
assert_eq!(config.num_hidden_layers as usize, all_hidden_states.unwrap().len());
|
||||
assert_eq!(config.num_hidden_layers as usize, all_attentions.unwrap().len());
|
||||
assert_eq!(
|
||||
config.num_hidden_layers as usize,
|
||||
all_hidden_states.unwrap().len()
|
||||
);
|
||||
assert_eq!(
|
||||
config.num_hidden_layers as usize,
|
||||
all_attentions.unwrap().len()
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@ -302,7 +348,7 @@ fn bert_pre_trained_ner() -> failure::Fallible<()> {
|
||||
// Define input
|
||||
let input = [
|
||||
"My name is Amy. I live in Paris.",
|
||||
"Paris is a city in France."
|
||||
"Paris is a city in France.",
|
||||
];
|
||||
|
||||
// Run model
|
||||
|
@ -1,13 +1,17 @@
|
||||
use tch::{Device, Tensor, nn, no_grad};
|
||||
use rust_tokenizers::preprocessing::tokenizer::base_tokenizer::{Tokenizer, TruncationStrategy};
|
||||
use rust_tokenizers::bert_tokenizer::BertTokenizer;
|
||||
use rust_tokenizers::preprocessing::vocab::base_vocab::Vocab;
|
||||
use rust_bert::Config;
|
||||
use rust_bert::distilbert::{DistilBertConfig, DistilBertModelMaskedLM, DistilBertForQuestionAnswering, DistilBertForTokenClassification, DistilBertModelResources, DistilBertConfigResources, DistilBertVocabResources};
|
||||
use rust_bert::distilbert::{
|
||||
DistilBertConfig, DistilBertConfigResources, DistilBertForQuestionAnswering,
|
||||
DistilBertForTokenClassification, DistilBertModelMaskedLM, DistilBertModelResources,
|
||||
DistilBertVocabResources,
|
||||
};
|
||||
use rust_bert::pipelines::question_answering::{QaInput, QuestionAnsweringModel};
|
||||
use rust_bert::pipelines::sentiment::{SentimentModel, SentimentPolarity};
|
||||
use rust_bert::pipelines::question_answering::{QuestionAnsweringModel, QaInput};
|
||||
use rust_bert::resources::{Resource, RemoteResource, download_resource};
|
||||
use rust_bert::resources::{download_resource, RemoteResource, Resource};
|
||||
use rust_bert::Config;
|
||||
use rust_tokenizers::bert_tokenizer::BertTokenizer;
|
||||
use rust_tokenizers::preprocessing::tokenizer::base_tokenizer::{Tokenizer, TruncationStrategy};
|
||||
use rust_tokenizers::preprocessing::vocab::base_vocab::Vocab;
|
||||
use std::collections::HashMap;
|
||||
use tch::{nn, no_grad, Device, Tensor};
|
||||
|
||||
extern crate failure;
|
||||
|
||||
@ -36,13 +40,18 @@ fn distilbert_sentiment_classifier() -> failure::Fallible<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
#[test]
|
||||
fn distilbert_masked_lm() -> failure::Fallible<()> {
|
||||
// Resources paths
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(DistilBertConfigResources::DISTIL_BERT));
|
||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(DistilBertVocabResources::DISTIL_BERT));
|
||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(DistilBertModelResources::DISTIL_BERT));
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
DistilBertConfigResources::DISTIL_BERT,
|
||||
));
|
||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
DistilBertVocabResources::DISTIL_BERT,
|
||||
));
|
||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
DistilBertModelResources::DISTIL_BERT,
|
||||
));
|
||||
let config_path = download_resource(&config_resource)?;
|
||||
let vocab_path = download_resource(&vocab_resource)?;
|
||||
let weights_path = download_resource(&weights_resource)?;
|
||||
@ -56,29 +65,35 @@ fn distilbert_masked_lm() -> failure::Fallible<()> {
|
||||
vs.load(weights_path)?;
|
||||
|
||||
// Define input
|
||||
let input = ["Looks like one thing is missing", "It\'s like comparing oranges to apples"];
|
||||
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
||||
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
|
||||
let mut tokenized_input = tokenized_input.
|
||||
iter().
|
||||
map(|input| input.token_ids.clone()).
|
||||
map(|mut input| {
|
||||
let input = [
|
||||
"Looks like one thing is missing",
|
||||
"It\'s like comparing oranges to apples",
|
||||
];
|
||||
let tokenized_input =
|
||||
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
||||
let max_len = tokenized_input
|
||||
.iter()
|
||||
.map(|input| input.token_ids.len())
|
||||
.max()
|
||||
.unwrap();
|
||||
let mut tokenized_input = tokenized_input
|
||||
.iter()
|
||||
.map(|input| input.token_ids.clone())
|
||||
.map(|mut input| {
|
||||
input.extend(vec![0; max_len - input.len()]);
|
||||
input
|
||||
}).
|
||||
collect::<Vec<_>>();
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
// Masking the token [thing] of sentence 1 and [oranges] of sentence 2
|
||||
tokenized_input[0][4] = 103;
|
||||
tokenized_input[1][6] = 103;
|
||||
let tokenized_input = tokenized_input.
|
||||
iter().
|
||||
map(|input|
|
||||
Tensor::of_slice(&(input))).
|
||||
collect::<Vec<_>>();
|
||||
let tokenized_input = tokenized_input
|
||||
.iter()
|
||||
.map(|input| Tensor::of_slice(&(input)))
|
||||
.collect::<Vec<_>>();
|
||||
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
|
||||
|
||||
|
||||
// Forward pass
|
||||
let (output, _, _) = no_grad(|| {
|
||||
distil_bert_model
|
||||
@ -100,10 +115,13 @@ fn distilbert_masked_lm() -> failure::Fallible<()> {
|
||||
|
||||
#[test]
|
||||
fn distilbert_for_question_answering() -> failure::Fallible<()> {
|
||||
|
||||
// Resources paths
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(DistilBertConfigResources::DISTIL_BERT_SQUAD));
|
||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(DistilBertVocabResources::DISTIL_BERT_SQUAD));
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
DistilBertConfigResources::DISTIL_BERT_SQUAD,
|
||||
));
|
||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
DistilBertVocabResources::DISTIL_BERT_SQUAD,
|
||||
));
|
||||
let config_path = download_resource(&config_resource)?;
|
||||
let vocab_path = download_resource(&vocab_resource)?;
|
||||
|
||||
@ -117,19 +135,26 @@ fn distilbert_for_question_answering() -> failure::Fallible<()> {
|
||||
let distil_bert_model = DistilBertForQuestionAnswering::new(&vs.root(), &config);
|
||||
|
||||
// Define input
|
||||
let input = ["Looks like one thing is missing", "It\'s like comparing oranges to apples"];
|
||||
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
||||
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
|
||||
let tokenized_input = tokenized_input.
|
||||
iter().
|
||||
map(|input| input.token_ids.clone()).
|
||||
map(|mut input| {
|
||||
let input = [
|
||||
"Looks like one thing is missing",
|
||||
"It\'s like comparing oranges to apples",
|
||||
];
|
||||
let tokenized_input =
|
||||
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
||||
let max_len = tokenized_input
|
||||
.iter()
|
||||
.map(|input| input.token_ids.len())
|
||||
.max()
|
||||
.unwrap();
|
||||
let tokenized_input = tokenized_input
|
||||
.iter()
|
||||
.map(|input| input.token_ids.clone())
|
||||
.map(|mut input| {
|
||||
input.extend(vec![0; max_len - input.len()]);
|
||||
input
|
||||
}).
|
||||
map(|input|
|
||||
Tensor::of_slice(&(input))).
|
||||
collect::<Vec<_>>();
|
||||
})
|
||||
.map(|input| Tensor::of_slice(&(input)))
|
||||
.collect::<Vec<_>>();
|
||||
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
|
||||
|
||||
// Forward pass
|
||||
@ -149,10 +174,13 @@ fn distilbert_for_question_answering() -> failure::Fallible<()> {
|
||||
|
||||
#[test]
|
||||
fn distilbert_for_token_classification() -> failure::Fallible<()> {
|
||||
|
||||
// Resources paths
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(DistilBertConfigResources::DISTIL_BERT));
|
||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(DistilBertVocabResources::DISTIL_BERT));
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
DistilBertConfigResources::DISTIL_BERT,
|
||||
));
|
||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
DistilBertVocabResources::DISTIL_BERT,
|
||||
));
|
||||
let config_path = download_resource(&config_resource)?;
|
||||
let vocab_path = download_resource(&vocab_resource)?;
|
||||
|
||||
@ -172,19 +200,26 @@ fn distilbert_for_token_classification() -> failure::Fallible<()> {
|
||||
let distil_bert_model = DistilBertForTokenClassification::new(&vs.root(), &config);
|
||||
|
||||
// Define input
|
||||
let input = ["Looks like one thing is missing", "It\'s like comparing oranges to apples"];
|
||||
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
||||
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
|
||||
let tokenized_input = tokenized_input.
|
||||
iter().
|
||||
map(|input| input.token_ids.clone()).
|
||||
map(|mut input| {
|
||||
let input = [
|
||||
"Looks like one thing is missing",
|
||||
"It\'s like comparing oranges to apples",
|
||||
];
|
||||
let tokenized_input =
|
||||
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
||||
let max_len = tokenized_input
|
||||
.iter()
|
||||
.map(|input| input.token_ids.len())
|
||||
.max()
|
||||
.unwrap();
|
||||
let tokenized_input = tokenized_input
|
||||
.iter()
|
||||
.map(|input| input.token_ids.clone())
|
||||
.map(|mut input| {
|
||||
input.extend(vec![0; max_len - input.len()]);
|
||||
input
|
||||
}).
|
||||
map(|input|
|
||||
Tensor::of_slice(&(input))).
|
||||
collect::<Vec<_>>();
|
||||
})
|
||||
.map(|input| Tensor::of_slice(&(input)))
|
||||
.collect::<Vec<_>>();
|
||||
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
|
||||
|
||||
// Forward pass
|
||||
@ -211,7 +246,7 @@ fn distilbert_question_answering() -> failure::Fallible<()> {
|
||||
let context = String::from("Amy lives in Amsterdam");
|
||||
let qa_input = QaInput { question, context };
|
||||
|
||||
let answers = qa_model.predict(&vec!(qa_input), 1, 32);
|
||||
let answers = qa_model.predict(&vec![qa_input], 1, 32);
|
||||
|
||||
assert_eq!(answers.len(), 1 as usize);
|
||||
assert_eq!(answers[0].len(), 1 as usize);
|
||||
|
@ -1,17 +1,28 @@
|
||||
use tch::{Device, nn, Tensor};
|
||||
use rust_tokenizers::{Gpt2Tokenizer, TruncationStrategy, Tokenizer};
|
||||
use rust_bert::gpt2::{
|
||||
GPT2LMHeadModel, Gpt2Config, Gpt2ConfigResources, Gpt2MergesResources, Gpt2ModelResources,
|
||||
Gpt2VocabResources,
|
||||
};
|
||||
use rust_bert::pipelines::generation::{Cache, LMHeadModel};
|
||||
use rust_bert::resources::{download_resource, RemoteResource, Resource};
|
||||
use rust_bert::Config;
|
||||
use rust_bert::gpt2::{Gpt2Config, GPT2LMHeadModel, Gpt2ConfigResources, Gpt2VocabResources, Gpt2MergesResources, Gpt2ModelResources};
|
||||
use rust_bert::pipelines::generation::{LMHeadModel, Cache};
|
||||
use rust_bert::resources::{Resource, download_resource, RemoteResource};
|
||||
use rust_tokenizers::{Gpt2Tokenizer, Tokenizer, TruncationStrategy};
|
||||
use tch::{nn, Device, Tensor};
|
||||
|
||||
#[test]
|
||||
fn distilgpt2_lm_model() -> failure::Fallible<()> {
|
||||
// Resources paths
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::DISTIL_GPT2));
|
||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::DISTIL_GPT2));
|
||||
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::DISTIL_GPT2));
|
||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::DISTIL_GPT2));
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
Gpt2ConfigResources::DISTIL_GPT2,
|
||||
));
|
||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
Gpt2VocabResources::DISTIL_GPT2,
|
||||
));
|
||||
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
Gpt2MergesResources::DISTIL_GPT2,
|
||||
));
|
||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
Gpt2ModelResources::DISTIL_GPT2,
|
||||
));
|
||||
let config_path = download_resource(&config_resource)?;
|
||||
let vocab_path = download_resource(&vocab_resource)?;
|
||||
let merges_path = download_resource(&merges_resource)?;
|
||||
@ -20,29 +31,38 @@ fn distilgpt2_lm_model() -> failure::Fallible<()> {
|
||||
// Set-up masked LM model
|
||||
let device = Device::Cpu;
|
||||
let mut vs = nn::VarStore::new(device);
|
||||
let tokenizer: Gpt2Tokenizer = Gpt2Tokenizer::from_file(vocab_path.to_str().unwrap(), merges_path.to_str().unwrap(), false);
|
||||
let tokenizer: Gpt2Tokenizer = Gpt2Tokenizer::from_file(
|
||||
vocab_path.to_str().unwrap(),
|
||||
merges_path.to_str().unwrap(),
|
||||
false,
|
||||
);
|
||||
let config = Gpt2Config::from_file(config_path);
|
||||
let gpt2_model = GPT2LMHeadModel::new(&vs.root(), &config);
|
||||
vs.load(weights_path)?;
|
||||
|
||||
// Define input
|
||||
let input = ["One two three four five six seven eight nine ten eleven"];
|
||||
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
||||
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
|
||||
let tokenized_input = tokenized_input.
|
||||
iter().
|
||||
map(|input| input.token_ids.clone()).
|
||||
map(|mut input| {
|
||||
let tokenized_input =
|
||||
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
||||
let max_len = tokenized_input
|
||||
.iter()
|
||||
.map(|input| input.token_ids.len())
|
||||
.max()
|
||||
.unwrap();
|
||||
let tokenized_input = tokenized_input
|
||||
.iter()
|
||||
.map(|input| input.token_ids.clone())
|
||||
.map(|mut input| {
|
||||
input.extend(vec![0; max_len - input.len()]);
|
||||
input
|
||||
}).
|
||||
map(|input|
|
||||
Tensor::of_slice(&(input))).
|
||||
collect::<Vec<_>>();
|
||||
})
|
||||
.map(|input| Tensor::of_slice(&(input)))
|
||||
.collect::<Vec<_>>();
|
||||
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
|
||||
|
||||
// Forward pass
|
||||
let (output, _, past, _, _) = gpt2_model.forward_t(
|
||||
let (output, _, past, _, _) = gpt2_model
|
||||
.forward_t(
|
||||
&Some(input_tensor),
|
||||
Cache::None,
|
||||
&None,
|
||||
@ -51,21 +71,28 @@ fn distilgpt2_lm_model() -> failure::Fallible<()> {
|
||||
&None,
|
||||
None,
|
||||
&None,
|
||||
false).unwrap();
|
||||
false,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let next_word_id = output.get(0).get(-1).argmax(-1, true).int64_value(&[0]);
|
||||
let next_word = tokenizer.decode(vec!(next_word_id), true, true);
|
||||
let next_word = tokenizer.decode(vec![next_word_id], true, true);
|
||||
|
||||
assert_eq!(output.size(), vec!(1, 11, 50257));
|
||||
match past {
|
||||
Cache::GPT2Cache(past) => {
|
||||
assert!(past.is_some());
|
||||
assert_eq!(past.as_ref().unwrap().len(), config.n_layer as usize);
|
||||
assert_eq!(past.as_ref().unwrap()[0].size(), vec!(2, 1, config.n_head, 11, 64));
|
||||
assert_eq!(
|
||||
past.as_ref().unwrap()[0].size(),
|
||||
vec!(2, 1, config.n_head, 11, 64)
|
||||
);
|
||||
}
|
||||
_ => panic!("Wrong cache returned for GPT2")
|
||||
_ => panic!("Wrong cache returned for GPT2"),
|
||||
}
|
||||
assert!((output.double_value(&[0, output.size()[1] - 1, next_word_id]) - (-48.7065)).abs() < 1e-4);
|
||||
assert!(
|
||||
(output.double_value(&[0, output.size()[1] - 1, next_word_id]) - (-48.7065)).abs() < 1e-4
|
||||
);
|
||||
assert_eq!(next_word_id, 14104i64);
|
||||
assert_eq!(next_word, String::from(" twelve"));
|
||||
|
||||
|
126
tests/electra.rs
126
tests/electra.rs
@ -1,15 +1,24 @@
|
||||
use rust_bert::resources::{Resource, download_resource, RemoteResource};
|
||||
use rust_bert::electra::{ElectraConfigResources, ElectraVocabResources, ElectraModelResources, ElectraConfig, ElectraForMaskedLM, ElectraDiscriminator};
|
||||
use tch::{Device, nn, Tensor, no_grad};
|
||||
use rust_tokenizers::{BertTokenizer, TruncationStrategy, Tokenizer, Vocab};
|
||||
use rust_bert::electra::{
|
||||
ElectraConfig, ElectraConfigResources, ElectraDiscriminator, ElectraForMaskedLM,
|
||||
ElectraModelResources, ElectraVocabResources,
|
||||
};
|
||||
use rust_bert::resources::{download_resource, RemoteResource, Resource};
|
||||
use rust_bert::Config;
|
||||
use rust_tokenizers::{BertTokenizer, Tokenizer, TruncationStrategy, Vocab};
|
||||
use tch::{nn, no_grad, Device, Tensor};
|
||||
|
||||
#[test]
|
||||
fn electra_masked_lm() -> failure::Fallible<()> {
|
||||
// Resources paths
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(ElectraConfigResources::BASE_GENERATOR));
|
||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(ElectraVocabResources::BASE_GENERATOR));
|
||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(ElectraModelResources::BASE_GENERATOR));
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
ElectraConfigResources::BASE_GENERATOR,
|
||||
));
|
||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
ElectraVocabResources::BASE_GENERATOR,
|
||||
));
|
||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
ElectraModelResources::BASE_GENERATOR,
|
||||
));
|
||||
let config_path = download_resource(&config_resource)?;
|
||||
let vocab_path = download_resource(&vocab_resource)?;
|
||||
let weights_path = download_resource(&weights_resource)?;
|
||||
@ -25,33 +34,31 @@ fn electra_masked_lm() -> failure::Fallible<()> {
|
||||
vs.load(weights_path)?;
|
||||
|
||||
// Define input
|
||||
let input = ["Looks like one [MASK] is missing", "It was a very nice and [MASK] day"];
|
||||
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
||||
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
|
||||
let tokenized_input = tokenized_input.
|
||||
iter().
|
||||
map(|input| input.token_ids.clone()).
|
||||
map(|mut input| {
|
||||
let input = [
|
||||
"Looks like one [MASK] is missing",
|
||||
"It was a very nice and [MASK] day",
|
||||
];
|
||||
let tokenized_input =
|
||||
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
||||
let max_len = tokenized_input
|
||||
.iter()
|
||||
.map(|input| input.token_ids.len())
|
||||
.max()
|
||||
.unwrap();
|
||||
let tokenized_input = tokenized_input
|
||||
.iter()
|
||||
.map(|input| input.token_ids.clone())
|
||||
.map(|mut input| {
|
||||
input.extend(vec![0; max_len - input.len()]);
|
||||
input
|
||||
}).
|
||||
map(|input|
|
||||
Tensor::of_slice(&(input))).
|
||||
collect::<Vec<_>>();
|
||||
})
|
||||
.map(|input| Tensor::of_slice(&(input)))
|
||||
.collect::<Vec<_>>();
|
||||
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
|
||||
|
||||
// Forward pass
|
||||
let (output,
|
||||
all_hidden_states,
|
||||
all_attentions) = no_grad(|| {
|
||||
electra_model
|
||||
.forward_t(Some(input_tensor),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
false)
|
||||
});
|
||||
let (output, all_hidden_states, all_attentions) =
|
||||
no_grad(|| electra_model.forward_t(Some(input_tensor), None, None, None, None, false));
|
||||
|
||||
// Decode output
|
||||
let index_1 = output.get(0).get(4).argmax(0, false);
|
||||
@ -60,8 +67,14 @@ fn electra_masked_lm() -> failure::Fallible<()> {
|
||||
let word_2 = tokenizer.vocab().id_to_token(&index_2.int64_value(&[]));
|
||||
|
||||
assert_eq!(output.size(), &[2, 10, config.vocab_size]);
|
||||
assert_eq!(config.num_hidden_layers as usize, all_hidden_states.unwrap().len());
|
||||
assert_eq!(config.num_hidden_layers as usize, all_attentions.unwrap().len());
|
||||
assert_eq!(
|
||||
config.num_hidden_layers as usize,
|
||||
all_hidden_states.unwrap().len()
|
||||
);
|
||||
assert_eq!(
|
||||
config.num_hidden_layers as usize,
|
||||
all_attentions.unwrap().len()
|
||||
);
|
||||
assert_eq!("thing", word_1); // Outputs "person" : "Looks like one [person] is missing"
|
||||
assert_eq!("sunny", word_2); // Outputs "pear" : "It was a very nice and [sunny] day"
|
||||
Ok(())
|
||||
@ -70,9 +83,15 @@ fn electra_masked_lm() -> failure::Fallible<()> {
|
||||
#[test]
|
||||
fn electra_discriminator() -> failure::Fallible<()> {
|
||||
// Resources paths
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(ElectraConfigResources::BASE_DISCRIMINATOR));
|
||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(ElectraVocabResources::BASE_DISCRIMINATOR));
|
||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(ElectraModelResources::BASE_DISCRIMINATOR));
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
ElectraConfigResources::BASE_DISCRIMINATOR,
|
||||
));
|
||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
ElectraVocabResources::BASE_DISCRIMINATOR,
|
||||
));
|
||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
ElectraModelResources::BASE_DISCRIMINATOR,
|
||||
));
|
||||
let config_path = download_resource(&config_resource)?;
|
||||
let vocab_path = download_resource(&vocab_resource)?;
|
||||
let weights_path = download_resource(&weights_resource)?;
|
||||
@ -87,33 +106,32 @@ fn electra_discriminator() -> failure::Fallible<()> {
|
||||
|
||||
// Define input
|
||||
let input = ["One Two Three Ten Five Six Seven Eight"];
|
||||
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
||||
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
|
||||
let encoded_input = tokenized_input.
|
||||
iter().
|
||||
map(|input| input.token_ids.clone()).
|
||||
map(|mut input| {
|
||||
let tokenized_input =
|
||||
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
||||
let max_len = tokenized_input
|
||||
.iter()
|
||||
.map(|input| input.token_ids.len())
|
||||
.max()
|
||||
.unwrap();
|
||||
let encoded_input = tokenized_input
|
||||
.iter()
|
||||
.map(|input| input.token_ids.clone())
|
||||
.map(|mut input| {
|
||||
input.extend(vec![0; max_len - input.len()]);
|
||||
input
|
||||
}).
|
||||
map(|input|
|
||||
Tensor::of_slice(&(input))).
|
||||
collect::<Vec<_>>();
|
||||
})
|
||||
.map(|input| Tensor::of_slice(&(input)))
|
||||
.collect::<Vec<_>>();
|
||||
let input_tensor = Tensor::stack(encoded_input.as_slice(), 0).to(device);
|
||||
|
||||
// Forward pass
|
||||
let (output, _, _) = no_grad(|| {
|
||||
electra_model
|
||||
.forward_t(Some(input_tensor),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
false)
|
||||
});
|
||||
let (output, _, _) =
|
||||
no_grad(|| electra_model.forward_t(Some(input_tensor), None, None, None, None, false));
|
||||
|
||||
// Validate model predictions
|
||||
let expected_probabilities = vec!(0.0101, 0.0030, 0.0010, 0.0018, 0.9489, 0.0067, 0.0026, 0.0017, 0.0311, 0.0101);
|
||||
let expected_probabilities = vec![
|
||||
0.0101, 0.0030, 0.0010, 0.0018, 0.9489, 0.0067, 0.0026, 0.0017, 0.0311, 0.0101,
|
||||
];
|
||||
let probabilities = output.iter::<f64>().unwrap().collect::<Vec<f64>>();
|
||||
|
||||
assert_eq!(output.size(), &[10]);
|
||||
|
209
tests/gpt2.rs
209
tests/gpt2.rs
@ -1,17 +1,26 @@
|
||||
use tch::{Device, nn, Tensor};
|
||||
use rust_tokenizers::{Gpt2Tokenizer, TruncationStrategy, Tokenizer};
|
||||
use rust_bert::gpt2::{
|
||||
GPT2LMHeadModel, Gpt2Config, Gpt2ConfigResources, Gpt2MergesResources, Gpt2ModelResources,
|
||||
Gpt2VocabResources,
|
||||
};
|
||||
use rust_bert::pipelines::generation::{
|
||||
Cache, GPT2Generator, GenerateConfig, LMHeadModel, LanguageGenerator,
|
||||
};
|
||||
use rust_bert::resources::{download_resource, RemoteResource, Resource};
|
||||
use rust_bert::Config;
|
||||
use rust_bert::pipelines::generation::{GPT2Generator, LanguageGenerator, GenerateConfig, LMHeadModel, Cache};
|
||||
use rust_bert::gpt2::{Gpt2Config, GPT2LMHeadModel, Gpt2ConfigResources, Gpt2MergesResources, Gpt2VocabResources, Gpt2ModelResources};
|
||||
use rust_bert::resources::{RemoteResource, Resource, download_resource};
|
||||
use rust_tokenizers::{Gpt2Tokenizer, Tokenizer, TruncationStrategy};
|
||||
use tch::{nn, Device, Tensor};
|
||||
|
||||
#[test]
|
||||
fn gpt2_lm_model() -> failure::Fallible<()> {
|
||||
// Resources paths
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2));
|
||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2));
|
||||
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
|
||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
|
||||
let config_resource =
|
||||
Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2));
|
||||
let vocab_resource =
|
||||
Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2));
|
||||
let merges_resource =
|
||||
Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
|
||||
let weights_resource =
|
||||
Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
|
||||
let config_path = download_resource(&config_resource)?;
|
||||
let vocab_path = download_resource(&vocab_resource)?;
|
||||
let merges_path = download_resource(&merges_resource)?;
|
||||
@ -20,29 +29,38 @@ fn gpt2_lm_model() -> failure::Fallible<()> {
|
||||
// Set-up masked LM model
|
||||
let device = Device::Cpu;
|
||||
let mut vs = nn::VarStore::new(device);
|
||||
let tokenizer: Gpt2Tokenizer = Gpt2Tokenizer::from_file(vocab_path.to_str().unwrap(), merges_path.to_str().unwrap(), false);
|
||||
let tokenizer: Gpt2Tokenizer = Gpt2Tokenizer::from_file(
|
||||
vocab_path.to_str().unwrap(),
|
||||
merges_path.to_str().unwrap(),
|
||||
false,
|
||||
);
|
||||
let config = Gpt2Config::from_file(config_path);
|
||||
let gpt2_model = GPT2LMHeadModel::new(&vs.root(), &config);
|
||||
vs.load(weights_path)?;
|
||||
|
||||
// Define input
|
||||
let input = ["One two three four"];
|
||||
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
||||
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
|
||||
let tokenized_input = tokenized_input.
|
||||
iter().
|
||||
map(|input| input.token_ids.clone()).
|
||||
map(|mut input| {
|
||||
let tokenized_input =
|
||||
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
||||
let max_len = tokenized_input
|
||||
.iter()
|
||||
.map(|input| input.token_ids.len())
|
||||
.max()
|
||||
.unwrap();
|
||||
let tokenized_input = tokenized_input
|
||||
.iter()
|
||||
.map(|input| input.token_ids.clone())
|
||||
.map(|mut input| {
|
||||
input.extend(vec![0; max_len - input.len()]);
|
||||
input
|
||||
}).
|
||||
map(|input|
|
||||
Tensor::of_slice(&(input))).
|
||||
collect::<Vec<_>>();
|
||||
})
|
||||
.map(|input| Tensor::of_slice(&(input)))
|
||||
.collect::<Vec<_>>();
|
||||
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
|
||||
|
||||
// Forward pass
|
||||
let (output, _, past, _, _) = gpt2_model.forward_t(
|
||||
let (output, _, past, _, _) = gpt2_model
|
||||
.forward_t(
|
||||
&Some(input_tensor),
|
||||
Cache::None,
|
||||
&None,
|
||||
@ -51,35 +69,45 @@ fn gpt2_lm_model() -> failure::Fallible<()> {
|
||||
&None,
|
||||
None,
|
||||
&None,
|
||||
false).unwrap();
|
||||
false,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let next_word_id = output.get(0).get(-1).argmax(-1, true).int64_value(&[0]);
|
||||
let next_word = tokenizer.decode(vec!(next_word_id), true, true);
|
||||
let next_word = tokenizer.decode(vec![next_word_id], true, true);
|
||||
|
||||
assert_eq!(output.size(), vec!(1, 4, 50257));
|
||||
match past {
|
||||
Cache::GPT2Cache(past) => {
|
||||
assert!(past.is_some());
|
||||
assert_eq!(past.as_ref().unwrap().len(), config.n_layer as usize);
|
||||
assert_eq!(past.as_ref().unwrap()[0].size(), vec!(2, 1, config.n_head, 4, 64));
|
||||
assert_eq!(
|
||||
past.as_ref().unwrap()[0].size(),
|
||||
vec!(2, 1, config.n_head, 4, 64)
|
||||
);
|
||||
}
|
||||
_ => panic!("Wrong cache returned for GPT2")
|
||||
_ => panic!("Wrong cache returned for GPT2"),
|
||||
}
|
||||
assert!((output.double_value(&[0, output.size()[1] - 1, next_word_id]) - (-69.4948)).abs() < 1e-4);
|
||||
assert!(
|
||||
(output.double_value(&[0, output.size()[1] - 1, next_word_id]) - (-69.4948)).abs() < 1e-4
|
||||
);
|
||||
assert_eq!(next_word_id, 1936i64);
|
||||
assert_eq!(next_word, String::from(" five"));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
#[test]
|
||||
fn gpt2_generation_greedy() -> failure::Fallible<()> {
|
||||
// Resources definition
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2));
|
||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2));
|
||||
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
|
||||
let model_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
|
||||
let config_resource =
|
||||
Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2));
|
||||
let vocab_resource =
|
||||
Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2));
|
||||
let merges_resource =
|
||||
Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
|
||||
let model_resource =
|
||||
Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
|
||||
|
||||
// Set-up masked LM model
|
||||
let generate_config = GenerateConfig {
|
||||
@ -97,7 +125,7 @@ fn gpt2_generation_greedy() -> failure::Fallible<()> {
|
||||
let model = GPT2Generator::new(generate_config)?;
|
||||
|
||||
let input_context = "The cat";
|
||||
let output = model.generate(Some(vec!(input_context)), None);
|
||||
let output = model.generate(Some(vec![input_context]), None);
|
||||
|
||||
assert_eq!(output.len(), 1);
|
||||
assert_eq!(output[0], "The cat was found in a field near the town of Keflavik, about 30 miles (48 kilometers) south-east of Moscow.\n\n\n");
|
||||
@ -108,10 +136,14 @@ fn gpt2_generation_greedy() -> failure::Fallible<()> {
|
||||
#[test]
|
||||
fn gpt2_generation_beam_search() -> failure::Fallible<()> {
|
||||
// Resources definition
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2));
|
||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2));
|
||||
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
|
||||
let model_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
|
||||
let config_resource =
|
||||
Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2));
|
||||
let vocab_resource =
|
||||
Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2));
|
||||
let merges_resource =
|
||||
Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
|
||||
let model_resource =
|
||||
Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
|
||||
|
||||
// Set-up masked LM model
|
||||
let generate_config = GenerateConfig {
|
||||
@ -129,12 +161,21 @@ fn gpt2_generation_beam_search() -> failure::Fallible<()> {
|
||||
let model = GPT2Generator::new(generate_config)?;
|
||||
|
||||
let input_context = "The dog";
|
||||
let output = model.generate(Some(vec!(input_context)), None);
|
||||
let output = model.generate(Some(vec![input_context]), None);
|
||||
|
||||
assert_eq!(output.len(), 3);
|
||||
assert_eq!(output[0], "The dog was found in the backyard of a home in the 6200 block of South Main Street.");
|
||||
assert_eq!(output[1], "The dog was found in the backyard of a home in the 6500 block of South Main Street.");
|
||||
assert_eq!(output[2], "The dog was found in the backyard of a home in the 6200 block of South Main Street,");
|
||||
assert_eq!(
|
||||
output[0],
|
||||
"The dog was found in the backyard of a home in the 6200 block of South Main Street."
|
||||
);
|
||||
assert_eq!(
|
||||
output[1],
|
||||
"The dog was found in the backyard of a home in the 6500 block of South Main Street."
|
||||
);
|
||||
assert_eq!(
|
||||
output[2],
|
||||
"The dog was found in the backyard of a home in the 6200 block of South Main Street,"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@ -142,10 +183,14 @@ fn gpt2_generation_beam_search() -> failure::Fallible<()> {
|
||||
#[test]
|
||||
fn gpt2_generation_beam_search_multiple_prompts_without_padding() -> failure::Fallible<()> {
|
||||
// Resources definition
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2));
|
||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2));
|
||||
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
|
||||
let model_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
|
||||
let config_resource =
|
||||
Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2));
|
||||
let vocab_resource =
|
||||
Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2));
|
||||
let merges_resource =
|
||||
Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
|
||||
let model_resource =
|
||||
Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
|
||||
|
||||
// Set-up masked LM model
|
||||
let generate_config = GenerateConfig {
|
||||
@ -164,15 +209,33 @@ fn gpt2_generation_beam_search_multiple_prompts_without_padding() -> failure::Fa
|
||||
|
||||
let input_context_1 = "The dog";
|
||||
let input_context_2 = "The cat";
|
||||
let output = model.generate(Some(vec!(input_context_1, input_context_2)), None);
|
||||
let output = model.generate(Some(vec![input_context_1, input_context_2]), None);
|
||||
|
||||
assert_eq!(output.len(), 6);
|
||||
assert_eq!(output[0], "The dog was found in the backyard of a home in the 6200 block of South Main Street.");
|
||||
assert_eq!(output[1], "The dog was found in the backyard of a home in the 6500 block of South Main Street.");
|
||||
assert_eq!(output[2], "The dog was found in the backyard of a home in the 6200 block of South Main Street,");
|
||||
assert_eq!(output[3], "The cat-and-mouse game.\n\n\"I think it\'s going to be interesting to");
|
||||
assert_eq!(output[4], "The cat-and-mouse game.\n\n\"I think it\'s going to be a very");
|
||||
assert_eq!(output[5], "The cat-and-mouse game.\n\n\"I think it\'s going to be very interesting");
|
||||
assert_eq!(
|
||||
output[0],
|
||||
"The dog was found in the backyard of a home in the 6200 block of South Main Street."
|
||||
);
|
||||
assert_eq!(
|
||||
output[1],
|
||||
"The dog was found in the backyard of a home in the 6500 block of South Main Street."
|
||||
);
|
||||
assert_eq!(
|
||||
output[2],
|
||||
"The dog was found in the backyard of a home in the 6200 block of South Main Street,"
|
||||
);
|
||||
assert_eq!(
|
||||
output[3],
|
||||
"The cat-and-mouse game.\n\n\"I think it\'s going to be interesting to"
|
||||
);
|
||||
assert_eq!(
|
||||
output[4],
|
||||
"The cat-and-mouse game.\n\n\"I think it\'s going to be a very"
|
||||
);
|
||||
assert_eq!(
|
||||
output[5],
|
||||
"The cat-and-mouse game.\n\n\"I think it\'s going to be very interesting"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@ -180,10 +243,14 @@ fn gpt2_generation_beam_search_multiple_prompts_without_padding() -> failure::Fa
|
||||
#[test]
|
||||
fn gpt2_generation_beam_search_multiple_prompts_with_padding() -> failure::Fallible<()> {
|
||||
// Resources definition
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2));
|
||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2));
|
||||
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
|
||||
let model_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
|
||||
let config_resource =
|
||||
Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2));
|
||||
let vocab_resource =
|
||||
Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2));
|
||||
let merges_resource =
|
||||
Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
|
||||
let model_resource =
|
||||
Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
|
||||
|
||||
// Set-up masked LM model
|
||||
let generate_config = GenerateConfig {
|
||||
@ -202,15 +269,33 @@ fn gpt2_generation_beam_search_multiple_prompts_with_padding() -> failure::Falli
|
||||
|
||||
let input_context_1 = "The dog";
|
||||
let input_context_2 = "The cat was";
|
||||
let output = model.generate(Some(vec!(input_context_1, input_context_2)), None);
|
||||
let output = model.generate(Some(vec![input_context_1, input_context_2]), None);
|
||||
|
||||
assert_eq!(output.len(), 6);
|
||||
assert_eq!(output[0], "The dog was found dead on the side of the road in the middle of the night.\n");
|
||||
assert_eq!(output[1], "The dog was found dead on the side of the road in the middle of the night on Sunday");
|
||||
assert_eq!(output[2], "The dog was found dead on the side of the road in the middle of the night on Saturday");
|
||||
assert_eq!(output[3], "The cat was taken to a local hospital, where it was treated and released.\n\nPolice said");
|
||||
assert_eq!(output[4], "The cat was taken to a local hospital, where it was treated and released.\n\n\"It");
|
||||
assert_eq!(output[5], "The cat was taken to a local hospital, where it was treated and released.\n\n\"We");
|
||||
assert_eq!(
|
||||
output[0],
|
||||
"The dog was found dead on the side of the road in the middle of the night.\n"
|
||||
);
|
||||
assert_eq!(
|
||||
output[1],
|
||||
"The dog was found dead on the side of the road in the middle of the night on Sunday"
|
||||
);
|
||||
assert_eq!(
|
||||
output[2],
|
||||
"The dog was found dead on the side of the road in the middle of the night on Saturday"
|
||||
);
|
||||
assert_eq!(
|
||||
output[3],
|
||||
"The cat was taken to a local hospital, where it was treated and released.\n\nPolice said"
|
||||
);
|
||||
assert_eq!(
|
||||
output[4],
|
||||
"The cat was taken to a local hospital, where it was treated and released.\n\n\"It"
|
||||
);
|
||||
assert_eq!(
|
||||
output[5],
|
||||
"The cat was taken to a local hospital, where it was treated and released.\n\n\"We"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
@ -1,10 +1,9 @@
|
||||
use rust_bert::pipelines::translation::{TranslationConfig, Language, TranslationModel};
|
||||
use rust_bert::pipelines::translation::{Language, TranslationConfig, TranslationModel};
|
||||
use tch::Device;
|
||||
|
||||
#[test]
|
||||
#[cfg_attr(not(feature = "all-tests"), ignore)]
|
||||
fn test_translation() -> failure::Fallible<()> {
|
||||
|
||||
// Set-up translation model
|
||||
let translation_config = TranslationConfig::new(Language::EnglishToFrench, Device::Cpu);
|
||||
let model = TranslationModel::new(translation_config)?;
|
||||
@ -15,7 +14,10 @@ fn test_translation() -> failure::Fallible<()> {
|
||||
let output = model.translate(&[input_context_1, input_context_2]);
|
||||
|
||||
assert_eq!(output.len(), 2);
|
||||
assert_eq!(output[0], " Le rapide renard brun saute sur le chien paresseux");
|
||||
assert_eq!(
|
||||
output[0],
|
||||
" Le rapide renard brun saute sur le chien paresseux"
|
||||
);
|
||||
assert_eq!(output[1], " Le chien ne s'est pas réveillé.");
|
||||
|
||||
Ok(())
|
||||
|
@ -1,18 +1,31 @@
|
||||
use tch::{Device, nn, Tensor};
|
||||
use rust_tokenizers::{TruncationStrategy, Tokenizer, OpenAiGptTokenizer};
|
||||
use rust_bert::Config;
|
||||
use rust_bert::pipelines::generation::{OpenAIGenerator, LanguageGenerator, GenerateConfig, LMHeadModel, Cache};
|
||||
use rust_bert::gpt2::Gpt2Config;
|
||||
use rust_bert::openai_gpt::{OpenAIGPTLMHeadModel, OpenAiGptConfigResources, OpenAiGptVocabResources, OpenAiGptMergesResources, OpenAiGptModelResources};
|
||||
use rust_bert::resources::{RemoteResource, Resource, download_resource};
|
||||
use rust_bert::openai_gpt::{
|
||||
OpenAIGPTLMHeadModel, OpenAiGptConfigResources, OpenAiGptMergesResources,
|
||||
OpenAiGptModelResources, OpenAiGptVocabResources,
|
||||
};
|
||||
use rust_bert::pipelines::generation::{
|
||||
Cache, GenerateConfig, LMHeadModel, LanguageGenerator, OpenAIGenerator,
|
||||
};
|
||||
use rust_bert::resources::{download_resource, RemoteResource, Resource};
|
||||
use rust_bert::Config;
|
||||
use rust_tokenizers::{OpenAiGptTokenizer, Tokenizer, TruncationStrategy};
|
||||
use tch::{nn, Device, Tensor};
|
||||
|
||||
#[test]
|
||||
fn openai_gpt_lm_model() -> failure::Fallible<()> {
|
||||
// Resources paths
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptConfigResources::GPT));
|
||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptVocabResources::GPT));
|
||||
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptMergesResources::GPT));
|
||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptModelResources::GPT));
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
OpenAiGptConfigResources::GPT,
|
||||
));
|
||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
OpenAiGptVocabResources::GPT,
|
||||
));
|
||||
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
OpenAiGptMergesResources::GPT,
|
||||
));
|
||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
OpenAiGptModelResources::GPT,
|
||||
));
|
||||
let config_path = download_resource(&config_resource)?;
|
||||
let vocab_path = download_resource(&vocab_resource)?;
|
||||
let merges_path = download_resource(&merges_resource)?;
|
||||
@ -21,29 +34,38 @@ fn openai_gpt_lm_model() -> failure::Fallible<()> {
|
||||
// Set-up masked LM model
|
||||
let device = Device::Cpu;
|
||||
let mut vs = nn::VarStore::new(device);
|
||||
let tokenizer = OpenAiGptTokenizer::from_file(vocab_path.to_str().unwrap(), merges_path.to_str().unwrap(), true);
|
||||
let tokenizer = OpenAiGptTokenizer::from_file(
|
||||
vocab_path.to_str().unwrap(),
|
||||
merges_path.to_str().unwrap(),
|
||||
true,
|
||||
);
|
||||
let config = Gpt2Config::from_file(config_path);
|
||||
let openai_gpt = OpenAIGPTLMHeadModel::new(&vs.root(), &config);
|
||||
vs.load(weights_path)?;
|
||||
|
||||
// Define input
|
||||
let input = ["Wondering what the next word will"];
|
||||
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
||||
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
|
||||
let tokenized_input = tokenized_input.
|
||||
iter().
|
||||
map(|input| input.token_ids.clone()).
|
||||
map(|mut input| {
|
||||
let tokenized_input =
|
||||
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
||||
let max_len = tokenized_input
|
||||
.iter()
|
||||
.map(|input| input.token_ids.len())
|
||||
.max()
|
||||
.unwrap();
|
||||
let tokenized_input = tokenized_input
|
||||
.iter()
|
||||
.map(|input| input.token_ids.clone())
|
||||
.map(|mut input| {
|
||||
input.extend(vec![0; max_len - input.len()]);
|
||||
input
|
||||
}).
|
||||
map(|input|
|
||||
Tensor::of_slice(&(input))).
|
||||
collect::<Vec<_>>();
|
||||
})
|
||||
.map(|input| Tensor::of_slice(&(input)))
|
||||
.collect::<Vec<_>>();
|
||||
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
|
||||
|
||||
// Forward pass
|
||||
let (output, _, _, _, _) = openai_gpt.forward_t(
|
||||
let (output, _, _, _, _) = openai_gpt
|
||||
.forward_t(
|
||||
&Some(input_tensor),
|
||||
Cache::None,
|
||||
&None,
|
||||
@ -52,13 +74,17 @@ fn openai_gpt_lm_model() -> failure::Fallible<()> {
|
||||
&None,
|
||||
None,
|
||||
&None,
|
||||
false).unwrap();
|
||||
false,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let next_word_id = output.get(0).get(-1).argmax(-1, true).int64_value(&[0]);
|
||||
let next_word = tokenizer.decode(vec!(next_word_id), true, true);
|
||||
let next_word = tokenizer.decode(vec![next_word_id], true, true);
|
||||
|
||||
assert_eq!(output.size(), vec!(1, 6, 40478));
|
||||
assert!((output.double_value(&[0, output.size()[1] - 1, next_word_id]) - (9.1056)).abs() < 1e-4);
|
||||
assert!(
|
||||
(output.double_value(&[0, output.size()[1] - 1, next_word_id]) - (9.1056)).abs() < 1e-4
|
||||
);
|
||||
assert_eq!(next_word_id, 580i64);
|
||||
assert_eq!(next_word, String::from("be"));
|
||||
|
||||
@ -68,10 +94,18 @@ fn openai_gpt_lm_model() -> failure::Fallible<()> {
|
||||
#[test]
|
||||
fn openai_gpt_generation_greedy() -> failure::Fallible<()> {
|
||||
// Resources paths
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptConfigResources::GPT));
|
||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptVocabResources::GPT));
|
||||
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptMergesResources::GPT));
|
||||
let model_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptModelResources::GPT));
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
OpenAiGptConfigResources::GPT,
|
||||
));
|
||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
OpenAiGptVocabResources::GPT,
|
||||
));
|
||||
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
OpenAiGptMergesResources::GPT,
|
||||
));
|
||||
let model_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
OpenAiGptModelResources::GPT,
|
||||
));
|
||||
|
||||
// Set-up masked LM model
|
||||
let generate_config = GenerateConfig {
|
||||
@ -90,7 +124,7 @@ fn openai_gpt_generation_greedy() -> failure::Fallible<()> {
|
||||
let model = OpenAIGenerator::new(generate_config)?;
|
||||
|
||||
let input_context = "It was an intense machine dialogue. ";
|
||||
let output = model.generate(Some(vec!(input_context)), None);
|
||||
let output = model.generate(Some(vec![input_context]), None);
|
||||
|
||||
assert_eq!(output.len(), 1);
|
||||
assert_eq!(output[0], "it was an intense machine dialogue. \n \" i\'m sorry, but we have to go now! the police are on their way and they\'re going after you - or at least that\'s what my");
|
||||
@ -101,10 +135,18 @@ fn openai_gpt_generation_greedy() -> failure::Fallible<()> {
|
||||
#[test]
|
||||
fn openai_gpt_generation_beam_search() -> failure::Fallible<()> {
|
||||
// Resources paths
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptConfigResources::GPT));
|
||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptVocabResources::GPT));
|
||||
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptMergesResources::GPT));
|
||||
let model_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptModelResources::GPT));
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
OpenAiGptConfigResources::GPT,
|
||||
));
|
||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
OpenAiGptVocabResources::GPT,
|
||||
));
|
||||
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
OpenAiGptMergesResources::GPT,
|
||||
));
|
||||
let model_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
OpenAiGptModelResources::GPT,
|
||||
));
|
||||
|
||||
// Set-up masked LM model
|
||||
let generate_config = GenerateConfig {
|
||||
@ -122,12 +164,21 @@ fn openai_gpt_generation_beam_search() -> failure::Fallible<()> {
|
||||
let model = OpenAIGenerator::new(generate_config)?;
|
||||
|
||||
let input_context = "The dog is";
|
||||
let output = model.generate(Some(vec!(input_context)), None);
|
||||
let output = model.generate(Some(vec![input_context]), None);
|
||||
|
||||
assert_eq!(output.len(), 3);
|
||||
assert_eq!(output[0], "the dog isn\'t going anywhere. i\'m going to take care of him. i \'ll be right");
|
||||
assert_eq!(output[1], "the dog isn\'t going anywhere. i\'m going to take care of him. i \'ll be back");
|
||||
assert_eq!(output[2], "the dog isn\'t going anywhere. i\'m going to take care of him. \" \n \" i");
|
||||
assert_eq!(
|
||||
output[0],
|
||||
"the dog isn\'t going anywhere. i\'m going to take care of him. i \'ll be right"
|
||||
);
|
||||
assert_eq!(
|
||||
output[1],
|
||||
"the dog isn\'t going anywhere. i\'m going to take care of him. i \'ll be back"
|
||||
);
|
||||
assert_eq!(
|
||||
output[2],
|
||||
"the dog isn\'t going anywhere. i\'m going to take care of him. \" \n \" i"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@ -135,10 +186,18 @@ fn openai_gpt_generation_beam_search() -> failure::Fallible<()> {
|
||||
#[test]
|
||||
fn openai_gpt_generation_beam_search_multiple_prompts_without_padding() -> failure::Fallible<()> {
|
||||
// Resources paths
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptConfigResources::GPT));
|
||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptVocabResources::GPT));
|
||||
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptMergesResources::GPT));
|
||||
let model_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptModelResources::GPT));
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
OpenAiGptConfigResources::GPT,
|
||||
));
|
||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
OpenAiGptVocabResources::GPT,
|
||||
));
|
||||
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
OpenAiGptMergesResources::GPT,
|
||||
));
|
||||
let model_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
OpenAiGptModelResources::GPT,
|
||||
));
|
||||
|
||||
// Set-up masked LM model
|
||||
let generate_config = GenerateConfig {
|
||||
@ -157,18 +216,36 @@ fn openai_gpt_generation_beam_search_multiple_prompts_without_padding() -> failu
|
||||
|
||||
let input_context_1 = "The dog is";
|
||||
let input_context_2 = "The cat";
|
||||
let output = model.generate(Some(vec!(input_context_1, input_context_2)), None);
|
||||
let output = model.generate(Some(vec![input_context_1, input_context_2]), None);
|
||||
|
||||
assert_eq!(output.len(), 6);
|
||||
|
||||
// Unpadded sequence (generation for `The dog is`) is identical to the
|
||||
assert_eq!(output[0], "the dog isn\'t going anywhere. i\'m going to take care of him. i \'ll be right");
|
||||
assert_eq!(output[1], "the dog isn\'t going anywhere. i\'m going to take care of him. i \'ll be back");
|
||||
assert_eq!(output[2], "the dog isn\'t going anywhere. i\'m going to take care of him. \" \n \" i");
|
||||
assert_eq!(
|
||||
output[0],
|
||||
"the dog isn\'t going anywhere. i\'m going to take care of him. i \'ll be right"
|
||||
);
|
||||
assert_eq!(
|
||||
output[1],
|
||||
"the dog isn\'t going anywhere. i\'m going to take care of him. i \'ll be back"
|
||||
);
|
||||
assert_eq!(
|
||||
output[2],
|
||||
"the dog isn\'t going anywhere. i\'m going to take care of him. \" \n \" i"
|
||||
);
|
||||
|
||||
assert_eq!(output[3], "the cat. \" \n \" i don\'t know what you\'re talking about. i don\'t");
|
||||
assert_eq!(output[4], "the cat. \" \n \" i don\'t know what you\'re talking about. i\'m not");
|
||||
assert_eq!(output[5], "the cat. \" \n \" i don\'t know what you\'re talking about. i do know");
|
||||
assert_eq!(
|
||||
output[3],
|
||||
"the cat. \" \n \" i don\'t know what you\'re talking about. i don\'t"
|
||||
);
|
||||
assert_eq!(
|
||||
output[4],
|
||||
"the cat. \" \n \" i don\'t know what you\'re talking about. i\'m not"
|
||||
);
|
||||
assert_eq!(
|
||||
output[5],
|
||||
"the cat. \" \n \" i don\'t know what you\'re talking about. i do know"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@ -176,10 +253,18 @@ fn openai_gpt_generation_beam_search_multiple_prompts_without_padding() -> failu
|
||||
#[test]
|
||||
fn openai_gpt_generation_beam_search_multiple_prompts_with_padding() -> failure::Fallible<()> {
|
||||
// Resources paths
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptConfigResources::GPT));
|
||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptVocabResources::GPT));
|
||||
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptMergesResources::GPT));
|
||||
let model_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptModelResources::GPT));
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
OpenAiGptConfigResources::GPT,
|
||||
));
|
||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
OpenAiGptVocabResources::GPT,
|
||||
));
|
||||
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
OpenAiGptMergesResources::GPT,
|
||||
));
|
||||
let model_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
OpenAiGptModelResources::GPT,
|
||||
));
|
||||
|
||||
// Set-up masked LM model
|
||||
let generate_config = GenerateConfig {
|
||||
@ -198,16 +283,34 @@ fn openai_gpt_generation_beam_search_multiple_prompts_with_padding() -> failure:
|
||||
|
||||
let input_context_1 = "The dog is";
|
||||
let input_context_2 = "The cat was in";
|
||||
let output = model.generate(Some(vec!(input_context_1, input_context_2)), None);
|
||||
let output = model.generate(Some(vec![input_context_1, input_context_2]), None);
|
||||
|
||||
assert_eq!(output.len(), 6);
|
||||
// Left padding impacts the generated sentences output
|
||||
assert_eq!(output[0], "the dog is a dog. \" \n \" i don\'t know what you\'re talking about.");
|
||||
assert_eq!(output[1], "the dog is a dog. \" \n \" i don\'t know what you\'re talking about,");
|
||||
assert_eq!(output[2], "the dog is a dog. \" \n \" i don\'t know what you\'re talking about!");
|
||||
assert_eq!(output[3], "the cat was in the room with them. \n \" what\'s going on? \" i asked.");
|
||||
assert_eq!(output[4], "the cat was in the room with them. \n \" what\'s going on? \" she asked.");
|
||||
assert_eq!(output[5], "the cat was in the room with them. \n \" what\'s going on? why are you all");
|
||||
assert_eq!(
|
||||
output[0],
|
||||
"the dog is a dog. \" \n \" i don\'t know what you\'re talking about."
|
||||
);
|
||||
assert_eq!(
|
||||
output[1],
|
||||
"the dog is a dog. \" \n \" i don\'t know what you\'re talking about,"
|
||||
);
|
||||
assert_eq!(
|
||||
output[2],
|
||||
"the dog is a dog. \" \n \" i don\'t know what you\'re talking about!"
|
||||
);
|
||||
assert_eq!(
|
||||
output[3],
|
||||
"the cat was in the room with them. \n \" what\'s going on? \" i asked."
|
||||
);
|
||||
assert_eq!(
|
||||
output[4],
|
||||
"the cat was in the room with them. \n \" what\'s going on? \" she asked."
|
||||
);
|
||||
assert_eq!(
|
||||
output[5],
|
||||
"the cat was in the room with them. \n \" what\'s going on? why are you all"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
357
tests/roberta.rs
357
tests/roberta.rs
@ -1,18 +1,30 @@
|
||||
use tch::{Device, nn, Tensor, no_grad};
|
||||
use rust_tokenizers::{RobertaTokenizer, TruncationStrategy, Tokenizer, Vocab};
|
||||
use rust_bert::Config;
|
||||
use rust_bert::bert::BertConfig;
|
||||
use rust_bert::roberta::{RobertaForMaskedLM, RobertaForSequenceClassification, RobertaForMultipleChoice, RobertaForTokenClassification, RobertaForQuestionAnswering, RobertaConfigResources, RobertaVocabResources, RobertaMergesResources, RobertaModelResources};
|
||||
use rust_bert::resources::{download_resource, RemoteResource, Resource};
|
||||
use rust_bert::roberta::{
|
||||
RobertaConfigResources, RobertaForMaskedLM, RobertaForMultipleChoice,
|
||||
RobertaForQuestionAnswering, RobertaForSequenceClassification, RobertaForTokenClassification,
|
||||
RobertaMergesResources, RobertaModelResources, RobertaVocabResources,
|
||||
};
|
||||
use rust_bert::Config;
|
||||
use rust_tokenizers::{RobertaTokenizer, Tokenizer, TruncationStrategy, Vocab};
|
||||
use std::collections::HashMap;
|
||||
use rust_bert::resources::{RemoteResource, Resource, download_resource};
|
||||
use tch::{nn, no_grad, Device, Tensor};
|
||||
|
||||
#[test]
|
||||
fn roberta_masked_lm() -> failure::Fallible<()> {
|
||||
// Resources paths
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaConfigResources::ROBERTA));
|
||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaVocabResources::ROBERTA));
|
||||
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaMergesResources::ROBERTA));
|
||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaModelResources::ROBERTA));
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
RobertaConfigResources::ROBERTA,
|
||||
));
|
||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
RobertaVocabResources::ROBERTA,
|
||||
));
|
||||
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
RobertaMergesResources::ROBERTA,
|
||||
));
|
||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
RobertaModelResources::ROBERTA,
|
||||
));
|
||||
let config_path = download_resource(&config_resource)?;
|
||||
let vocab_path = download_resource(&vocab_resource)?;
|
||||
let merges_path = download_resource(&merges_resource)?;
|
||||
@ -21,45 +33,57 @@ fn roberta_masked_lm() -> failure::Fallible<()> {
|
||||
// Set-up masked LM model
|
||||
let device = Device::Cpu;
|
||||
let mut vs = nn::VarStore::new(device);
|
||||
let tokenizer: RobertaTokenizer = RobertaTokenizer::from_file(vocab_path.to_str().unwrap(), merges_path.to_str().unwrap(), true);
|
||||
let tokenizer: RobertaTokenizer = RobertaTokenizer::from_file(
|
||||
vocab_path.to_str().unwrap(),
|
||||
merges_path.to_str().unwrap(),
|
||||
true,
|
||||
);
|
||||
let config = BertConfig::from_file(config_path);
|
||||
let roberta_model = RobertaForMaskedLM::new(&vs.root(), &config);
|
||||
vs.load(weights_path)?;
|
||||
|
||||
// Define input
|
||||
let input = ["<pad> Looks like one thing is missing", "It\'s like comparing oranges to apples"];
|
||||
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
||||
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
|
||||
let mut tokenized_input = tokenized_input.
|
||||
iter().
|
||||
map(|input| input.token_ids.clone()).
|
||||
map(|mut input| {
|
||||
let input = [
|
||||
"<pad> Looks like one thing is missing",
|
||||
"It\'s like comparing oranges to apples",
|
||||
];
|
||||
let tokenized_input =
|
||||
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
||||
let max_len = tokenized_input
|
||||
.iter()
|
||||
.map(|input| input.token_ids.len())
|
||||
.max()
|
||||
.unwrap();
|
||||
let mut tokenized_input = tokenized_input
|
||||
.iter()
|
||||
.map(|input| input.token_ids.clone())
|
||||
.map(|mut input| {
|
||||
input.extend(vec![0; max_len - input.len()]);
|
||||
input
|
||||
}).
|
||||
collect::<Vec<_>>();
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
// Masking the token [thing] of sentence 1 and [oranges] of sentence 2
|
||||
tokenized_input[0][4] = 103;
|
||||
tokenized_input[1][5] = 103;
|
||||
let tokenized_input = tokenized_input.
|
||||
iter().
|
||||
map(|input|
|
||||
Tensor::of_slice(&(input))).
|
||||
collect::<Vec<_>>();
|
||||
let tokenized_input = tokenized_input
|
||||
.iter()
|
||||
.map(|input| Tensor::of_slice(&(input)))
|
||||
.collect::<Vec<_>>();
|
||||
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
|
||||
|
||||
// Forward pass
|
||||
let (output, _, _) = no_grad(|| {
|
||||
roberta_model
|
||||
.forward_t(Some(input_tensor),
|
||||
roberta_model.forward_t(
|
||||
Some(input_tensor),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
&None,
|
||||
&None,
|
||||
false)
|
||||
false,
|
||||
)
|
||||
});
|
||||
|
||||
// Print masked tokens
|
||||
@ -77,9 +101,15 @@ fn roberta_masked_lm() -> failure::Fallible<()> {
|
||||
#[test]
|
||||
fn roberta_for_sequence_classification() -> failure::Fallible<()> {
|
||||
// Resources paths
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaConfigResources::ROBERTA));
|
||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaVocabResources::ROBERTA));
|
||||
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaMergesResources::ROBERTA));
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
RobertaConfigResources::ROBERTA,
|
||||
));
|
||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
RobertaVocabResources::ROBERTA,
|
||||
));
|
||||
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
RobertaMergesResources::ROBERTA,
|
||||
));
|
||||
let config_path = download_resource(&config_resource)?;
|
||||
let vocab_path = download_resource(&vocab_resource)?;
|
||||
let merges_path = download_resource(&merges_resource)?;
|
||||
@ -87,7 +117,11 @@ fn roberta_for_sequence_classification() -> failure::Fallible<()> {
|
||||
// Set-up model
|
||||
let device = Device::Cpu;
|
||||
let vs = nn::VarStore::new(device);
|
||||
let tokenizer: RobertaTokenizer = RobertaTokenizer::from_file(vocab_path.to_str().unwrap(), merges_path.to_str().unwrap(), true);
|
||||
let tokenizer: RobertaTokenizer = RobertaTokenizer::from_file(
|
||||
vocab_path.to_str().unwrap(),
|
||||
merges_path.to_str().unwrap(),
|
||||
true,
|
||||
);
|
||||
let mut config = BertConfig::from_file(config_path);
|
||||
let mut dummy_label_mapping = HashMap::new();
|
||||
dummy_label_mapping.insert(0, String::from("Positive"));
|
||||
@ -98,37 +132,42 @@ fn roberta_for_sequence_classification() -> failure::Fallible<()> {
|
||||
config.output_hidden_states = Some(true);
|
||||
let roberta_model = RobertaForSequenceClassification::new(&vs.root(), &config);
|
||||
|
||||
|
||||
// Define input
|
||||
let input = ["Looks like one thing is missing", "It\'s like comparing oranges to apples"];
|
||||
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
||||
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
|
||||
let tokenized_input = tokenized_input.
|
||||
iter().
|
||||
map(|input| input.token_ids.clone()).
|
||||
map(|mut input| {
|
||||
let input = [
|
||||
"Looks like one thing is missing",
|
||||
"It\'s like comparing oranges to apples",
|
||||
];
|
||||
let tokenized_input =
|
||||
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
||||
let max_len = tokenized_input
|
||||
.iter()
|
||||
.map(|input| input.token_ids.len())
|
||||
.max()
|
||||
.unwrap();
|
||||
let tokenized_input = tokenized_input
|
||||
.iter()
|
||||
.map(|input| input.token_ids.clone())
|
||||
.map(|mut input| {
|
||||
input.extend(vec![0; max_len - input.len()]);
|
||||
input
|
||||
}).
|
||||
map(|input|
|
||||
Tensor::of_slice(&(input))).
|
||||
collect::<Vec<_>>();
|
||||
})
|
||||
.map(|input| Tensor::of_slice(&(input)))
|
||||
.collect::<Vec<_>>();
|
||||
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
|
||||
|
||||
// Forward pass
|
||||
let (output, all_hidden_states, all_attentions) = no_grad(|| {
|
||||
roberta_model
|
||||
.forward_t(Some(input_tensor),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
false)
|
||||
});
|
||||
let (output, all_hidden_states, all_attentions) =
|
||||
no_grad(|| roberta_model.forward_t(Some(input_tensor), None, None, None, None, false));
|
||||
|
||||
assert_eq!(output.size(), &[2, 3]);
|
||||
assert_eq!(config.num_hidden_layers as usize, all_hidden_states.unwrap().len());
|
||||
assert_eq!(config.num_hidden_layers as usize, all_attentions.unwrap().len());
|
||||
assert_eq!(
|
||||
config.num_hidden_layers as usize,
|
||||
all_hidden_states.unwrap().len()
|
||||
);
|
||||
assert_eq!(
|
||||
config.num_hidden_layers as usize,
|
||||
all_attentions.unwrap().len()
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@ -136,9 +175,15 @@ fn roberta_for_sequence_classification() -> failure::Fallible<()> {
|
||||
#[test]
|
||||
fn roberta_for_multiple_choice() -> failure::Fallible<()> {
|
||||
// Resources paths
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaConfigResources::ROBERTA));
|
||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaVocabResources::ROBERTA));
|
||||
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaMergesResources::ROBERTA));
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
RobertaConfigResources::ROBERTA,
|
||||
));
|
||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
RobertaVocabResources::ROBERTA,
|
||||
));
|
||||
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
RobertaMergesResources::ROBERTA,
|
||||
));
|
||||
let config_path = download_resource(&config_resource)?;
|
||||
let vocab_path = download_resource(&vocab_resource)?;
|
||||
let merges_path = download_resource(&merges_resource)?;
|
||||
@ -146,42 +191,54 @@ fn roberta_for_multiple_choice() -> failure::Fallible<()> {
|
||||
// Set-up model
|
||||
let device = Device::Cpu;
|
||||
let vs = nn::VarStore::new(device);
|
||||
let tokenizer: RobertaTokenizer = RobertaTokenizer::from_file(vocab_path.to_str().unwrap(), merges_path.to_str().unwrap(), true);
|
||||
let tokenizer: RobertaTokenizer = RobertaTokenizer::from_file(
|
||||
vocab_path.to_str().unwrap(),
|
||||
merges_path.to_str().unwrap(),
|
||||
true,
|
||||
);
|
||||
let mut config = BertConfig::from_file(config_path);
|
||||
config.output_attentions = Some(true);
|
||||
config.output_hidden_states = Some(true);
|
||||
let roberta_model = RobertaForMultipleChoice::new(&vs.root(), &config);
|
||||
|
||||
|
||||
// Define input
|
||||
let input = ["Looks like one thing is missing", "It\'s like comparing oranges to apples"];
|
||||
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
||||
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
|
||||
let tokenized_input = tokenized_input.
|
||||
iter().
|
||||
map(|input| input.token_ids.clone()).
|
||||
map(|mut input| {
|
||||
let input = [
|
||||
"Looks like one thing is missing",
|
||||
"It\'s like comparing oranges to apples",
|
||||
];
|
||||
let tokenized_input =
|
||||
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
||||
let max_len = tokenized_input
|
||||
.iter()
|
||||
.map(|input| input.token_ids.len())
|
||||
.max()
|
||||
.unwrap();
|
||||
let tokenized_input = tokenized_input
|
||||
.iter()
|
||||
.map(|input| input.token_ids.clone())
|
||||
.map(|mut input| {
|
||||
input.extend(vec![0; max_len - input.len()]);
|
||||
input
|
||||
}).
|
||||
map(|input|
|
||||
Tensor::of_slice(&(input))).
|
||||
collect::<Vec<_>>();
|
||||
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device).unsqueeze(0);
|
||||
})
|
||||
.map(|input| Tensor::of_slice(&(input)))
|
||||
.collect::<Vec<_>>();
|
||||
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0)
|
||||
.to(device)
|
||||
.unsqueeze(0);
|
||||
|
||||
// Forward pass
|
||||
let (output, all_hidden_states, all_attentions) = no_grad(|| {
|
||||
roberta_model
|
||||
.forward_t(input_tensor,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
false)
|
||||
});
|
||||
let (output, all_hidden_states, all_attentions) =
|
||||
no_grad(|| roberta_model.forward_t(input_tensor, None, None, None, false));
|
||||
|
||||
assert_eq!(output.size(), &[1, 2]);
|
||||
assert_eq!(config.num_hidden_layers as usize, all_hidden_states.unwrap().len());
|
||||
assert_eq!(config.num_hidden_layers as usize, all_attentions.unwrap().len());
|
||||
assert_eq!(
|
||||
config.num_hidden_layers as usize,
|
||||
all_hidden_states.unwrap().len()
|
||||
);
|
||||
assert_eq!(
|
||||
config.num_hidden_layers as usize,
|
||||
all_attentions.unwrap().len()
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@ -189,9 +246,15 @@ fn roberta_for_multiple_choice() -> failure::Fallible<()> {
|
||||
#[test]
|
||||
fn roberta_for_token_classification() -> failure::Fallible<()> {
|
||||
// Resources paths
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaConfigResources::ROBERTA));
|
||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaVocabResources::ROBERTA));
|
||||
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaMergesResources::ROBERTA));
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
RobertaConfigResources::ROBERTA,
|
||||
));
|
||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
RobertaVocabResources::ROBERTA,
|
||||
));
|
||||
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
RobertaMergesResources::ROBERTA,
|
||||
));
|
||||
let config_path = download_resource(&config_resource)?;
|
||||
let vocab_path = download_resource(&vocab_resource)?;
|
||||
let merges_path = download_resource(&merges_resource)?;
|
||||
@ -199,7 +262,11 @@ fn roberta_for_token_classification() -> failure::Fallible<()> {
|
||||
// Set-up model
|
||||
let device = Device::Cpu;
|
||||
let vs = nn::VarStore::new(device);
|
||||
let tokenizer: RobertaTokenizer = RobertaTokenizer::from_file(vocab_path.to_str().unwrap(), merges_path.to_str().unwrap(), true);
|
||||
let tokenizer: RobertaTokenizer = RobertaTokenizer::from_file(
|
||||
vocab_path.to_str().unwrap(),
|
||||
merges_path.to_str().unwrap(),
|
||||
true,
|
||||
);
|
||||
let mut config = BertConfig::from_file(config_path);
|
||||
let mut dummy_label_mapping = HashMap::new();
|
||||
dummy_label_mapping.insert(0, String::from("O"));
|
||||
@ -212,46 +279,57 @@ fn roberta_for_token_classification() -> failure::Fallible<()> {
|
||||
let roberta_model = RobertaForTokenClassification::new(&vs.root(), &config);
|
||||
|
||||
// Define input
|
||||
let input = ["Looks like one thing is missing", "It\'s like comparing oranges to apples"];
|
||||
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
||||
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
|
||||
let tokenized_input = tokenized_input.
|
||||
iter().
|
||||
map(|input| input.token_ids.clone()).
|
||||
map(|mut input| {
|
||||
let input = [
|
||||
"Looks like one thing is missing",
|
||||
"It\'s like comparing oranges to apples",
|
||||
];
|
||||
let tokenized_input =
|
||||
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
||||
let max_len = tokenized_input
|
||||
.iter()
|
||||
.map(|input| input.token_ids.len())
|
||||
.max()
|
||||
.unwrap();
|
||||
let tokenized_input = tokenized_input
|
||||
.iter()
|
||||
.map(|input| input.token_ids.clone())
|
||||
.map(|mut input| {
|
||||
input.extend(vec![0; max_len - input.len()]);
|
||||
input
|
||||
}).
|
||||
map(|input|
|
||||
Tensor::of_slice(&(input))).
|
||||
collect::<Vec<_>>();
|
||||
})
|
||||
.map(|input| Tensor::of_slice(&(input)))
|
||||
.collect::<Vec<_>>();
|
||||
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
|
||||
|
||||
// Forward pass
|
||||
let (output, all_hidden_states, all_attentions) = no_grad(|| {
|
||||
roberta_model
|
||||
.forward_t(Some(input_tensor),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
false)
|
||||
});
|
||||
let (output, all_hidden_states, all_attentions) =
|
||||
no_grad(|| roberta_model.forward_t(Some(input_tensor), None, None, None, None, false));
|
||||
|
||||
assert_eq!(output.size(), &[2, 9, 4]);
|
||||
assert_eq!(config.num_hidden_layers as usize, all_hidden_states.unwrap().len());
|
||||
assert_eq!(config.num_hidden_layers as usize, all_attentions.unwrap().len());
|
||||
assert_eq!(
|
||||
config.num_hidden_layers as usize,
|
||||
all_hidden_states.unwrap().len()
|
||||
);
|
||||
assert_eq!(
|
||||
config.num_hidden_layers as usize,
|
||||
all_attentions.unwrap().len()
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
#[test]
|
||||
fn roberta_for_question_answering() -> failure::Fallible<()> {
|
||||
// Resources paths
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaConfigResources::ROBERTA));
|
||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaVocabResources::ROBERTA));
|
||||
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaMergesResources::ROBERTA));
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
RobertaConfigResources::ROBERTA,
|
||||
));
|
||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
RobertaVocabResources::ROBERTA,
|
||||
));
|
||||
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
RobertaMergesResources::ROBERTA,
|
||||
));
|
||||
let config_path = download_resource(&config_resource)?;
|
||||
let vocab_path = download_resource(&vocab_resource)?;
|
||||
let merges_path = download_resource(&merges_resource)?;
|
||||
@ -259,7 +337,11 @@ fn roberta_for_question_answering() -> failure::Fallible<()> {
|
||||
// Set-up model
|
||||
let device = Device::Cpu;
|
||||
let vs = nn::VarStore::new(device);
|
||||
let tokenizer: RobertaTokenizer = RobertaTokenizer::from_file(vocab_path.to_str().unwrap(), merges_path.to_str().unwrap(), true);
|
||||
let tokenizer: RobertaTokenizer = RobertaTokenizer::from_file(
|
||||
vocab_path.to_str().unwrap(),
|
||||
merges_path.to_str().unwrap(),
|
||||
true,
|
||||
);
|
||||
let mut config = BertConfig::from_file(config_path);
|
||||
let mut dummy_label_mapping = HashMap::new();
|
||||
dummy_label_mapping.insert(0, String::from("Positive"));
|
||||
@ -270,38 +352,43 @@ fn roberta_for_question_answering() -> failure::Fallible<()> {
|
||||
config.output_hidden_states = Some(true);
|
||||
let roberta_model = RobertaForQuestionAnswering::new(&vs.root(), &config);
|
||||
|
||||
|
||||
// Define input
|
||||
let input = ["Looks like one thing is missing", "It\'s like comparing oranges to apples"];
|
||||
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
||||
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
|
||||
let tokenized_input = tokenized_input.
|
||||
iter().
|
||||
map(|input| input.token_ids.clone()).
|
||||
map(|mut input| {
|
||||
let input = [
|
||||
"Looks like one thing is missing",
|
||||
"It\'s like comparing oranges to apples",
|
||||
];
|
||||
let tokenized_input =
|
||||
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
||||
let max_len = tokenized_input
|
||||
.iter()
|
||||
.map(|input| input.token_ids.len())
|
||||
.max()
|
||||
.unwrap();
|
||||
let tokenized_input = tokenized_input
|
||||
.iter()
|
||||
.map(|input| input.token_ids.clone())
|
||||
.map(|mut input| {
|
||||
input.extend(vec![0; max_len - input.len()]);
|
||||
input
|
||||
}).
|
||||
map(|input|
|
||||
Tensor::of_slice(&(input))).
|
||||
collect::<Vec<_>>();
|
||||
})
|
||||
.map(|input| Tensor::of_slice(&(input)))
|
||||
.collect::<Vec<_>>();
|
||||
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
|
||||
|
||||
// Forward pass
|
||||
let (start_scores, end_scores, all_hidden_states, all_attentions) = no_grad(|| {
|
||||
roberta_model
|
||||
.forward_t(Some(input_tensor),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
false)
|
||||
});
|
||||
let (start_scores, end_scores, all_hidden_states, all_attentions) =
|
||||
no_grad(|| roberta_model.forward_t(Some(input_tensor), None, None, None, None, false));
|
||||
|
||||
assert_eq!(start_scores.size(), &[2, 9]);
|
||||
assert_eq!(end_scores.size(), &[2, 9]);
|
||||
assert_eq!(config.num_hidden_layers as usize, all_hidden_states.unwrap().len());
|
||||
assert_eq!(config.num_hidden_layers as usize, all_attentions.unwrap().len());
|
||||
assert_eq!(
|
||||
config.num_hidden_layers as usize,
|
||||
all_hidden_states.unwrap().len()
|
||||
);
|
||||
assert_eq!(
|
||||
config.num_hidden_layers as usize,
|
||||
all_attentions.unwrap().len()
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
Loading…
Reference in New Issue
Block a user