Code formatted using rustfmt

This commit is contained in:
Guillaume B 2020-06-23 16:54:46 +02:00
parent 0624a5368c
commit 47e36c4e8c
86 changed files with 9498 additions and 5324 deletions

View File

@ -13,58 +13,67 @@
extern crate failure; extern crate failure;
use tch::{Device, nn, Tensor, no_grad}; use rust_bert::albert::{
use rust_tokenizers::{AlbertTokenizer, TruncationStrategy, Tokenizer, Vocab}; AlbertConfig, AlbertConfigResources, AlbertForMaskedLM, AlbertModelResources,
AlbertVocabResources,
};
use rust_bert::resources::{download_resource, RemoteResource, Resource};
use rust_bert::Config; use rust_bert::Config;
use rust_bert::resources::{Resource, download_resource, RemoteResource}; use rust_tokenizers::{AlbertTokenizer, Tokenizer, TruncationStrategy, Vocab};
use rust_bert::albert::{AlbertConfig, AlbertForMaskedLM, AlbertConfigResources, AlbertVocabResources, AlbertModelResources}; use tch::{nn, no_grad, Device, Tensor};
fn main() -> failure::Fallible<()> { fn main() -> failure::Fallible<()> {
// Resources paths // Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(AlbertConfigResources::ALBERT_BASE_V2)); let config_resource = Resource::Remote(RemoteResource::from_pretrained(
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(AlbertVocabResources::ALBERT_BASE_V2)); AlbertConfigResources::ALBERT_BASE_V2,
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(AlbertModelResources::ALBERT_BASE_V2)); ));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
AlbertVocabResources::ALBERT_BASE_V2,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
AlbertModelResources::ALBERT_BASE_V2,
));
let config_path = download_resource(&config_resource)?; let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?; let vocab_path = download_resource(&vocab_resource)?;
let weights_path = download_resource(&weights_resource)?; let weights_path = download_resource(&weights_resource)?;
// Set-up masked LM model // Set-up masked LM model
let device = Device::Cpu; let device = Device::Cpu;
let mut vs = nn::VarStore::new(device); let mut vs = nn::VarStore::new(device);
let tokenizer: AlbertTokenizer = AlbertTokenizer::from_file(vocab_path.to_str().unwrap(), true, false); let tokenizer: AlbertTokenizer =
AlbertTokenizer::from_file(vocab_path.to_str().unwrap(), true, false);
let config = AlbertConfig::from_file(config_path); let config = AlbertConfig::from_file(config_path);
let albert_model = AlbertForMaskedLM::new(&vs.root(), &config); let albert_model = AlbertForMaskedLM::new(&vs.root(), &config);
vs.load(weights_path)?; vs.load(weights_path)?;
// Define input // Define input
let input = ["Looks like one [MASK] is missing", "It was a very nice and [MASK] day"]; let input = [
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0); "Looks like one [MASK] is missing",
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap(); "It was a very nice and [MASK] day",
let tokenized_input = tokenized_input. ];
iter(). let tokenized_input =
map(|input| input.token_ids.clone()). tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
map(|mut input| { let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let tokenized_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]); input.extend(vec![0; max_len - input.len()]);
input input
}). })
map(|input| .map(|input| Tensor::of_slice(&(input)))
Tensor::of_slice(&(input))). .collect::<Vec<_>>();
collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device); let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass // Forward pass
let (output, _, _) = no_grad(|| { let (output, _, _) =
albert_model no_grad(|| albert_model.forward_t(Some(input_tensor), None, None, None, None, false));
.forward_t(Some(input_tensor),
None,
None,
None,
None,
false)
});
println!("{:?}", output.double_value(&[0, 0, 0])); println!("{:?}", output.double_value(&[0, 0, 0]));
// Print masked tokens // Print masked tokens
let index_1 = output.get(0).get(4).argmax(0, false); let index_1 = output.get(0).get(4).argmax(0, false);
let index_2 = output.get(1).get(7).argmax(0, false); let index_2 = output.get(1).get(7).argmax(0, false);
let word_1 = tokenizer.vocab().id_to_token(&index_1.int64_value(&[])); let word_1 = tokenizer.vocab().id_to_token(&index_1.int64_value(&[]));
@ -74,4 +83,4 @@ fn main() -> failure::Fallible<()> {
println!("{} - {}", &index_2.int64_value(&[]), word_2); // Outputs "_enjoyable" : "It was a very nice and [enjoyable] day" println!("{} - {}", &index_2.int64_value(&[]), word_2); // Outputs "_enjoyable" : "It was a very nice and [enjoyable] day"
Ok(()) Ok(())
} }

View File

@ -12,33 +12,43 @@
extern crate failure; extern crate failure;
use tch::{Device, nn, Tensor, no_grad}; use rust_bert::bart::{
use rust_tokenizers::{RobertaTokenizer, TruncationStrategy, Tokenizer}; BartConfig, BartConfigResources, BartMergesResources, BartModel, BartModelResources,
use rust_bert::bart::{BartConfig, BartConfigResources, BartVocabResources, BartMergesResources, BartModelResources, BartModel}; BartVocabResources,
};
use rust_bert::resources::{download_resource, RemoteResource, Resource};
use rust_bert::Config; use rust_bert::Config;
use rust_bert::resources::{Resource, download_resource, RemoteResource}; use rust_tokenizers::{RobertaTokenizer, Tokenizer, TruncationStrategy};
use tch::{nn, no_grad, Device, Tensor};
fn main() -> failure::Fallible<()> { fn main() -> failure::Fallible<()> {
// Resources paths // Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(BartConfigResources::BART)); let config_resource =
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(BartVocabResources::BART)); Resource::Remote(RemoteResource::from_pretrained(BartConfigResources::BART));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(BartMergesResources::BART)); let vocab_resource =
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(BartModelResources::BART)); Resource::Remote(RemoteResource::from_pretrained(BartVocabResources::BART));
let merges_resource =
Resource::Remote(RemoteResource::from_pretrained(BartMergesResources::BART));
let weights_resource =
Resource::Remote(RemoteResource::from_pretrained(BartModelResources::BART));
let config_path = download_resource(&config_resource)?; let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?; let vocab_path = download_resource(&vocab_resource)?;
let merges_path = download_resource(&merges_resource)?; let merges_path = download_resource(&merges_resource)?;
let weights_path = download_resource(&weights_resource)?; let weights_path = download_resource(&weights_resource)?;
// Set-up masked LM model // Set-up masked LM model
let device = Device::cuda_if_available(); let device = Device::cuda_if_available();
let mut vs = nn::VarStore::new(device); let mut vs = nn::VarStore::new(device);
let tokenizer = RobertaTokenizer::from_file(vocab_path.to_str().unwrap(), merges_path.to_str().unwrap(), false); let tokenizer = RobertaTokenizer::from_file(
vocab_path.to_str().unwrap(),
merges_path.to_str().unwrap(),
false,
);
let config = BartConfig::from_file(config_path); let config = BartConfig::from_file(config_path);
let bart_model = BartModel::new(&vs.root(), &config, false); let bart_model = BartModel::new(&vs.root(), &config, false);
vs.load(weights_path)?; vs.load(weights_path)?;
// Define input // Define input
let input = ["In findings published Tuesday in Cornell University's arXiv by a team of scientists \ let input = ["In findings published Tuesday in Cornell University's arXiv by a team of scientists \
from the University of Montreal and a separate report published Wednesday in Nature Astronomy by a team \ from the University of Montreal and a separate report published Wednesday in Nature Astronomy by a team \
from University College London (UCL), the presence of water vapour was confirmed in the atmosphere of K2-18b, \ from University College London (UCL), the presence of water vapour was confirmed in the atmosphere of K2-18b, \
@ -61,36 +71,32 @@ on K2-18b lasts 33 Earth days. According to The Guardian, astronomers were optim
telescope scheduled for launch in 2021 and the European Space Agency's 2028 ARIEL program, could reveal more \ telescope scheduled for launch in 2021 and the European Space Agency's 2028 ARIEL program, could reveal more \
about exoplanets like K2-18b."]; about exoplanets like K2-18b."];
// Credits: WikiNews, CC BY 2.5 license (https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b) // Credits: WikiNews, CC BY 2.5 license (https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b)
let tokenized_input = tokenizer.encode_list(input.to_vec(), 1024, &TruncationStrategy::LongestFirst, 0); let tokenized_input =
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap(); tokenizer.encode_list(input.to_vec(), 1024, &TruncationStrategy::LongestFirst, 0);
let tokenized_input = tokenized_input. let max_len = tokenized_input
iter(). .iter()
map(|input| input.token_ids.clone()). .map(|input| input.token_ids.len())
map(|mut input| { .max()
.unwrap();
let tokenized_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]); input.extend(vec![0; max_len - input.len()]);
input input
}). })
map(|input| .map(|input| Tensor::of_slice(&(input)))
Tensor::of_slice(&(input))). .collect::<Vec<_>>();
collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device); let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass // Forward pass
let (decoder_output, encoder_output, _, _, _, _, _) = no_grad(|| { let (decoder_output, encoder_output, _, _, _, _, _) =
bart_model no_grad(|| bart_model.forward_t(Some(&input_tensor), None, None, None, None, None, false));
.forward_t(Some(&input_tensor),
None,
None,
None,
None,
None,
false)
});
// Print masked tokens // Print masked tokens
println!("{:?}", encoder_output); println!("{:?}", encoder_output);
println!("{:?}", decoder_output); println!("{:?}", decoder_output);
Ok(()) Ok(())
} }

View File

@ -12,23 +12,27 @@
extern crate failure; extern crate failure;
use tch::{Device, nn, Tensor, no_grad}; use rust_bert::bert::{
use rust_tokenizers::{BertTokenizer, TruncationStrategy, Tokenizer, Vocab}; BertConfig, BertConfigResources, BertForMaskedLM, BertModelResources, BertVocabResources,
};
use rust_bert::resources::{download_resource, RemoteResource, Resource};
use rust_bert::Config; use rust_bert::Config;
use rust_bert::bert::{BertConfig, BertForMaskedLM, BertConfigResources, BertVocabResources, BertModelResources}; use rust_tokenizers::{BertTokenizer, Tokenizer, TruncationStrategy, Vocab};
use rust_bert::resources::{Resource, download_resource, RemoteResource}; use tch::{nn, no_grad, Device, Tensor};
fn main() -> failure::Fallible<()> { fn main() -> failure::Fallible<()> {
// Resources paths // Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT)); let config_resource =
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT)); Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(BertModelResources::BERT)); let vocab_resource =
Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT));
let weights_resource =
Resource::Remote(RemoteResource::from_pretrained(BertModelResources::BERT));
let config_path = download_resource(&config_resource)?; let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?; let vocab_path = download_resource(&vocab_resource)?;
let weights_path = download_resource(&weights_resource)?; let weights_path = download_resource(&weights_resource)?;
// Set-up masked LM model // Set-up masked LM model
let device = Device::Cpu; let device = Device::Cpu;
let mut vs = nn::VarStore::new(device); let mut vs = nn::VarStore::new(device);
let tokenizer: BertTokenizer = BertTokenizer::from_file(vocab_path.to_str().unwrap(), true); let tokenizer: BertTokenizer = BertTokenizer::from_file(vocab_path.to_str().unwrap(), true);
@ -36,43 +40,51 @@ fn main() -> failure::Fallible<()> {
let bert_model = BertForMaskedLM::new(&vs.root(), &config); let bert_model = BertForMaskedLM::new(&vs.root(), &config);
vs.load(weights_path)?; vs.load(weights_path)?;
// Define input // Define input
let input = ["Looks like one [MASK] is missing", "It was a very nice and [MASK] day"]; let input = [
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0); "Looks like one [MASK] is missing",
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap(); "It was a very nice and [MASK] day",
let tokenized_input = tokenized_input. ];
iter(). let tokenized_input =
map(|input| input.token_ids.clone()). tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
map(|mut input| { let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let tokenized_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]); input.extend(vec![0; max_len - input.len()]);
input input
}). })
map(|input| .map(|input| Tensor::of_slice(&(input)))
Tensor::of_slice(&(input))). .collect::<Vec<_>>();
collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device); let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass // Forward pass
let (output, _, _) = no_grad(|| { let (output, _, _) = no_grad(|| {
bert_model bert_model.forward_t(
.forward_t(Some(input_tensor), Some(input_tensor),
None, None,
None, None,
None, None,
None, None,
&None, &None,
&None, &None,
false) false,
)
}); });
// Print masked tokens // Print masked tokens
let index_1 = output.get(0).get(4).argmax(0, false); let index_1 = output.get(0).get(4).argmax(0, false);
let index_2 = output.get(1).get(7).argmax(0, false); let index_2 = output.get(1).get(7).argmax(0, false);
let word_1 = tokenizer.vocab().id_to_token(&index_1.int64_value(&[])); let word_1 = tokenizer.vocab().id_to_token(&index_1.int64_value(&[]));
let word_2 = tokenizer.vocab().id_to_token(&index_2.int64_value(&[])); let word_2 = tokenizer.vocab().id_to_token(&index_2.int64_value(&[]));
println!("{}", word_1); // Outputs "person" : "Looks like one [person] is missing" println!("{}", word_1); // Outputs "person" : "Looks like one [person] is missing"
println!("{}", word_2);// Outputs "pear" : "It was a very nice and [pleasant] day" println!("{}", word_2); // Outputs "pear" : "It was a very nice and [pleasant] day"
Ok(()) Ok(())
} }

View File

@ -11,25 +11,33 @@
// limitations under the License. // limitations under the License.
extern crate failure; extern crate failure;
use tch::{Device, Tensor, nn, no_grad}; use rust_bert::distilbert::{
use rust_tokenizers::preprocessing::tokenizer::base_tokenizer::{Tokenizer, TruncationStrategy}; DistilBertConfig, DistilBertConfigResources, DistilBertModelMaskedLM, DistilBertModelResources,
use rust_tokenizers::bert_tokenizer::BertTokenizer; DistilBertVocabResources,
use rust_tokenizers::preprocessing::vocab::base_vocab::Vocab; };
use rust_bert::resources::{download_resource, RemoteResource, Resource};
use rust_bert::Config; use rust_bert::Config;
use rust_bert::distilbert::{DistilBertConfig, DistilBertModelMaskedLM, DistilBertConfigResources, DistilBertVocabResources, DistilBertModelResources}; use rust_tokenizers::bert_tokenizer::BertTokenizer;
use rust_bert::resources::{Resource, download_resource, RemoteResource}; use rust_tokenizers::preprocessing::tokenizer::base_tokenizer::{Tokenizer, TruncationStrategy};
use rust_tokenizers::preprocessing::vocab::base_vocab::Vocab;
use tch::{nn, no_grad, Device, Tensor};
fn main() -> failure::Fallible<()> { fn main() -> failure::Fallible<()> {
// Resources paths // Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(DistilBertConfigResources::DISTIL_BERT)); let config_resource = Resource::Remote(RemoteResource::from_pretrained(
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(DistilBertVocabResources::DISTIL_BERT)); DistilBertConfigResources::DISTIL_BERT,
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(DistilBertModelResources::DISTIL_BERT)); ));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
DistilBertVocabResources::DISTIL_BERT,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
DistilBertModelResources::DISTIL_BERT,
));
let config_path = download_resource(&config_resource)?; let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?; let vocab_path = download_resource(&vocab_resource)?;
let weights_path = download_resource(&weights_resource)?; let weights_path = download_resource(&weights_resource)?;
// Set-up masked LM model // Set-up masked LM model
let device = Device::Cpu; let device = Device::Cpu;
let mut vs = nn::VarStore::new(device); let mut vs = nn::VarStore::new(device);
let tokenizer: BertTokenizer = BertTokenizer::from_file(vocab_path.to_str().unwrap(), true); let tokenizer: BertTokenizer = BertTokenizer::from_file(vocab_path.to_str().unwrap(), true);
@ -37,45 +45,51 @@ fn main() -> failure::Fallible<()> {
let distil_bert_model = DistilBertModelMaskedLM::new(&vs.root(), &config); let distil_bert_model = DistilBertModelMaskedLM::new(&vs.root(), &config);
vs.load(weights_path)?; vs.load(weights_path)?;
// Define input // Define input
let input = ["Looks like one thing is missing", "It\'s like comparing oranges to apples"]; let input = [
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0); "Looks like one thing is missing",
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap(); "It\'s like comparing oranges to apples",
let mut tokenized_input = tokenized_input. ];
iter(). let tokenized_input =
map(|input| input.token_ids.clone()). tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
map(|mut input| { let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let mut tokenized_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]); input.extend(vec![0; max_len - input.len()]);
input input
}). })
collect::<Vec<_>>(); .collect::<Vec<_>>();
// Masking the token [thing] of sentence 1 and [oranges] of sentence 2 // Masking the token [thing] of sentence 1 and [oranges] of sentence 2
tokenized_input[0][4] = 103; tokenized_input[0][4] = 103;
tokenized_input[1][6] = 103; tokenized_input[1][6] = 103;
let tokenized_input = tokenized_input. let tokenized_input = tokenized_input
iter(). .iter()
map(|input| .map(|input| Tensor::of_slice(&(input)))
Tensor::of_slice(&(input))). .collect::<Vec<_>>();
collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device); let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
// Forward pass
let (output, _, _) = no_grad(|| { let (output, _, _) = no_grad(|| {
distil_bert_model distil_bert_model
.forward_t(Some(input_tensor), None, None, false) .forward_t(Some(input_tensor), None, None, false)
.unwrap() .unwrap()
}); });
// Print masked tokens // Print masked tokens
let index_1 = output.get(0).get(4).argmax(0, false); let index_1 = output.get(0).get(4).argmax(0, false);
let index_2 = output.get(1).get(6).argmax(0, false); let index_2 = output.get(1).get(6).argmax(0, false);
let word_1 = tokenizer.vocab().id_to_token(&index_1.int64_value(&[])); let word_1 = tokenizer.vocab().id_to_token(&index_1.int64_value(&[]));
let word_2 = tokenizer.vocab().id_to_token(&index_2.int64_value(&[])); let word_2 = tokenizer.vocab().id_to_token(&index_2.int64_value(&[]));
println!("{}", word_1); // Outputs "person" : "Looks like one [person] is missing" println!("{}", word_1); // Outputs "person" : "Looks like one [person] is missing"
println!("{}", word_2);// Outputs "pear" : "It\'s like comparing [pear] to apples" println!("{}", word_2); // Outputs "pear" : "It\'s like comparing [pear] to apples"
Ok(()) Ok(())
} }

View File

@ -1,26 +1,44 @@
extern crate failure; extern crate failure;
use rust_bert::gpt2::{Gpt2ConfigResources, Gpt2VocabResources, Gpt2MergesResources, Gpt2ModelResources}; use rust_bert::albert::{AlbertConfigResources, AlbertModelResources, AlbertVocabResources};
use rust_bert::distilbert::{DistilBertModelResources, DistilBertConfigResources, DistilBertVocabResources}; use rust_bert::bart::{
use rust_bert::openai_gpt::{OpenAiGptConfigResources, OpenAiGptVocabResources, OpenAiGptMergesResources, OpenAiGptModelResources}; BartConfigResources, BartMergesResources, BartModelResources, BartVocabResources,
use rust_bert::roberta::{RobertaConfigResources, RobertaVocabResources, RobertaMergesResources, RobertaModelResources}; };
use rust_bert::bert::{BertConfigResources, BertVocabResources, BertModelResources}; use rust_bert::bert::{BertConfigResources, BertModelResources, BertVocabResources};
use rust_bert::bart::{BartConfigResources, BartVocabResources, BartMergesResources, BartModelResources}; use rust_bert::distilbert::{
use rust_bert::resources::{Resource, download_resource, RemoteResource}; DistilBertConfigResources, DistilBertModelResources, DistilBertVocabResources,
use rust_bert::electra::{ElectraConfigResources, ElectraVocabResources, ElectraModelResources}; };
use rust_bert::albert::{AlbertConfigResources, AlbertVocabResources, AlbertModelResources}; use rust_bert::electra::{ElectraConfigResources, ElectraModelResources, ElectraVocabResources};
use rust_bert::gpt2::{
Gpt2ConfigResources, Gpt2MergesResources, Gpt2ModelResources, Gpt2VocabResources,
};
use rust_bert::openai_gpt::{
OpenAiGptConfigResources, OpenAiGptMergesResources, OpenAiGptModelResources,
OpenAiGptVocabResources,
};
use rust_bert::resources::{download_resource, RemoteResource, Resource};
use rust_bert::roberta::{
RobertaConfigResources, RobertaMergesResources, RobertaModelResources, RobertaVocabResources,
};
/// This example downloads and caches all dependencies used in model tests. This allows for safe /// This example downloads and caches all dependencies used in model tests. This allows for safe
/// multi threaded testing (two test using the same resource would otherwise download the file to /// multi threaded testing (two test using the same resource would otherwise download the file to
/// the same location). /// the same location).
fn download_distil_gpt2() -> failure::Fallible<()> { fn download_distil_gpt2() -> failure::Fallible<()> {
// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format. // Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
let config_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::DISTIL_GPT2)); let config_resource = Resource::Remote(RemoteResource::from_pretrained(
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::DISTIL_GPT2)); Gpt2ConfigResources::DISTIL_GPT2,
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::DISTIL_GPT2)); ));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::DISTIL_GPT2)); let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
Gpt2VocabResources::DISTIL_GPT2,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
Gpt2MergesResources::DISTIL_GPT2,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
Gpt2ModelResources::DISTIL_GPT2,
));
let _ = download_resource(&config_resource)?; let _ = download_resource(&config_resource)?;
let _ = download_resource(&vocab_resource)?; let _ = download_resource(&vocab_resource)?;
let _ = download_resource(&merges_resource)?; let _ = download_resource(&merges_resource)?;
@ -29,10 +47,16 @@ fn download_distil_gpt2() -> failure::Fallible<()> {
} }
fn download_distilbert_sst2() -> failure::Fallible<()> { fn download_distilbert_sst2() -> failure::Fallible<()> {
// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format. // Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(DistilBertModelResources::DISTIL_BERT_SST2)); let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
let config_resource = Resource::Remote(RemoteResource::from_pretrained(DistilBertConfigResources::DISTIL_BERT_SST2)); DistilBertModelResources::DISTIL_BERT_SST2,
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(DistilBertVocabResources::DISTIL_BERT_SST2)); ));
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
DistilBertConfigResources::DISTIL_BERT_SST2,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
DistilBertVocabResources::DISTIL_BERT_SST2,
));
let _ = download_resource(&config_resource)?; let _ = download_resource(&config_resource)?;
let _ = download_resource(&vocab_resource)?; let _ = download_resource(&vocab_resource)?;
let _ = download_resource(&weights_resource)?; let _ = download_resource(&weights_resource)?;
@ -40,10 +64,16 @@ fn download_distilbert_sst2() -> failure::Fallible<()> {
} }
fn download_distilbert_qa() -> failure::Fallible<()> { fn download_distilbert_qa() -> failure::Fallible<()> {
// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format. // Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(DistilBertModelResources::DISTIL_BERT_SQUAD)); let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
let config_resource = Resource::Remote(RemoteResource::from_pretrained(DistilBertConfigResources::DISTIL_BERT_SQUAD)); DistilBertModelResources::DISTIL_BERT_SQUAD,
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(DistilBertVocabResources::DISTIL_BERT_SQUAD)); ));
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
DistilBertConfigResources::DISTIL_BERT_SQUAD,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
DistilBertVocabResources::DISTIL_BERT_SQUAD,
));
let _ = download_resource(&config_resource)?; let _ = download_resource(&config_resource)?;
let _ = download_resource(&vocab_resource)?; let _ = download_resource(&vocab_resource)?;
let _ = download_resource(&weights_resource)?; let _ = download_resource(&weights_resource)?;
@ -51,10 +81,16 @@ fn download_distilbert_qa() -> failure::Fallible<()> {
} }
fn download_distilbert() -> failure::Fallible<()> { fn download_distilbert() -> failure::Fallible<()> {
// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format. // Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(DistilBertModelResources::DISTIL_BERT)); let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
let config_resource = Resource::Remote(RemoteResource::from_pretrained(DistilBertConfigResources::DISTIL_BERT)); DistilBertModelResources::DISTIL_BERT,
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(DistilBertVocabResources::DISTIL_BERT)); ));
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
DistilBertConfigResources::DISTIL_BERT,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
DistilBertVocabResources::DISTIL_BERT,
));
let _ = download_resource(&config_resource)?; let _ = download_resource(&config_resource)?;
let _ = download_resource(&vocab_resource)?; let _ = download_resource(&vocab_resource)?;
let _ = download_resource(&weights_resource)?; let _ = download_resource(&weights_resource)?;
@ -62,11 +98,15 @@ fn download_distilbert() -> failure::Fallible<()> {
} }
fn download_gpt2() -> failure::Fallible<()> { fn download_gpt2() -> failure::Fallible<()> {
// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2. Modified with conversion to C-array format. // Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2. Modified with conversion to C-array format.
let config_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2)); let config_resource =
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2)); Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2)); let vocab_resource =
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2)); Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2));
let merges_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
let weights_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
let _ = download_resource(&config_resource)?; let _ = download_resource(&config_resource)?;
let _ = download_resource(&vocab_resource)?; let _ = download_resource(&vocab_resource)?;
let _ = download_resource(&merges_resource)?; let _ = download_resource(&merges_resource)?;
@ -75,11 +115,19 @@ fn download_gpt2() -> failure::Fallible<()> {
} }
fn download_gpt() -> failure::Fallible<()> { fn download_gpt() -> failure::Fallible<()> {
// Shared under MIT license by the OpenAI team at https://github.com/openai/finetune-transformer-lm. Modified with conversion to C-array format. // Shared under MIT license by the OpenAI team at https://github.com/openai/finetune-transformer-lm. Modified with conversion to C-array format.
let config_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptConfigResources::GPT)); let config_resource = Resource::Remote(RemoteResource::from_pretrained(
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptVocabResources::GPT)); OpenAiGptConfigResources::GPT,
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptMergesResources::GPT)); ));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptModelResources::GPT)); let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
OpenAiGptVocabResources::GPT,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
OpenAiGptMergesResources::GPT,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
OpenAiGptModelResources::GPT,
));
let _ = download_resource(&config_resource)?; let _ = download_resource(&config_resource)?;
let _ = download_resource(&vocab_resource)?; let _ = download_resource(&vocab_resource)?;
let _ = download_resource(&merges_resource)?; let _ = download_resource(&merges_resource)?;
@ -88,11 +136,19 @@ fn download_gpt() -> failure::Fallible<()> {
} }
fn download_roberta() -> failure::Fallible<()> { fn download_roberta() -> failure::Fallible<()> {
// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format. // Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
let config_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaConfigResources::ROBERTA)); let config_resource = Resource::Remote(RemoteResource::from_pretrained(
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaVocabResources::ROBERTA)); RobertaConfigResources::ROBERTA,
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaMergesResources::ROBERTA)); ));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaModelResources::ROBERTA)); let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
RobertaVocabResources::ROBERTA,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
RobertaMergesResources::ROBERTA,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
RobertaModelResources::ROBERTA,
));
let _ = download_resource(&config_resource)?; let _ = download_resource(&config_resource)?;
let _ = download_resource(&vocab_resource)?; let _ = download_resource(&vocab_resource)?;
let _ = download_resource(&merges_resource)?; let _ = download_resource(&merges_resource)?;
@ -101,10 +157,13 @@ fn download_roberta() -> failure::Fallible<()> {
} }
fn download_bert() -> failure::Fallible<()> { fn download_bert() -> failure::Fallible<()> {
// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/bert. Modified with conversion to C-array format. // Shared under Apache 2.0 license by the Google team at https://github.com/google-research/bert. Modified with conversion to C-array format.
let config_resource = Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT)); let config_resource =
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT)); Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(BertModelResources::BERT)); let vocab_resource =
Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT));
let weights_resource =
Resource::Remote(RemoteResource::from_pretrained(BertModelResources::BERT));
let _ = download_resource(&config_resource)?; let _ = download_resource(&config_resource)?;
let _ = download_resource(&vocab_resource)?; let _ = download_resource(&vocab_resource)?;
let _ = download_resource(&weights_resource)?; let _ = download_resource(&weights_resource)?;
@ -112,10 +171,16 @@ fn download_bert() -> failure::Fallible<()> {
} }
fn download_bert_ner() -> failure::Fallible<()> { fn download_bert_ner() -> failure::Fallible<()> {
// Shared under MIT license by the MDZ Digital Library team at the Bavarian State Library at https://github.com/dbmdz/berts. Modified with conversion to C-array format. // Shared under MIT license by the MDZ Digital Library team at the Bavarian State Library at https://github.com/dbmdz/berts. Modified with conversion to C-array format.
let config_resource = Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT_NER)); let config_resource = Resource::Remote(RemoteResource::from_pretrained(
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT_NER)); BertConfigResources::BERT_NER,
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(BertModelResources::BERT_NER)); ));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
BertVocabResources::BERT_NER,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
BertModelResources::BERT_NER,
));
let _ = download_resource(&config_resource)?; let _ = download_resource(&config_resource)?;
let _ = download_resource(&vocab_resource)?; let _ = download_resource(&vocab_resource)?;
let _ = download_resource(&weights_resource)?; let _ = download_resource(&weights_resource)?;
@ -123,11 +188,15 @@ fn download_bert_ner() -> failure::Fallible<()> {
} }
fn download_bart() -> failure::Fallible<()> { fn download_bart() -> failure::Fallible<()> {
// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format. // Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
let config_resource = Resource::Remote(RemoteResource::from_pretrained(BartConfigResources::BART)); let config_resource =
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(BartVocabResources::BART)); Resource::Remote(RemoteResource::from_pretrained(BartConfigResources::BART));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(BartMergesResources::BART)); let vocab_resource =
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(BartModelResources::BART)); Resource::Remote(RemoteResource::from_pretrained(BartVocabResources::BART));
let merges_resource =
Resource::Remote(RemoteResource::from_pretrained(BartMergesResources::BART));
let weights_resource =
Resource::Remote(RemoteResource::from_pretrained(BartModelResources::BART));
let _ = download_resource(&config_resource)?; let _ = download_resource(&config_resource)?;
let _ = download_resource(&vocab_resource)?; let _ = download_resource(&vocab_resource)?;
let _ = download_resource(&merges_resource)?; let _ = download_resource(&merges_resource)?;
@ -136,11 +205,19 @@ fn download_bart() -> failure::Fallible<()> {
} }
fn download_bart_cnn() -> failure::Fallible<()> { fn download_bart_cnn() -> failure::Fallible<()> {
// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format. // Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
let config_resource = Resource::Remote(RemoteResource::from_pretrained(BartConfigResources::BART_CNN)); let config_resource = Resource::Remote(RemoteResource::from_pretrained(
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(BartVocabResources::BART_CNN)); BartConfigResources::BART_CNN,
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(BartMergesResources::BART_CNN)); ));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(BartModelResources::BART_CNN)); let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
BartVocabResources::BART_CNN,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
BartMergesResources::BART_CNN,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
BartModelResources::BART_CNN,
));
let _ = download_resource(&config_resource)?; let _ = download_resource(&config_resource)?;
let _ = download_resource(&vocab_resource)?; let _ = download_resource(&vocab_resource)?;
let _ = download_resource(&merges_resource)?; let _ = download_resource(&merges_resource)?;
@ -149,10 +226,16 @@ fn download_bart_cnn() -> failure::Fallible<()> {
} }
fn download_electra_generator() -> failure::Fallible<()> { fn download_electra_generator() -> failure::Fallible<()> {
// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/electra. Modified with conversion to C-array format. // Shared under Apache 2.0 license by the Google team at https://github.com/google-research/electra. Modified with conversion to C-array format.
let config_resource = Resource::Remote(RemoteResource::from_pretrained(ElectraConfigResources::BASE_GENERATOR)); let config_resource = Resource::Remote(RemoteResource::from_pretrained(
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(ElectraVocabResources::BASE_GENERATOR)); ElectraConfigResources::BASE_GENERATOR,
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(ElectraModelResources::BASE_GENERATOR)); ));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
ElectraVocabResources::BASE_GENERATOR,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
ElectraModelResources::BASE_GENERATOR,
));
let _ = download_resource(&config_resource)?; let _ = download_resource(&config_resource)?;
let _ = download_resource(&vocab_resource)?; let _ = download_resource(&vocab_resource)?;
let _ = download_resource(&weights_resource)?; let _ = download_resource(&weights_resource)?;
@ -160,10 +243,16 @@ fn download_electra_generator() -> failure::Fallible<()> {
} }
fn download_electra_discriminator() -> failure::Fallible<()> { fn download_electra_discriminator() -> failure::Fallible<()> {
// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/electra. Modified with conversion to C-array format. // Shared under Apache 2.0 license by the Google team at https://github.com/google-research/electra. Modified with conversion to C-array format.
let config_resource = Resource::Remote(RemoteResource::from_pretrained(ElectraConfigResources::BASE_DISCRIMINATOR)); let config_resource = Resource::Remote(RemoteResource::from_pretrained(
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(ElectraVocabResources::BASE_DISCRIMINATOR)); ElectraConfigResources::BASE_DISCRIMINATOR,
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(ElectraModelResources::BASE_DISCRIMINATOR)); ));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
ElectraVocabResources::BASE_DISCRIMINATOR,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
ElectraModelResources::BASE_DISCRIMINATOR,
));
let _ = download_resource(&config_resource)?; let _ = download_resource(&config_resource)?;
let _ = download_resource(&vocab_resource)?; let _ = download_resource(&vocab_resource)?;
let _ = download_resource(&weights_resource)?; let _ = download_resource(&weights_resource)?;
@ -171,10 +260,16 @@ fn download_electra_discriminator() -> failure::Fallible<()> {
} }
fn download_albert_base_v2() -> failure::Fallible<()> { fn download_albert_base_v2() -> failure::Fallible<()> {
// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/ALBERT. Modified with conversion to C-array format. // Shared under Apache 2.0 license by the Google team at https://github.com/google-research/ALBERT. Modified with conversion to C-array format.
let config_resource = Resource::Remote(RemoteResource::from_pretrained(AlbertConfigResources::ALBERT_BASE_V2)); let config_resource = Resource::Remote(RemoteResource::from_pretrained(
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(AlbertVocabResources::ALBERT_BASE_V2)); AlbertConfigResources::ALBERT_BASE_V2,
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(AlbertModelResources::ALBERT_BASE_V2)); ));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
AlbertVocabResources::ALBERT_BASE_V2,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
AlbertModelResources::ALBERT_BASE_V2,
));
let _ = download_resource(&config_resource)?; let _ = download_resource(&config_resource)?;
let _ = download_resource(&vocab_resource)?; let _ = download_resource(&vocab_resource)?;
let _ = download_resource(&weights_resource)?; let _ = download_resource(&weights_resource)?;
@ -198,4 +293,4 @@ fn main() -> failure::Fallible<()> {
let _ = download_albert_base_v2(); let _ = download_albert_base_v2();
Ok(()) Ok(())
} }

View File

@ -12,23 +12,31 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use rust_bert::electra::{
use rust_bert::resources::{Resource, download_resource, RemoteResource}; ElectraConfig, ElectraConfigResources, ElectraDiscriminator, ElectraModelResources,
use rust_bert::electra::{ElectraConfig, ElectraDiscriminator, ElectraConfigResources, ElectraVocabResources, ElectraModelResources}; ElectraVocabResources,
};
use rust_bert::resources::{download_resource, RemoteResource, Resource};
use rust_bert::Config; use rust_bert::Config;
use rust_tokenizers::{BertTokenizer, Tokenizer, TruncationStrategy}; use rust_tokenizers::{BertTokenizer, Tokenizer, TruncationStrategy};
use tch::{Tensor, Device, nn, no_grad}; use tch::{nn, no_grad, Device, Tensor};
fn main() -> failure::Fallible<()> { fn main() -> failure::Fallible<()> {
// Resources paths // Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(ElectraConfigResources::BASE_DISCRIMINATOR)); let config_resource = Resource::Remote(RemoteResource::from_pretrained(
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(ElectraVocabResources::BASE_DISCRIMINATOR)); ElectraConfigResources::BASE_DISCRIMINATOR,
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(ElectraModelResources::BASE_DISCRIMINATOR)); ));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
ElectraVocabResources::BASE_DISCRIMINATOR,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
ElectraModelResources::BASE_DISCRIMINATOR,
));
let config_path = download_resource(&config_resource)?; let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?; let vocab_path = download_resource(&vocab_resource)?;
let weights_path = download_resource(&weights_resource)?; let weights_path = download_resource(&weights_resource)?;
// Set-up masked LM model // Set-up masked LM model
let device = Device::Cpu; let device = Device::Cpu;
let mut vs = nn::VarStore::new(device); let mut vs = nn::VarStore::new(device);
let tokenizer: BertTokenizer = BertTokenizer::from_file(vocab_path.to_str().unwrap(), true); let tokenizer: BertTokenizer = BertTokenizer::from_file(vocab_path.to_str().unwrap(), true);
@ -36,45 +44,45 @@ fn main() -> failure::Fallible<()> {
let electra_model = ElectraDiscriminator::new(&vs.root(), &config); let electra_model = ElectraDiscriminator::new(&vs.root(), &config);
vs.load(weights_path)?; vs.load(weights_path)?;
// Define input // Define input
let input = ["One Two Three Ten Five Six Seven Eight"]; let input = ["One Two Three Ten Five Six Seven Eight"];
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0); let tokenized_input =
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap(); tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let encoded_input = tokenized_input. let max_len = tokenized_input
iter(). .iter()
map(|input| input.token_ids.clone()). .map(|input| input.token_ids.len())
map(|mut input| { .max()
.unwrap();
let encoded_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]); input.extend(vec![0; max_len - input.len()]);
input input
}). })
map(|input| .map(|input| Tensor::of_slice(&(input)))
Tensor::of_slice(&(input))). .collect::<Vec<_>>();
collect::<Vec<_>>();
let input_tensor = Tensor::stack(encoded_input.as_slice(), 0).to(device); let input_tensor = Tensor::stack(encoded_input.as_slice(), 0).to(device);
// Forward pass // Forward pass
let (output, _, _) = no_grad(|| { let (output, _, _) =
electra_model no_grad(|| electra_model.forward_t(Some(input_tensor), None, None, None, None, false));
.forward_t(Some(input_tensor),
None,
None,
None,
None,
false)
});
// Print model predictions // Print model predictions
for (position, token) in tokenized_input[0].token_ids.iter().enumerate() { for (position, token) in tokenized_input[0].token_ids.iter().enumerate() {
let probability = output.double_value(&[position as i64]); let probability = output.double_value(&[position as i64]);
let generated = if probability > 0.5 { "generated" } else { "original" }; let generated = if probability > 0.5 {
println!("{:?}: {} ({:.1}%)", "generated"
tokenizer.decode([*token].to_vec(), } else {
false, "original"
false), };
generated, println!(
100f64 * probability) "{:?}: {} ({:.1}%)",
tokenizer.decode([*token].to_vec(), false, false),
generated,
100f64 * probability
)
} }
Ok(()) Ok(())
} }

View File

@ -12,23 +12,31 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use rust_bert::electra::{
use rust_bert::resources::{Resource, download_resource, RemoteResource}; ElectraConfig, ElectraConfigResources, ElectraForMaskedLM, ElectraModelResources,
use rust_bert::electra::{ElectraConfig, ElectraForMaskedLM, ElectraModelResources, ElectraConfigResources, ElectraVocabResources}; ElectraVocabResources,
};
use rust_bert::resources::{download_resource, RemoteResource, Resource};
use rust_bert::Config; use rust_bert::Config;
use rust_tokenizers::{BertTokenizer, Tokenizer, TruncationStrategy, Vocab}; use rust_tokenizers::{BertTokenizer, Tokenizer, TruncationStrategy, Vocab};
use tch::{Tensor, Device, nn, no_grad}; use tch::{nn, no_grad, Device, Tensor};
fn main() -> failure::Fallible<()> { fn main() -> failure::Fallible<()> {
// Resources paths // Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(ElectraConfigResources::BASE_GENERATOR)); let config_resource = Resource::Remote(RemoteResource::from_pretrained(
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(ElectraVocabResources::BASE_GENERATOR)); ElectraConfigResources::BASE_GENERATOR,
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(ElectraModelResources::BASE_GENERATOR)); ));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
ElectraVocabResources::BASE_GENERATOR,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
ElectraModelResources::BASE_GENERATOR,
));
let config_path = download_resource(&config_resource)?; let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?; let vocab_path = download_resource(&vocab_resource)?;
let weights_path = download_resource(&weights_resource)?; let weights_path = download_resource(&weights_resource)?;
// Set-up masked LM model // Set-up masked LM model
let device = Device::Cpu; let device = Device::Cpu;
let mut vs = nn::VarStore::new(device); let mut vs = nn::VarStore::new(device);
let tokenizer: BertTokenizer = BertTokenizer::from_file(vocab_path.to_str().unwrap(), true); let tokenizer: BertTokenizer = BertTokenizer::from_file(vocab_path.to_str().unwrap(), true);
@ -36,41 +44,41 @@ fn main() -> failure::Fallible<()> {
let electra_model = ElectraForMaskedLM::new(&vs.root(), &config); let electra_model = ElectraForMaskedLM::new(&vs.root(), &config);
vs.load(weights_path)?; vs.load(weights_path)?;
// Define input // Define input
let input = ["Looks like one [MASK] is missing", "It was a very nice and [MASK] day"]; let input = [
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0); "Looks like one [MASK] is missing",
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap(); "It was a very nice and [MASK] day",
let tokenized_input = tokenized_input. ];
iter(). let tokenized_input =
map(|input| input.token_ids.clone()). tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
map(|mut input| { let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let tokenized_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]); input.extend(vec![0; max_len - input.len()]);
input input
}). })
map(|input| .map(|input| Tensor::of_slice(&(input)))
Tensor::of_slice(&(input))). .collect::<Vec<_>>();
collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device); let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass // Forward pass
let (output, _, _) = no_grad(|| { let (output, _, _) =
electra_model no_grad(|| electra_model.forward_t(Some(input_tensor), None, None, None, None, false));
.forward_t(Some(input_tensor),
None,
None,
None,
None,
false)
});
// Print masked tokens // Print masked tokens
let index_1 = output.get(0).get(4).argmax(0, false); let index_1 = output.get(0).get(4).argmax(0, false);
let index_2 = output.get(1).get(7).argmax(0, false); let index_2 = output.get(1).get(7).argmax(0, false);
let word_1 = tokenizer.vocab().id_to_token(&index_1.int64_value(&[])); let word_1 = tokenizer.vocab().id_to_token(&index_1.int64_value(&[]));
let word_2 = tokenizer.vocab().id_to_token(&index_2.int64_value(&[])); let word_2 = tokenizer.vocab().id_to_token(&index_2.int64_value(&[]));
println!("{}", word_1); // Outputs "thing" : "Looks like one [thing] is missing" println!("{}", word_1); // Outputs "thing" : "Looks like one [thing] is missing"
println!("{}", word_2);// Outputs "sunny" : "It was a very nice and [sunny] day" println!("{}", word_2); // Outputs "sunny" : "It was a very nice and [sunny] day"
Ok(()) Ok(())
} }

View File

@ -12,12 +12,10 @@
extern crate failure; extern crate failure;
use rust_bert::pipelines::generation::{GPT2Generator, LanguageGenerator, GenerateConfig}; use rust_bert::pipelines::generation::{GPT2Generator, GenerateConfig, LanguageGenerator};
fn main() -> failure::Fallible<()> { fn main() -> failure::Fallible<()> {
// Set-up masked LM model
// Set-up masked LM model
let generate_config = GenerateConfig { let generate_config = GenerateConfig {
max_length: 30, max_length: 30,
do_sample: true, do_sample: true,
@ -30,10 +28,10 @@ fn main() -> failure::Fallible<()> {
let input_context = "The dog"; let input_context = "The dog";
let second_input_context = "The cat was"; let second_input_context = "The cat was";
let output = model.generate(Some(vec!(input_context, second_input_context)), None); let output = model.generate(Some(vec![input_context, second_input_context]), None);
for sentence in output { for sentence in output {
println!("{:?}", sentence); println!("{:?}", sentence);
} }
Ok(()) Ok(())
} }

View File

@ -12,65 +12,82 @@
extern crate failure; extern crate failure;
use tch::{Device, nn, Tensor}; use rust_bert::gpt2::{
use rust_tokenizers::{TruncationStrategy, Tokenizer, Gpt2Tokenizer}; GPT2LMHeadModel, Gpt2Config, Gpt2ConfigResources, Gpt2MergesResources, Gpt2ModelResources,
use rust_bert::gpt2::{Gpt2Config, GPT2LMHeadModel, Gpt2ConfigResources, Gpt2VocabResources, Gpt2MergesResources, Gpt2ModelResources}; Gpt2VocabResources,
use rust_bert::pipelines::generation::{LMHeadModel, Cache}; };
use rust_bert::resources::{Resource, download_resource, RemoteResource}; use rust_bert::pipelines::generation::{Cache, LMHeadModel};
use rust_bert::resources::{download_resource, RemoteResource, Resource};
use rust_bert::Config; use rust_bert::Config;
use rust_tokenizers::{Gpt2Tokenizer, Tokenizer, TruncationStrategy};
use tch::{nn, Device, Tensor};
fn main() -> failure::Fallible<()> { fn main() -> failure::Fallible<()> {
// Resources set-up // Resources set-up
let config_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2)); let config_resource =
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2)); Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2)); let vocab_resource =
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2)); Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2));
let merges_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
let weights_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
let config_path = download_resource(&config_resource)?; let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?; let vocab_path = download_resource(&vocab_resource)?;
let merges_path = download_resource(&merges_resource)?; let merges_path = download_resource(&merges_resource)?;
let weights_path = download_resource(&weights_resource)?; let weights_path = download_resource(&weights_resource)?;
// Set-up masked LM model // Set-up masked LM model
let device = Device::Cpu; let device = Device::Cpu;
let mut vs = nn::VarStore::new(device); let mut vs = nn::VarStore::new(device);
let tokenizer: Gpt2Tokenizer = Gpt2Tokenizer::from_file(vocab_path.to_str().unwrap(), merges_path.to_str().unwrap(), false); let tokenizer: Gpt2Tokenizer = Gpt2Tokenizer::from_file(
vocab_path.to_str().unwrap(),
merges_path.to_str().unwrap(),
false,
);
let config = Gpt2Config::from_file(config_path); let config = Gpt2Config::from_file(config_path);
let gpt2_model = GPT2LMHeadModel::new(&vs.root(), &config); let gpt2_model = GPT2LMHeadModel::new(&vs.root(), &config);
vs.load(weights_path)?; vs.load(weights_path)?;
// Define input // Define input
let input = ["One two three four five six seven eight nine ten eleven"]; let input = ["One two three four five six seven eight nine ten eleven"];
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0); let tokenized_input =
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap(); tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let tokenized_input = tokenized_input. let max_len = tokenized_input
iter(). .iter()
map(|input| input.token_ids.clone()). .map(|input| input.token_ids.len())
map(|mut input| { .max()
.unwrap();
let tokenized_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]); input.extend(vec![0; max_len - input.len()]);
input input
}). })
map(|input| .map(|input| Tensor::of_slice(&(input)))
Tensor::of_slice(&(input))). .collect::<Vec<_>>();
collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device); let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass // Forward pass
let (output, _, _, _, _) = gpt2_model.forward_t( let (output, _, _, _, _) = gpt2_model
&Some(input_tensor), .forward_t(
Cache::None, &Some(input_tensor),
&None, Cache::None,
&None, &None,
&None, &None,
&None, &None,
None, &None,
&None, None,
false).unwrap(); &None,
false,
)
.unwrap();
let next_word_id = output.get(0).get(-1).argmax(-1, true).int64_value(&[0]); let next_word_id = output.get(0).get(-1).argmax(-1, true).int64_value(&[0]);
let next_word = tokenizer.decode(vec!(next_word_id), true, true); let next_word = tokenizer.decode(vec![next_word_id], true, true);
println!("Provided input: {}", input[0]); println!("Provided input: {}", input[0]);
println!("Next word: {}", next_word); println!("Next word: {}", next_word);
Ok(()) Ok(())
} }

View File

@ -15,20 +15,20 @@ extern crate failure;
use rust_bert::pipelines::ner::NERModel; use rust_bert::pipelines::ner::NERModel;
fn main() -> failure::Fallible<()> { fn main() -> failure::Fallible<()> {
// Set-up model // Set-up model
let ner_model = NERModel::new(Default::default())?; let ner_model = NERModel::new(Default::default())?;
// Define input // Define input
let input = [ let input = [
"My name is Amélie. I live in Москва.", "My name is Amélie. I live in Москва.",
"Chongqing is a city in China." "Chongqing is a city in China.",
]; ];
// Run model // Run model
let output = ner_model.predict(&input); let output = ner_model.predict(&input);
for entity in output { for entity in output {
println!("{:?}", entity); println!("{:?}", entity);
} }
Ok(()) Ok(())
} }

View File

@ -12,66 +12,87 @@
extern crate failure; extern crate failure;
use tch::{Device, nn, Tensor};
use rust_tokenizers::{TruncationStrategy, Tokenizer, OpenAiGptTokenizer};
use rust_bert::gpt2::Gpt2Config; use rust_bert::gpt2::Gpt2Config;
use rust_bert::openai_gpt::{OpenAIGPTLMHeadModel, OpenAiGptConfigResources, OpenAiGptVocabResources, OpenAiGptMergesResources, OpenAiGptModelResources}; use rust_bert::openai_gpt::{
use rust_bert::pipelines::generation::{LMHeadModel, Cache}; OpenAIGPTLMHeadModel, OpenAiGptConfigResources, OpenAiGptMergesResources,
use rust_bert::resources::{Resource, download_resource, RemoteResource}; OpenAiGptModelResources, OpenAiGptVocabResources,
};
use rust_bert::pipelines::generation::{Cache, LMHeadModel};
use rust_bert::resources::{download_resource, RemoteResource, Resource};
use rust_bert::Config; use rust_bert::Config;
use rust_tokenizers::{OpenAiGptTokenizer, Tokenizer, TruncationStrategy};
use tch::{nn, Device, Tensor};
fn main() -> failure::Fallible<()> { fn main() -> failure::Fallible<()> {
// Resources paths // Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptConfigResources::GPT)); let config_resource = Resource::Remote(RemoteResource::from_pretrained(
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptVocabResources::GPT)); OpenAiGptConfigResources::GPT,
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptMergesResources::GPT)); ));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptModelResources::GPT)); let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
OpenAiGptVocabResources::GPT,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
OpenAiGptMergesResources::GPT,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
OpenAiGptModelResources::GPT,
));
let config_path = download_resource(&config_resource)?; let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?; let vocab_path = download_resource(&vocab_resource)?;
let merges_path = download_resource(&merges_resource)?; let merges_path = download_resource(&merges_resource)?;
let weights_path = download_resource(&weights_resource)?; let weights_path = download_resource(&weights_resource)?;
// Set-up masked LM model // Set-up masked LM model
let device = Device::Cpu; let device = Device::Cpu;
let mut vs = nn::VarStore::new(device); let mut vs = nn::VarStore::new(device);
let tokenizer = OpenAiGptTokenizer::from_file(vocab_path.to_str().unwrap(), merges_path.to_str().unwrap(), true); let tokenizer = OpenAiGptTokenizer::from_file(
vocab_path.to_str().unwrap(),
merges_path.to_str().unwrap(),
true,
);
let config = Gpt2Config::from_file(config_path); let config = Gpt2Config::from_file(config_path);
let openai_gpt = OpenAIGPTLMHeadModel::new(&vs.root(), &config); let openai_gpt = OpenAIGPTLMHeadModel::new(&vs.root(), &config);
vs.load(weights_path)?; vs.load(weights_path)?;
// Define input // Define input
let input = ["Wondering what the next word will"]; let input = ["Wondering what the next word will"];
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0); let tokenized_input =
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap(); tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let tokenized_input = tokenized_input. let max_len = tokenized_input
iter(). .iter()
map(|input| input.token_ids.clone()). .map(|input| input.token_ids.len())
map(|mut input| { .max()
.unwrap();
let tokenized_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]); input.extend(vec![0; max_len - input.len()]);
input input
}). })
map(|input| .map(|input| Tensor::of_slice(&(input)))
Tensor::of_slice(&(input))). .collect::<Vec<_>>();
collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device); let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass // Forward pass
let (output, _, _, _, _) = openai_gpt.forward_t( let (output, _, _, _, _) = openai_gpt
&Some(input_tensor), .forward_t(
Cache::None, &Some(input_tensor),
&None, Cache::None,
&None, &None,
&None, &None,
&None, &None,
None, &None,
&None, None,
false).unwrap(); &None,
false,
)
.unwrap();
let next_word_id = output.get(0).get(-1).argmax(-1, true).int64_value(&[0]); let next_word_id = output.get(0).get(-1).argmax(-1, true).int64_value(&[0]);
let next_word = tokenizer.decode(vec!(next_word_id), true, true); let next_word = tokenizer.decode(vec![next_word_id], true, true);
println!("Provided input: {}", input[0]); println!("Provided input: {}", input[0]);
println!("Next word: {}", next_word); println!("Next word: {}", next_word);
Ok(()) Ok(())
} }

View File

@ -12,23 +12,28 @@
extern crate failure; extern crate failure;
use rust_bert::pipelines::question_answering::{QuestionAnsweringModel, QaInput}; use rust_bert::pipelines::question_answering::{QaInput, QuestionAnsweringModel};
fn main() -> failure::Fallible<()> { fn main() -> failure::Fallible<()> {
// Set-up Question Answering model // Set-up Question Answering model
let qa_model = QuestionAnsweringModel::new(Default::default())?; let qa_model = QuestionAnsweringModel::new(Default::default())?;
// Define input // Define input
let question_1 = String::from("Where does Amy live ?"); let question_1 = String::from("Where does Amy live ?");
let context_1 = String::from("Amy lives in Amsterdam"); let context_1 = String::from("Amy lives in Amsterdam");
let question_2 = String::from("Where does Eric live"); let question_2 = String::from("Where does Eric live");
let context_2 = String::from("While Amy lives in Amsterdam, Eric is in The Hague."); let context_2 = String::from("While Amy lives in Amsterdam, Eric is in The Hague.");
let qa_input_1 = QaInput { question: question_1, context: context_1 }; let qa_input_1 = QaInput {
let qa_input_2 = QaInput { question: question_2, context: context_2 }; question: question_1,
context: context_1,
};
let qa_input_2 = QaInput {
question: question_2,
context: context_2,
};
// Get answer // Get answer
let answers = qa_model.predict(&vec!(qa_input_1, qa_input_2), 1, 32); let answers = qa_model.predict(&vec![qa_input_1, qa_input_2], 1, 32);
println!("{:?}", answers); println!("{:?}", answers);
Ok(()) Ok(())
} }

View File

@ -12,77 +12,99 @@
extern crate failure; extern crate failure;
use tch::{Device, nn, Tensor, no_grad};
use rust_tokenizers::{TruncationStrategy, Tokenizer, Vocab, RobertaTokenizer};
use rust_bert::Config;
use rust_bert::bert::BertConfig; use rust_bert::bert::BertConfig;
use rust_bert::roberta::{RobertaForMaskedLM, RobertaVocabResources, RobertaConfigResources, RobertaMergesResources, RobertaModelResources}; use rust_bert::resources::{download_resource, RemoteResource, Resource};
use rust_bert::resources::{Resource, download_resource, RemoteResource}; use rust_bert::roberta::{
RobertaConfigResources, RobertaForMaskedLM, RobertaMergesResources, RobertaModelResources,
RobertaVocabResources,
};
use rust_bert::Config;
use rust_tokenizers::{RobertaTokenizer, Tokenizer, TruncationStrategy, Vocab};
use tch::{nn, no_grad, Device, Tensor};
fn main() -> failure::Fallible<()> { fn main() -> failure::Fallible<()> {
// Resources paths // Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaConfigResources::ROBERTA)); let config_resource = Resource::Remote(RemoteResource::from_pretrained(
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaVocabResources::ROBERTA)); RobertaConfigResources::ROBERTA,
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaMergesResources::ROBERTA)); ));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaModelResources::ROBERTA)); let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
RobertaVocabResources::ROBERTA,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
RobertaMergesResources::ROBERTA,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
RobertaModelResources::ROBERTA,
));
let config_path = download_resource(&config_resource)?; let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?; let vocab_path = download_resource(&vocab_resource)?;
let merges_path = download_resource(&merges_resource)?; let merges_path = download_resource(&merges_resource)?;
let weights_path = download_resource(&weights_resource)?; let weights_path = download_resource(&weights_resource)?;
// Set-up masked LM model // Set-up masked LM model
let device = Device::Cpu; let device = Device::Cpu;
let mut vs = nn::VarStore::new(device); let mut vs = nn::VarStore::new(device);
let tokenizer: RobertaTokenizer = RobertaTokenizer::from_file(vocab_path.to_str().unwrap(), merges_path.to_str().unwrap(), true); let tokenizer: RobertaTokenizer = RobertaTokenizer::from_file(
vocab_path.to_str().unwrap(),
merges_path.to_str().unwrap(),
true,
);
let config = BertConfig::from_file(config_path); let config = BertConfig::from_file(config_path);
let bert_model = RobertaForMaskedLM::new(&vs.root(), &config); let bert_model = RobertaForMaskedLM::new(&vs.root(), &config);
vs.load(weights_path)?; vs.load(weights_path)?;
// Define input // Define input
let input = ["<pad> Looks like one thing is missing", "It\'s like comparing oranges to apples"]; let input = [
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0); "<pad> Looks like one thing is missing",
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap(); "It\'s like comparing oranges to apples",
let mut tokenized_input = tokenized_input. ];
iter(). let tokenized_input =
map(|input| input.token_ids.clone()). tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
map(|mut input| { let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let mut tokenized_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]); input.extend(vec![0; max_len - input.len()]);
input input
}). })
collect::<Vec<_>>(); .collect::<Vec<_>>();
// Masking the token [thing] of sentence 1 and [oranges] of sentence 2 // Masking the token [thing] of sentence 1 and [oranges] of sentence 2
tokenized_input[0][4] = 103; tokenized_input[0][4] = 103;
tokenized_input[1][5] = 103; tokenized_input[1][5] = 103;
let tokenized_input = tokenized_input. let tokenized_input = tokenized_input
iter(). .iter()
map(|input| .map(|input| Tensor::of_slice(&(input)))
Tensor::of_slice(&(input))). .collect::<Vec<_>>();
collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device); let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass // Forward pass
let (output, _, _) = no_grad(|| { let (output, _, _) = no_grad(|| {
bert_model bert_model.forward_t(
.forward_t(Some(input_tensor), Some(input_tensor),
None, None,
None, None,
None, None,
None, None,
&None, &None,
&None, &None,
false) false,
)
}); });
// Print masked tokens // Print masked tokens
let index_1 = output.get(0).get(4).argmax(0, false); let index_1 = output.get(0).get(4).argmax(0, false);
let index_2 = output.get(1).get(5).argmax(0, false); let index_2 = output.get(1).get(5).argmax(0, false);
let word_1 = tokenizer.vocab().id_to_token(&index_1.int64_value(&[])); let word_1 = tokenizer.vocab().id_to_token(&index_1.int64_value(&[]));
let word_2 = tokenizer.vocab().id_to_token(&index_2.int64_value(&[])); let word_2 = tokenizer.vocab().id_to_token(&index_2.int64_value(&[]));
println!("{}", word_1); // Outputs "some" : "Looks like [some] thing is missing" println!("{}", word_1); // Outputs "some" : "Looks like [some] thing is missing"
println!("{}", word_2);// Outputs "apple" : "It\'s like comparing [apple] to apples" println!("{}", word_2); // Outputs "apple" : "It\'s like comparing [apple] to apples"
Ok(()) Ok(())
} }

View File

@ -14,23 +14,22 @@ extern crate failure;
use rust_bert::pipelines::sentiment::SentimentModel; use rust_bert::pipelines::sentiment::SentimentModel;
fn main() -> failure::Fallible<()> { fn main() -> failure::Fallible<()> {
// Set-up classifier // Set-up classifier
let sentiment_classifier = SentimentModel::new(Default::default())?; let sentiment_classifier = SentimentModel::new(Default::default())?;
// Define input // Define input
let input = [ let input = [
"Probably my all-time favorite movie, a story of selflessness, sacrifice and dedication to a noble cause, but it's not preachy or boring.", "Probably my all-time favorite movie, a story of selflessness, sacrifice and dedication to a noble cause, but it's not preachy or boring.",
"This film tried to be too many things all at once: stinging political satire, Hollywood blockbuster, sappy romantic comedy, family values promo...", "This film tried to be too many things all at once: stinging political satire, Hollywood blockbuster, sappy romantic comedy, family values promo...",
"If you like original gut wrenching laughter you will like this movie. If you are young or old then you will love this movie, hell even my mom liked it.", "If you like original gut wrenching laughter you will like this movie. If you are young or old then you will love this movie, hell even my mom liked it.",
]; ];
// Run model // Run model
let output = sentiment_classifier.predict(&input); let output = sentiment_classifier.predict(&input);
for sentiment in output { for sentiment in output {
println!("{:?}", sentiment); println!("{:?}", sentiment);
} }
Ok(()) Ok(())
} }

View File

@ -15,21 +15,21 @@ extern crate failure;
use rust_bert::pipelines::sequence_classification::SequenceClassificationModel; use rust_bert::pipelines::sequence_classification::SequenceClassificationModel;
fn main() -> failure::Fallible<()> { fn main() -> failure::Fallible<()> {
// Set-up model // Set-up model
let sequence_classification_model = SequenceClassificationModel::new(Default::default())?; let sequence_classification_model = SequenceClassificationModel::new(Default::default())?;
// Define input // Define input
let input = [ let input = [
"Probably my all-time favorite movie, a story of selflessness, sacrifice and dedication to a noble cause, but it's not preachy or boring.", "Probably my all-time favorite movie, a story of selflessness, sacrifice and dedication to a noble cause, but it's not preachy or boring.",
"This film tried to be too many things all at once: stinging political satire, Hollywood blockbuster, sappy romantic comedy, family values promo...", "This film tried to be too many things all at once: stinging political satire, Hollywood blockbuster, sappy romantic comedy, family values promo...",
"If you like original gut wrenching laughter you will like this movie. If you are young or old then you will love this movie, hell even my mom liked it.", "If you like original gut wrenching laughter you will like this movie. If you are young or old then you will love this movie, hell even my mom liked it.",
]; ];
// Run model // Run model
let output = sequence_classification_model.predict(&input); let output = sequence_classification_model.predict(&input);
for label in output { for label in output {
println!("{:?}", label); println!("{:?}", label);
} }
Ok(()) Ok(())
} }

View File

@ -15,21 +15,21 @@ extern crate failure;
use rust_bert::pipelines::sequence_classification::SequenceClassificationModel; use rust_bert::pipelines::sequence_classification::SequenceClassificationModel;
fn main() -> failure::Fallible<()> { fn main() -> failure::Fallible<()> {
// Set-up model // Set-up model
let sequence_classification_model = SequenceClassificationModel::new(Default::default())?; let sequence_classification_model = SequenceClassificationModel::new(Default::default())?;
// Define input // Define input
let input = [ let input = [
"Probably my all-time favorite movie, a story of selflessness, sacrifice and dedication to a noble cause, but it's not preachy or boring.", "Probably my all-time favorite movie, a story of selflessness, sacrifice and dedication to a noble cause, but it's not preachy or boring.",
"This is a neutral sentence.", "This is a neutral sentence.",
"If you like original gut wrenching laughter you will like this movie. If you are young or old then you will love this movie, hell even my mom liked it.", "If you like original gut wrenching laughter you will like this movie. If you are young or old then you will love this movie, hell even my mom liked it.",
]; ];
// Run model // Run model
let output = sequence_classification_model.predict_multilabel(&input, 0.05); let output = sequence_classification_model.predict_multilabel(&input, 0.05);
for label in output { for label in output {
println!("{:?}", label); println!("{:?}", label);
} }
Ok(()) Ok(())
} }

View File

@ -12,24 +12,23 @@
extern crate failure; extern crate failure;
use std::path::PathBuf; use rust_bert::pipelines::question_answering::{squad_processor, QuestionAnsweringModel};
use std::env; use std::env;
use rust_bert::pipelines::question_answering::{QuestionAnsweringModel, squad_processor}; use std::path::PathBuf;
fn main() -> failure::Fallible<()> { fn main() -> failure::Fallible<()> {
// Set-up Question Answering model // Set-up Question Answering model
let qa_model = QuestionAnsweringModel::new(Default::default())?; let qa_model = QuestionAnsweringModel::new(Default::default())?;
// Define input // Define input
let mut squad_path = PathBuf::from(env::var("squad_dataset") let mut squad_path = PathBuf::from(env::var("squad_dataset")
.expect("Please set the \"squad_dataset\" environment variable pointing to the SQuAD dataset folder")); .expect("Please set the \"squad_dataset\" environment variable pointing to the SQuAD dataset folder"));
squad_path.push("dev-v2.0.json"); squad_path.push("dev-v2.0.json");
let qa_inputs = squad_processor(squad_path); let qa_inputs = squad_processor(squad_path);
// Get answer // Get answer
let answers = qa_model.predict(&qa_inputs, 1, 64); let answers = qa_model.predict(&qa_inputs, 1, 64);
println!("Sample answer: {:?}", answers.first().unwrap()); println!("Sample answer: {:?}", answers.first().unwrap());
println!("{}", answers.len()); println!("{}", answers.len());
Ok(()) Ok(())
} }

View File

@ -10,35 +10,42 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
extern crate failure;
extern crate dirs; extern crate dirs;
extern crate failure;
use std::path::PathBuf; use rust_bert::pipelines::sentiment::{ss2_processor, SentimentModel};
use rust_bert::pipelines::sentiment::{SentimentModel, ss2_processor};
use std::env; use std::env;
use std::path::PathBuf;
fn main() -> failure::Fallible<()> { fn main() -> failure::Fallible<()> {
// Set-up classifier // Set-up classifier
let sentiment_classifier = SentimentModel::new(Default::default())?; let sentiment_classifier = SentimentModel::new(Default::default())?;
// Define input // Define input
let mut sst2_path = PathBuf::from(env::var("SST2_PATH") let mut sst2_path = PathBuf::from(env::var("SST2_PATH")
.expect("Please set the \"squad_dataset\" environment variable pointing to the SQuAD dataset folder")); .expect("Please set the \"squad_dataset\" environment variable pointing to the SQuAD dataset folder"));
sst2_path.push("train.tsv"); sst2_path.push("train.tsv");
let inputs = ss2_processor(sst2_path).unwrap(); let inputs = ss2_processor(sst2_path).unwrap();
// Run model // Run model
let batch_size = 64; let batch_size = 64;
let mut output = vec!(); let mut output = vec![];
for batch in inputs.chunks(batch_size) { for batch in inputs.chunks(batch_size) {
output.push(sentiment_classifier.predict(batch.iter().map(|v| v.as_str()).collect::<Vec<&str>>().as_slice())); output.push(
sentiment_classifier.predict(
batch
.iter()
.map(|v| v.as_str())
.collect::<Vec<&str>>()
.as_slice(),
),
);
} }
let mut flat_outputs = vec!(); let mut flat_outputs = vec![];
for batch_output in output.iter_mut() { for batch_output in output.iter_mut() {
flat_outputs.append(batch_output); flat_outputs.append(batch_output);
} }
println!("{:?}", flat_outputs.len()); println!("{:?}", flat_outputs.len());
Ok(()) Ok(())
} }

View File

@ -14,7 +14,6 @@ extern crate failure;
use rust_bert::pipelines::summarization::SummarizationModel; use rust_bert::pipelines::summarization::SummarizationModel;
fn main() -> failure::Fallible<()> { fn main() -> failure::Fallible<()> {
let summarization_model = SummarizationModel::new(Default::default())?; let summarization_model = SummarizationModel::new(Default::default())?;
@ -40,11 +39,11 @@ on K2-18b lasts 33 Earth days. According to The Guardian, astronomers were optim
telescope scheduled for launch in 2021 and the European Space Agency's 2028 ARIEL program, could reveal more \ telescope scheduled for launch in 2021 and the European Space Agency's 2028 ARIEL program, could reveal more \
about exoplanets like K2-18b."]; about exoplanets like K2-18b."];
// Credits: WikiNews, CC BY 2.5 license (https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b) // Credits: WikiNews, CC BY 2.5 license (https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b)
let _output = summarization_model.summarize(&input); let _output = summarization_model.summarize(&input);
for sentence in _output { for sentence in _output {
println!("{:?}", sentence); println!("{:?}", sentence);
}; }
Ok(()) Ok(())
} }

View File

@ -10,28 +10,36 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use rust_bert::pipelines::token_classification::{TokenClassificationModel, TokenClassificationConfig, LabelAggregationOption}; use rust_bert::bert::{BertConfigResources, BertModelResources, BertVocabResources};
use rust_bert::resources::{Resource, RemoteResource};
use rust_bert::bert::{BertModelResources, BertVocabResources, BertConfigResources};
use rust_bert::pipelines::common::ModelType; use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::token_classification::{
LabelAggregationOption, TokenClassificationConfig, TokenClassificationModel,
};
use rust_bert::resources::{RemoteResource, Resource};
fn main() -> failure::Fallible<()> { fn main() -> failure::Fallible<()> {
// Load a configuration
// Load a configuration let config = TokenClassificationConfig::new(
let config = TokenClassificationConfig::new(ModelType::Bert, ModelType::Bert,
Resource::Remote(RemoteResource::from_pretrained(BertModelResources::BERT_NER)), Resource::Remote(RemoteResource::from_pretrained(
Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT_NER)), BertModelResources::BERT_NER,
Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT_NER)), )),
None, //merges resource only relevant with ModelType::Roberta Resource::Remote(RemoteResource::from_pretrained(
false, //lowercase BertConfigResources::BERT_NER,
LabelAggregationOption::Mode, )),
Resource::Remote(RemoteResource::from_pretrained(
BertVocabResources::BERT_NER,
)),
None, //merges resource only relevant with ModelType::Roberta
false, //lowercase
LabelAggregationOption::Mode,
); );
// Create the model // Create the model
let token_classification_model = TokenClassificationModel::new(config)?; let token_classification_model = TokenClassificationModel::new(config)?;
let input = [ let input = [
"My name is Amélie. I live in Москва.", "My name is Amélie. I live in Москва.",
"Chongqing is a city in China." "Chongqing is a city in China.",
]; ];
let token_outputs = token_classification_model.predict(&input, true, false); //ignore_first_label = true (only returns the NER parts, ignoring first label O) let token_outputs = token_classification_model.predict(&input, true, false); //ignore_first_label = true (only returns the NER parts, ignoring first label O)
@ -40,4 +48,4 @@ fn main() -> failure::Fallible<()> {
} }
Ok(()) Ok(())
} }

View File

@ -13,12 +13,12 @@
extern crate failure; extern crate failure;
use rust_bert::pipelines::translation::{TranslationConfig, TranslationModel, Language}; use rust_bert::pipelines::translation::{Language, TranslationConfig, TranslationModel};
use tch::Device; use tch::Device;
fn main() -> failure::Fallible<()> { fn main() -> failure::Fallible<()> {
let translation_config =
let translation_config = TranslationConfig::new(Language::EnglishToGerman, Device::cuda_if_available()); TranslationConfig::new(Language::EnglishToGerman, Device::cuda_if_available());
let model = TranslationModel::new(translation_config)?; let model = TranslationModel::new(translation_config)?;
let input_context_1 = "The quick brown fox jumps over the lazy dog"; let input_context_1 = "The quick brown fox jumps over the lazy dog";
@ -30,4 +30,4 @@ fn main() -> failure::Fallible<()> {
println!("{}", sentence); println!("{}", sentence);
} }
Ok(()) Ok(())
} }

1
rustfmt.toml Normal file
View File

@ -0,0 +1 @@
format_code_in_doc_comments = true

View File

@ -11,16 +11,15 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use std::collections::HashMap;
use crate::Config;
use serde::{Deserialize, Serialize};
use crate::albert::embeddings::AlbertEmbeddings; use crate::albert::embeddings::AlbertEmbeddings;
use crate::albert::encoder::AlbertTransformer; use crate::albert::encoder::AlbertTransformer;
use tch::{nn, Tensor, Kind}; use crate::common::activations::{_gelu, _gelu_new, _mish, _relu, _tanh};
use crate::common::activations::{_tanh, _gelu_new, _gelu, _relu, _mish};
use tch::nn::Module;
use crate::common::dropout::Dropout; use crate::common::dropout::Dropout;
use crate::Config;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use tch::nn::Module;
use tch::{nn, Kind, Tensor};
/// # ALBERT Pretrained model weight files /// # ALBERT Pretrained model weight files
pub struct AlbertModelResources; pub struct AlbertModelResources;
@ -33,20 +32,28 @@ pub struct AlbertVocabResources;
impl AlbertModelResources { impl AlbertModelResources {
/// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/ALBERT. Modified with conversion to C-array format. /// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/ALBERT. Modified with conversion to C-array format.
pub const ALBERT_BASE_V2: (&'static str, &'static str) = ("albert-base-v2/model.ot", "https://cdn.huggingface.co/albert-base-v2/rust_model.ot"); pub const ALBERT_BASE_V2: (&'static str, &'static str) = (
"albert-base-v2/model.ot",
"https://cdn.huggingface.co/albert-base-v2/rust_model.ot",
);
} }
impl AlbertConfigResources { impl AlbertConfigResources {
/// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/ALBERT. Modified with conversion to C-array format. /// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/ALBERT. Modified with conversion to C-array format.
pub const ALBERT_BASE_V2: (&'static str, &'static str) = ("albert-base-v2/config.json", "https://cdn.huggingface.co/albert-base-v2-config.json"); pub const ALBERT_BASE_V2: (&'static str, &'static str) = (
"albert-base-v2/config.json",
"https://cdn.huggingface.co/albert-base-v2-config.json",
);
} }
impl AlbertVocabResources { impl AlbertVocabResources {
/// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/ALBERT. Modified with conversion to C-array format. /// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/ALBERT. Modified with conversion to C-array format.
pub const ALBERT_BASE_V2: (&'static str, &'static str) = ("albert-base-v2/spiece.model", "https://cdn.huggingface.co/albert-base-v2-spiece.model"); pub const ALBERT_BASE_V2: (&'static str, &'static str) = (
"albert-base-v2/spiece.model",
"https://cdn.huggingface.co/albert-base-v2-spiece.model",
);
} }
#[allow(non_camel_case_types)] #[allow(non_camel_case_types)]
#[derive(Clone, Debug, Serialize, Deserialize)] #[derive(Clone, Debug, Serialize, Deserialize)]
/// # Activation function used in the attention layer and masked language model head /// # Activation function used in the attention layer and masked language model head
@ -61,7 +68,6 @@ pub enum Activation {
mish, mish,
} }
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
/// # ALBERT model configuration /// # ALBERT model configuration
/// Defines the ALBERT model architecture (e.g. number of layers, hidden layer size, label mapping...) /// Defines the ALBERT model architecture (e.g. number of layers, hidden layer size, label mapping...)
@ -123,10 +129,10 @@ impl AlbertModel {
/// # Example /// # Example
/// ///
/// ```no_run /// ```no_run
/// use tch::{nn, Device}; /// use rust_bert::albert::{AlbertConfig, AlbertModel};
/// use rust_bert::Config; /// use rust_bert::Config;
/// use std::path::Path; /// use std::path::Path;
/// use rust_bert::albert::{AlbertConfig, AlbertModel}; /// use tch::{nn, Device};
/// ///
/// let config_path = Path::new("path/to/config.json"); /// let config_path = Path::new("path/to/config.json");
/// let device = Device::Cpu; /// let device = Device::Cpu;
@ -134,14 +140,23 @@ impl AlbertModel {
/// let config = AlbertConfig::from_file(config_path); /// let config = AlbertConfig::from_file(config_path);
/// let albert: AlbertModel = AlbertModel::new(&(&p.root() / "albert"), &config); /// let albert: AlbertModel = AlbertModel::new(&(&p.root() / "albert"), &config);
/// ``` /// ```
///
pub fn new(p: &nn::Path, config: &AlbertConfig) -> AlbertModel { pub fn new(p: &nn::Path, config: &AlbertConfig) -> AlbertModel {
let embeddings = AlbertEmbeddings::new(&(p / "embeddings"), config); let embeddings = AlbertEmbeddings::new(&(p / "embeddings"), config);
let encoder = AlbertTransformer::new(&(p / "encoder"), config); let encoder = AlbertTransformer::new(&(p / "encoder"), config);
let pooler = nn::linear(&(p / "pooler"), config.hidden_size, config.hidden_size, Default::default()); let pooler = nn::linear(
&(p / "pooler"),
config.hidden_size,
config.hidden_size,
Default::default(),
);
let pooler_activation = Box::new(_tanh); let pooler_activation = Box::new(_tanh);
AlbertModel { embeddings, encoder, pooler, pooler_activation } AlbertModel {
embeddings,
encoder,
pooler,
pooler_activation,
}
} }
/// Forward pass through the model /// Forward pass through the model
@ -165,75 +180,103 @@ impl AlbertModel {
/// # Example /// # Example
/// ///
/// ```no_run /// ```no_run
///# use tch::{nn, Device, Tensor, no_grad}; /// # use tch::{nn, Device, Tensor, no_grad};
///# use rust_bert::Config; /// # use rust_bert::Config;
///# use std::path::Path; /// # use std::path::Path;
///# use tch::kind::Kind::Int64; /// # use tch::kind::Kind::Int64;
/// use rust_bert::albert::{AlbertConfig, AlbertModel}; /// use rust_bert::albert::{AlbertConfig, AlbertModel};
///# let config_path = Path::new("path/to/config.json"); /// # let config_path = Path::new("path/to/config.json");
///# let device = Device::Cpu; /// # let device = Device::Cpu;
///# let vs = nn::VarStore::new(device); /// # let vs = nn::VarStore::new(device);
///# let config = AlbertConfig::from_file(config_path); /// # let config = AlbertConfig::from_file(config_path);
///# let albert_model: AlbertModel = AlbertModel::new(&vs.root(), &config); /// # let albert_model: AlbertModel = AlbertModel::new(&vs.root(), &config);
/// let (batch_size, sequence_length) = (64, 128); /// let (batch_size, sequence_length) = (64, 128);
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device)); /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device)); /// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
/// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device)); /// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
/// let position_ids = Tensor::arange(sequence_length, (Int64, device)).expand(&[batch_size, sequence_length], true); /// let position_ids = Tensor::arange(sequence_length, (Int64, device))
/// /// .expand(&[batch_size, sequence_length], true);
/// let (output, pooled_output, all_hidden_states, all_attentions) = no_grad(|| {
/// albert_model
/// .forward_t(Some(input_tensor),
/// Some(mask),
/// Some(token_type_ids),
/// Some(position_ids),
/// None,
/// false).unwrap()
/// });
/// ///
/// let (output, pooled_output, all_hidden_states, all_attentions) = no_grad(|| {
/// albert_model
/// .forward_t(
/// Some(input_tensor),
/// Some(mask),
/// Some(token_type_ids),
/// Some(position_ids),
/// None,
/// false,
/// )
/// .unwrap()
/// });
/// ``` /// ```
/// pub fn forward_t(
pub fn forward_t(&self, &self,
input_ids: Option<Tensor>, input_ids: Option<Tensor>,
mask: Option<Tensor>, mask: Option<Tensor>,
token_type_ids: Option<Tensor>, token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>, position_ids: Option<Tensor>,
input_embeds: Option<Tensor>, input_embeds: Option<Tensor>,
train: bool) train: bool,
-> Result<(Tensor, Tensor, Option<Vec<Tensor>>, Option<Vec<Vec<Tensor>>>), &'static str> { ) -> Result<
(
Tensor,
Tensor,
Option<Vec<Tensor>>,
Option<Vec<Vec<Tensor>>>,
),
&'static str,
> {
let (input_shape, device) = match &input_ids { let (input_shape, device) = match &input_ids {
Some(input_value) => match &input_embeds { Some(input_value) => match &input_embeds {
Some(_) => { return Err("Only one of input ids or input embeddings may be set"); } Some(_) => {
None => (input_value.size(), input_value.device()) return Err("Only one of input ids or input embeddings may be set");
} }
None => (input_value.size(), input_value.device()),
},
None => match &input_embeds { None => match &input_embeds {
Some(embeds) => (vec!(embeds.size()[0], embeds.size()[1]), embeds.device()), Some(embeds) => (vec![embeds.size()[0], embeds.size()[1]], embeds.device()),
None => { return Err("At least one of input ids or input embeddings must be set"); } None => {
} return Err("At least one of input ids or input embeddings must be set");
}
},
}; };
let mask = match mask { let mask = match mask {
Some(value) => value, Some(value) => value,
None => Tensor::ones(&input_shape, (Kind::Int64, device)) None => Tensor::ones(&input_shape, (Kind::Int64, device)),
}; };
let extended_attention_mask = mask.unsqueeze(1).unsqueeze(2); let extended_attention_mask = mask.unsqueeze(1).unsqueeze(2);
let extended_attention_mask: Tensor = (extended_attention_mask.ones_like() - extended_attention_mask) * -10000.0; let extended_attention_mask: Tensor =
(extended_attention_mask.ones_like() - extended_attention_mask) * -10000.0;
let embedding_output = match self.embeddings.forward_t(input_ids, token_type_ids, position_ids, input_embeds, train) { let embedding_output = match self.embeddings.forward_t(
input_ids,
token_type_ids,
position_ids,
input_embeds,
train,
) {
Ok(value) => value, Ok(value) => value,
Err(e) => { return Err(e); } Err(e) => {
return Err(e);
}
}; };
let (hidden_state, all_hidden_states, all_attentions) = let (hidden_state, all_hidden_states, all_attentions) =
self.encoder.forward_t(&embedding_output, self.encoder
Some(extended_attention_mask), .forward_t(&embedding_output, Some(extended_attention_mask), train);
train);
let pooled_output = self.pooler.forward(&hidden_state.select(1, 0)); let pooled_output = self.pooler.forward(&hidden_state.select(1, 0));
let pooled_output = (self.pooler_activation)(&pooled_output); let pooled_output = (self.pooler_activation)(&pooled_output);
Ok((hidden_state, pooled_output, all_hidden_states, all_attentions)) Ok((
hidden_state,
pooled_output,
all_hidden_states,
all_attentions,
))
} }
} }
@ -248,21 +291,43 @@ impl AlbertMLMHead {
pub fn new(p: &nn::Path, config: &AlbertConfig) -> AlbertMLMHead { pub fn new(p: &nn::Path, config: &AlbertConfig) -> AlbertMLMHead {
let layer_norm_eps = match config.layer_norm_eps { let layer_norm_eps = match config.layer_norm_eps {
Some(value) => value, Some(value) => value,
None => 1e-12 None => 1e-12,
}; };
let layer_norm_config = nn::LayerNormConfig { eps: layer_norm_eps, ..Default::default() }; let layer_norm_config = nn::LayerNormConfig {
let layer_norm = nn::layer_norm(&(p / "LayerNorm"), vec![config.embedding_size], layer_norm_config); eps: layer_norm_eps,
let dense = nn::linear(&(p / "dense"), config.hidden_size, config.embedding_size, Default::default()); ..Default::default()
let decoder = nn::linear(&(p / "decoder"), config.embedding_size, config.vocab_size, Default::default()); };
let layer_norm = nn::layer_norm(
&(p / "LayerNorm"),
vec![config.embedding_size],
layer_norm_config,
);
let dense = nn::linear(
&(p / "dense"),
config.hidden_size,
config.embedding_size,
Default::default(),
);
let decoder = nn::linear(
&(p / "decoder"),
config.embedding_size,
config.vocab_size,
Default::default(),
);
let activation = Box::new(match &config.hidden_act { let activation = Box::new(match &config.hidden_act {
Activation::gelu_new => _gelu_new, Activation::gelu_new => _gelu_new,
Activation::gelu => _gelu, Activation::gelu => _gelu,
Activation::relu => _relu, Activation::relu => _relu,
Activation::mish => _mish Activation::mish => _mish,
}); });
AlbertMLMHead { layer_norm, dense, decoder, activation } AlbertMLMHead {
layer_norm,
dense,
decoder,
activation,
}
} }
pub fn forward(&self, hidden_states: &Tensor) -> Tensor { pub fn forward(&self, hidden_states: &Tensor) -> Tensor {
@ -292,10 +357,10 @@ impl AlbertForMaskedLM {
/// # Example /// # Example
/// ///
/// ```no_run /// ```no_run
/// use tch::{nn, Device}; /// use rust_bert::albert::{AlbertConfig, AlbertForMaskedLM};
/// use rust_bert::Config; /// use rust_bert::Config;
/// use std::path::Path; /// use std::path::Path;
/// use rust_bert::albert::{AlbertConfig, AlbertForMaskedLM}; /// use tch::{nn, Device};
/// ///
/// let config_path = Path::new("path/to/config.json"); /// let config_path = Path::new("path/to/config.json");
/// let device = Device::Cpu; /// let device = Device::Cpu;
@ -303,12 +368,14 @@ impl AlbertForMaskedLM {
/// let config = AlbertConfig::from_file(config_path); /// let config = AlbertConfig::from_file(config_path);
/// let albert: AlbertForMaskedLM = AlbertForMaskedLM::new(&p.root(), &config); /// let albert: AlbertForMaskedLM = AlbertForMaskedLM::new(&p.root(), &config);
/// ``` /// ```
///
pub fn new(p: &nn::Path, config: &AlbertConfig) -> AlbertForMaskedLM { pub fn new(p: &nn::Path, config: &AlbertConfig) -> AlbertForMaskedLM {
let albert = AlbertModel::new(&(p / "albert"), config); let albert = AlbertModel::new(&(p / "albert"), config);
let predictions = AlbertMLMHead::new(&(p / "predictions"), config); let predictions = AlbertMLMHead::new(&(p / "predictions"), config);
AlbertForMaskedLM { albert, predictions } AlbertForMaskedLM {
albert,
predictions,
}
} }
/// Forward pass through the model /// Forward pass through the model
@ -331,42 +398,54 @@ impl AlbertForMaskedLM {
/// # Example /// # Example
/// ///
/// ```no_run /// ```no_run
///# use tch::{nn, Device, Tensor, no_grad}; /// # use tch::{nn, Device, Tensor, no_grad};
///# use rust_bert::Config; /// # use rust_bert::Config;
///# use std::path::Path; /// # use std::path::Path;
///# use tch::kind::Kind::Int64; /// # use tch::kind::Kind::Int64;
/// use rust_bert::albert::{AlbertConfig, AlbertForMaskedLM}; /// use rust_bert::albert::{AlbertConfig, AlbertForMaskedLM};
///# let config_path = Path::new("path/to/config.json"); /// # let config_path = Path::new("path/to/config.json");
///# let device = Device::Cpu; /// # let device = Device::Cpu;
///# let vs = nn::VarStore::new(device); /// # let vs = nn::VarStore::new(device);
///# let config = AlbertConfig::from_file(config_path); /// # let config = AlbertConfig::from_file(config_path);
///# let albert_model: AlbertForMaskedLM = AlbertForMaskedLM::new(&vs.root(), &config); /// # let albert_model: AlbertForMaskedLM = AlbertForMaskedLM::new(&vs.root(), &config);
/// let (batch_size, sequence_length) = (64, 128); /// let (batch_size, sequence_length) = (64, 128);
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device)); /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device)); /// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
/// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device)); /// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
/// let position_ids = Tensor::arange(sequence_length, (Int64, device)).expand(&[batch_size, sequence_length], true); /// let position_ids = Tensor::arange(sequence_length, (Int64, device))
/// /// .expand(&[batch_size, sequence_length], true);
/// let (output, all_hidden_states, all_attentions) = no_grad(|| {
/// albert_model
/// .forward_t(Some(input_tensor),
/// Some(mask),
/// Some(token_type_ids),
/// Some(position_ids),
/// None,
/// false)
/// });
/// ///
/// let (output, all_hidden_states, all_attentions) = no_grad(|| {
/// albert_model.forward_t(
/// Some(input_tensor),
/// Some(mask),
/// Some(token_type_ids),
/// Some(position_ids),
/// None,
/// false,
/// )
/// });
/// ``` /// ```
/// pub fn forward_t(
pub fn forward_t(&self, &self,
input_ids: Option<Tensor>, input_ids: Option<Tensor>,
mask: Option<Tensor>, mask: Option<Tensor>,
token_type_ids: Option<Tensor>, token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>, position_ids: Option<Tensor>,
input_embeds: Option<Tensor>, input_embeds: Option<Tensor>,
train: bool) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Vec<Tensor>>>) { train: bool,
let (hidden_state, _, all_hidden_states, all_attentions) = self.albert.forward_t(input_ids, mask, token_type_ids, position_ids, input_embeds, train).unwrap(); ) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Vec<Tensor>>>) {
let (hidden_state, _, all_hidden_states, all_attentions) = self
.albert
.forward_t(
input_ids,
mask,
token_type_ids,
position_ids,
input_embeds,
train,
)
.unwrap();
let prediction_scores = self.predictions.forward(&hidden_state); let prediction_scores = self.predictions.forward(&hidden_state);
(prediction_scores, all_hidden_states, all_attentions) (prediction_scores, all_hidden_states, all_attentions)
} }
@ -395,29 +474,42 @@ impl AlbertForSequenceClassification {
/// # Example /// # Example
/// ///
/// ```no_run /// ```no_run
/// use tch::{nn, Device}; /// use rust_bert::albert::{AlbertConfig, AlbertForSequenceClassification};
/// use rust_bert::Config; /// use rust_bert::Config;
/// use std::path::Path; /// use std::path::Path;
/// use rust_bert::albert::{AlbertConfig, AlbertForSequenceClassification}; /// use tch::{nn, Device};
/// ///
/// let config_path = Path::new("path/to/config.json"); /// let config_path = Path::new("path/to/config.json");
/// let device = Device::Cpu; /// let device = Device::Cpu;
/// let p = nn::VarStore::new(device); /// let p = nn::VarStore::new(device);
/// let config = AlbertConfig::from_file(config_path); /// let config = AlbertConfig::from_file(config_path);
/// let albert: AlbertForSequenceClassification = AlbertForSequenceClassification::new(&p.root(), &config); /// let albert: AlbertForSequenceClassification =
/// AlbertForSequenceClassification::new(&p.root(), &config);
/// ``` /// ```
///
pub fn new(p: &nn::Path, config: &AlbertConfig) -> AlbertForSequenceClassification { pub fn new(p: &nn::Path, config: &AlbertConfig) -> AlbertForSequenceClassification {
let albert = AlbertModel::new(&(p / "albert"), config); let albert = AlbertModel::new(&(p / "albert"), config);
let classifier_dropout_prob = match config.classifier_dropout_prob { let classifier_dropout_prob = match config.classifier_dropout_prob {
Some(value) => value, Some(value) => value,
None => 0.1 None => 0.1,
}; };
let dropout = Dropout::new(classifier_dropout_prob); let dropout = Dropout::new(classifier_dropout_prob);
let num_labels = config.id2label.as_ref().expect("num_labels not provided in configuration").len() as i64; let num_labels = config
let classifier = nn::linear(&(p / "classifier"), config.hidden_size, num_labels, Default::default()); .id2label
.as_ref()
.expect("num_labels not provided in configuration")
.len() as i64;
let classifier = nn::linear(
&(p / "classifier"),
config.hidden_size,
num_labels,
Default::default(),
);
AlbertForSequenceClassification { albert, dropout, classifier } AlbertForSequenceClassification {
albert,
dropout,
classifier,
}
} }
/// Forward pass through the model /// Forward pass through the model
@ -440,16 +532,16 @@ impl AlbertForSequenceClassification {
/// # Example /// # Example
/// ///
/// ```no_run /// ```no_run
///# use tch::{nn, Device, Tensor, no_grad}; /// # use tch::{nn, Device, Tensor, no_grad};
///# use rust_bert::Config; /// # use rust_bert::Config;
///# use std::path::Path; /// # use std::path::Path;
///# use tch::kind::Kind::Int64; /// # use tch::kind::Kind::Int64;
/// use rust_bert::albert::{AlbertConfig, AlbertForSequenceClassification}; /// use rust_bert::albert::{AlbertConfig, AlbertForSequenceClassification};
///# let config_path = Path::new("path/to/config.json"); /// # let config_path = Path::new("path/to/config.json");
///# let device = Device::Cpu; /// # let device = Device::Cpu;
///# let vs = nn::VarStore::new(device); /// # let vs = nn::VarStore::new(device);
///# let config = AlbertConfig::from_file(config_path); /// # let config = AlbertConfig::from_file(config_path);
///# let albert_model: AlbertForSequenceClassification = AlbertForSequenceClassification::new(&vs.root(), &config); /// # let albert_model: AlbertForSequenceClassification = AlbertForSequenceClassification::new(&vs.root(), &config);
/// let (batch_size, sequence_length) = (64, 128); /// let (batch_size, sequence_length) = (64, 128);
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device)); /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device)); /// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
@ -465,18 +557,30 @@ impl AlbertForSequenceClassification {
/// None, /// None,
/// false) /// false)
/// }); /// });
///
/// ``` /// ```
/// pub fn forward_t(
pub fn forward_t(&self, &self,
input_ids: Option<Tensor>, input_ids: Option<Tensor>,
mask: Option<Tensor>, mask: Option<Tensor>,
token_type_ids: Option<Tensor>, token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>, position_ids: Option<Tensor>,
input_embeds: Option<Tensor>, input_embeds: Option<Tensor>,
train: bool) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Vec<Tensor>>>) { train: bool,
let (_, pooled_output, all_hidden_states, all_attentions) = self.albert.forward_t(input_ids, mask, token_type_ids, position_ids, input_embeds, train).unwrap(); ) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Vec<Tensor>>>) {
let logits = pooled_output.apply_t(&self.dropout, train).apply(&self.classifier); let (_, pooled_output, all_hidden_states, all_attentions) = self
.albert
.forward_t(
input_ids,
mask,
token_type_ids,
position_ids,
input_embeds,
train,
)
.unwrap();
let logits = pooled_output
.apply_t(&self.dropout, train)
.apply(&self.classifier);
(logits, all_hidden_states, all_attentions) (logits, all_hidden_states, all_attentions)
} }
} }
@ -505,25 +609,38 @@ impl AlbertForTokenClassification {
/// # Example /// # Example
/// ///
/// ```no_run /// ```no_run
/// use tch::{nn, Device}; /// use rust_bert::albert::{AlbertConfig, AlbertForTokenClassification};
/// use rust_bert::Config; /// use rust_bert::Config;
/// use std::path::Path; /// use std::path::Path;
/// use rust_bert::albert::{AlbertConfig, AlbertForTokenClassification}; /// use tch::{nn, Device};
/// ///
/// let config_path = Path::new("path/to/config.json"); /// let config_path = Path::new("path/to/config.json");
/// let device = Device::Cpu; /// let device = Device::Cpu;
/// let p = nn::VarStore::new(device); /// let p = nn::VarStore::new(device);
/// let config = AlbertConfig::from_file(config_path); /// let config = AlbertConfig::from_file(config_path);
/// let albert: AlbertForTokenClassification = AlbertForTokenClassification::new(&p.root(), &config); /// let albert: AlbertForTokenClassification =
/// AlbertForTokenClassification::new(&p.root(), &config);
/// ``` /// ```
///
pub fn new(p: &nn::Path, config: &AlbertConfig) -> AlbertForTokenClassification { pub fn new(p: &nn::Path, config: &AlbertConfig) -> AlbertForTokenClassification {
let albert = AlbertModel::new(&(p / "albert"), config); let albert = AlbertModel::new(&(p / "albert"), config);
let dropout = Dropout::new(config.hidden_dropout_prob); let dropout = Dropout::new(config.hidden_dropout_prob);
let num_labels = config.id2label.as_ref().expect("num_labels not provided in configuration").len() as i64; let num_labels = config
let classifier = nn::linear(&(p / "classifier"), config.hidden_size, num_labels, Default::default()); .id2label
.as_ref()
.expect("num_labels not provided in configuration")
.len() as i64;
let classifier = nn::linear(
&(p / "classifier"),
config.hidden_size,
num_labels,
Default::default(),
);
AlbertForTokenClassification { albert, dropout, classifier } AlbertForTokenClassification {
albert,
dropout,
classifier,
}
} }
/// Forward pass through the model /// Forward pass through the model
@ -546,16 +663,16 @@ impl AlbertForTokenClassification {
/// # Example /// # Example
/// ///
/// ```no_run /// ```no_run
///# use tch::{nn, Device, Tensor, no_grad}; /// # use tch::{nn, Device, Tensor, no_grad};
///# use rust_bert::Config; /// # use rust_bert::Config;
///# use std::path::Path; /// # use std::path::Path;
///# use tch::kind::Kind::Int64; /// # use tch::kind::Kind::Int64;
/// use rust_bert::albert::{AlbertConfig, AlbertForTokenClassification}; /// use rust_bert::albert::{AlbertConfig, AlbertForTokenClassification};
///# let config_path = Path::new("path/to/config.json"); /// # let config_path = Path::new("path/to/config.json");
///# let device = Device::Cpu; /// # let device = Device::Cpu;
///# let vs = nn::VarStore::new(device); /// # let vs = nn::VarStore::new(device);
///# let config = AlbertConfig::from_file(config_path); /// # let config = AlbertConfig::from_file(config_path);
///# let albert_model: AlbertForTokenClassification = AlbertForTokenClassification::new(&vs.root(), &config); /// # let albert_model: AlbertForTokenClassification = AlbertForTokenClassification::new(&vs.root(), &config);
/// let (batch_size, sequence_length) = (64, 128); /// let (batch_size, sequence_length) = (64, 128);
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device)); /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device)); /// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
@ -571,18 +688,30 @@ impl AlbertForTokenClassification {
/// None, /// None,
/// false) /// false)
/// }); /// });
///
/// ``` /// ```
/// pub fn forward_t(
pub fn forward_t(&self, &self,
input_ids: Option<Tensor>, input_ids: Option<Tensor>,
mask: Option<Tensor>, mask: Option<Tensor>,
token_type_ids: Option<Tensor>, token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>, position_ids: Option<Tensor>,
input_embeds: Option<Tensor>, input_embeds: Option<Tensor>,
train: bool) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Vec<Tensor>>>) { train: bool,
let (sequence_output, _, all_hidden_states, all_attentions) = self.albert.forward_t(input_ids, mask, token_type_ids, position_ids, input_embeds, train).unwrap(); ) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Vec<Tensor>>>) {
let logits = sequence_output.apply_t(&self.dropout, train).apply(&self.classifier); let (sequence_output, _, all_hidden_states, all_attentions) = self
.albert
.forward_t(
input_ids,
mask,
token_type_ids,
position_ids,
input_embeds,
train,
)
.unwrap();
let logits = sequence_output
.apply_t(&self.dropout, train)
.apply(&self.classifier);
(logits, all_hidden_states, all_attentions) (logits, all_hidden_states, all_attentions)
} }
} }
@ -610,10 +739,10 @@ impl AlbertForQuestionAnswering {
/// # Example /// # Example
/// ///
/// ```no_run /// ```no_run
/// use tch::{nn, Device}; /// use rust_bert::albert::{AlbertConfig, AlbertForQuestionAnswering};
/// use rust_bert::Config; /// use rust_bert::Config;
/// use std::path::Path; /// use std::path::Path;
/// use rust_bert::albert::{AlbertConfig, AlbertForQuestionAnswering}; /// use tch::{nn, Device};
/// ///
/// let config_path = Path::new("path/to/config.json"); /// let config_path = Path::new("path/to/config.json");
/// let device = Device::Cpu; /// let device = Device::Cpu;
@ -621,11 +750,15 @@ impl AlbertForQuestionAnswering {
/// let config = AlbertConfig::from_file(config_path); /// let config = AlbertConfig::from_file(config_path);
/// let albert: AlbertForQuestionAnswering = AlbertForQuestionAnswering::new(&p.root(), &config); /// let albert: AlbertForQuestionAnswering = AlbertForQuestionAnswering::new(&p.root(), &config);
/// ``` /// ```
///
pub fn new(p: &nn::Path, config: &AlbertConfig) -> AlbertForQuestionAnswering { pub fn new(p: &nn::Path, config: &AlbertConfig) -> AlbertForQuestionAnswering {
let albert = AlbertModel::new(&(p / "albert"), config); let albert = AlbertModel::new(&(p / "albert"), config);
let num_labels = 2; let num_labels = 2;
let qa_outputs = nn::linear(&(p / "qa_outputs"), config.hidden_size, num_labels, Default::default()); let qa_outputs = nn::linear(
&(p / "qa_outputs"),
config.hidden_size,
num_labels,
Default::default(),
);
AlbertForQuestionAnswering { albert, qa_outputs } AlbertForQuestionAnswering { albert, qa_outputs }
} }
@ -651,16 +784,16 @@ impl AlbertForQuestionAnswering {
/// # Example /// # Example
/// ///
/// ```no_run /// ```no_run
///# use tch::{nn, Device, Tensor, no_grad}; /// # use tch::{nn, Device, Tensor, no_grad};
///# use rust_bert::Config; /// # use rust_bert::Config;
///# use std::path::Path; /// # use std::path::Path;
///# use tch::kind::Kind::Int64; /// # use tch::kind::Kind::Int64;
/// use rust_bert::albert::{AlbertConfig, AlbertForQuestionAnswering}; /// use rust_bert::albert::{AlbertConfig, AlbertForQuestionAnswering};
///# let config_path = Path::new("path/to/config.json"); /// # let config_path = Path::new("path/to/config.json");
///# let device = Device::Cpu; /// # let device = Device::Cpu;
///# let vs = nn::VarStore::new(device); /// # let vs = nn::VarStore::new(device);
///# let config = AlbertConfig::from_file(config_path); /// # let config = AlbertConfig::from_file(config_path);
///# let albert_model: AlbertForQuestionAnswering = AlbertForQuestionAnswering::new(&vs.root(), &config); /// # let albert_model: AlbertForQuestionAnswering = AlbertForQuestionAnswering::new(&vs.root(), &config);
/// let (batch_size, sequence_length) = (64, 128); /// let (batch_size, sequence_length) = (64, 128);
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device)); /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device)); /// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
@ -676,17 +809,32 @@ impl AlbertForQuestionAnswering {
/// None, /// None,
/// false) /// false)
/// }); /// });
///
/// ``` /// ```
/// pub fn forward_t(
pub fn forward_t(&self, &self,
input_ids: Option<Tensor>, input_ids: Option<Tensor>,
mask: Option<Tensor>, mask: Option<Tensor>,
token_type_ids: Option<Tensor>, token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>, position_ids: Option<Tensor>,
input_embeds: Option<Tensor>, input_embeds: Option<Tensor>,
train: bool) -> (Tensor, Tensor, Option<Vec<Tensor>>, Option<Vec<Vec<Tensor>>>) { train: bool,
let (sequence_output, _, all_hidden_states, all_attentions) = self.albert.forward_t(input_ids, mask, token_type_ids, position_ids, input_embeds, train).unwrap(); ) -> (
Tensor,
Tensor,
Option<Vec<Tensor>>,
Option<Vec<Vec<Tensor>>>,
) {
let (sequence_output, _, all_hidden_states, all_attentions) = self
.albert
.forward_t(
input_ids,
mask,
token_type_ids,
position_ids,
input_embeds,
train,
)
.unwrap();
let logits = sequence_output.apply(&self.qa_outputs).split(1, -1); let logits = sequence_output.apply(&self.qa_outputs).split(1, -1);
let (start_logits, end_logits) = (&logits[0], &logits[1]); let (start_logits, end_logits) = (&logits[0], &logits[1]);
let start_logits = start_logits.squeeze1(-1); let start_logits = start_logits.squeeze1(-1);
@ -721,10 +869,10 @@ impl AlbertForMultipleChoice {
/// # Example /// # Example
/// ///
/// ```no_run /// ```no_run
/// use tch::{nn, Device}; /// use rust_bert::albert::{AlbertConfig, AlbertForMultipleChoice};
/// use rust_bert::Config; /// use rust_bert::Config;
/// use std::path::Path; /// use std::path::Path;
/// use rust_bert::albert::{AlbertConfig, AlbertForMultipleChoice}; /// use tch::{nn, Device};
/// ///
/// let config_path = Path::new("path/to/config.json"); /// let config_path = Path::new("path/to/config.json");
/// let device = Device::Cpu; /// let device = Device::Cpu;
@ -732,14 +880,22 @@ impl AlbertForMultipleChoice {
/// let config = AlbertConfig::from_file(config_path); /// let config = AlbertConfig::from_file(config_path);
/// let albert: AlbertForMultipleChoice = AlbertForMultipleChoice::new(&p.root(), &config); /// let albert: AlbertForMultipleChoice = AlbertForMultipleChoice::new(&p.root(), &config);
/// ``` /// ```
///
pub fn new(p: &nn::Path, config: &AlbertConfig) -> AlbertForMultipleChoice { pub fn new(p: &nn::Path, config: &AlbertConfig) -> AlbertForMultipleChoice {
let albert = AlbertModel::new(&(p / "albert"), config); let albert = AlbertModel::new(&(p / "albert"), config);
let dropout = Dropout::new(config.hidden_dropout_prob); let dropout = Dropout::new(config.hidden_dropout_prob);
let num_labels = 1; let num_labels = 1;
let classifier = nn::linear(&(p / "classifier"), config.hidden_size, num_labels, Default::default()); let classifier = nn::linear(
&(p / "classifier"),
config.hidden_size,
num_labels,
Default::default(),
);
AlbertForMultipleChoice { albert, dropout, classifier } AlbertForMultipleChoice {
albert,
dropout,
classifier,
}
} }
/// Forward pass through the model /// Forward pass through the model
@ -762,16 +918,16 @@ impl AlbertForMultipleChoice {
/// # Example /// # Example
/// ///
/// ```no_run /// ```no_run
///# use tch::{nn, Device, Tensor, no_grad}; /// # use tch::{nn, Device, Tensor, no_grad};
///# use rust_bert::Config; /// # use rust_bert::Config;
///# use std::path::Path; /// # use std::path::Path;
///# use tch::kind::Kind::Int64; /// # use tch::kind::Kind::Int64;
/// use rust_bert::albert::{AlbertConfig, AlbertForMultipleChoice}; /// use rust_bert::albert::{AlbertConfig, AlbertForMultipleChoice};
///# let config_path = Path::new("path/to/config.json"); /// # let config_path = Path::new("path/to/config.json");
///# let device = Device::Cpu; /// # let device = Device::Cpu;
///# let vs = nn::VarStore::new(device); /// # let vs = nn::VarStore::new(device);
///# let config = AlbertConfig::from_file(config_path); /// # let config = AlbertConfig::from_file(config_path);
///# let albert_model: AlbertForMultipleChoice = AlbertForMultipleChoice::new(&vs.root(), &config); /// # let albert_model: AlbertForMultipleChoice = AlbertForMultipleChoice::new(&vs.root(), &config);
/// let (batch_size, sequence_length) = (64, 128); /// let (batch_size, sequence_length) = (64, 128);
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device)); /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device)); /// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
@ -787,44 +943,68 @@ impl AlbertForMultipleChoice {
/// None, /// None,
/// false).unwrap() /// false).unwrap()
/// }); /// });
///
/// ``` /// ```
/// pub fn forward_t(
pub fn forward_t(&self, &self,
input_ids: Option<Tensor>, input_ids: Option<Tensor>,
mask: Option<Tensor>, mask: Option<Tensor>,
token_type_ids: Option<Tensor>, token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>, position_ids: Option<Tensor>,
input_embeds: Option<Tensor>, input_embeds: Option<Tensor>,
train: bool) -> Result<(Tensor, Option<Vec<Tensor>>, Option<Vec<Vec<Tensor>>>), &'static str> { train: bool,
) -> Result<(Tensor, Option<Vec<Tensor>>, Option<Vec<Vec<Tensor>>>), &'static str> {
let (input_ids, input_embeds, num_choices) = match &input_ids { let (input_ids, input_embeds, num_choices) = match &input_ids {
Some(input_value) => match &input_embeds { Some(input_value) => match &input_embeds {
Some(_) => { return Err("Only one of input ids or input embeddings may be set"); } Some(_) => {
None => (Some(input_value.view((-1, *input_value.size().last().unwrap()))), None, input_value.size()[1]) return Err("Only one of input ids or input embeddings may be set");
} }
None => (
Some(input_value.view((-1, *input_value.size().last().unwrap()))),
None,
input_value.size()[1],
),
},
None => match &input_embeds { None => match &input_embeds {
Some(embeds) => (None, Some(embeds.view((-1, embeds.size()[1], embeds.size()[2]))), embeds.size()[1]), Some(embeds) => (
None => { return Err("At least one of input ids or input embeddings must be set"); } None,
} Some(embeds.view((-1, embeds.size()[1], embeds.size()[2]))),
embeds.size()[1],
),
None => {
return Err("At least one of input ids or input embeddings must be set");
}
},
}; };
let mask = match mask { let mask = match mask {
Some(value) => Some(value.view((-1, *value.size().last().unwrap()))), Some(value) => Some(value.view((-1, *value.size().last().unwrap()))),
None => None None => None,
}; };
let token_type_ids = match token_type_ids { let token_type_ids = match token_type_ids {
Some(value) => Some(value.view((-1, *value.size().last().unwrap()))), Some(value) => Some(value.view((-1, *value.size().last().unwrap()))),
None => None None => None,
}; };
let position_ids = match position_ids { let position_ids = match position_ids {
Some(value) => Some(value.view((-1, *value.size().last().unwrap()))), Some(value) => Some(value.view((-1, *value.size().last().unwrap()))),
None => None None => None,
}; };
let (_, pooled_output, all_hidden_states, all_attentions) = self
let (_, pooled_output, all_hidden_states, all_attentions) = self.albert.forward_t(input_ids, mask, token_type_ids, position_ids, input_embeds, train).unwrap(); .albert
let logits = pooled_output.apply_t(&self.dropout, train).apply(&self.classifier).view((-1, num_choices)); .forward_t(
input_ids,
mask,
token_type_ids,
position_ids,
input_embeds,
train,
)
.unwrap();
let logits = pooled_output
.apply_t(&self.dropout, train)
.apply(&self.classifier)
.view((-1, num_choices));
Ok((logits, all_hidden_states, all_attentions)) Ok((logits, all_hidden_states, all_attentions))
} }
} }

View File

@ -11,10 +11,10 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use crate::common::dropout::Dropout;
use tch::{nn, Tensor};
use crate::albert::AlbertConfig; use crate::albert::AlbertConfig;
use crate::common::dropout::Dropout;
use tch::kind::Kind::Float; use tch::kind::Kind::Float;
use tch::{nn, Tensor};
#[derive(Debug)] #[derive(Debug)]
pub struct AlbertSelfAttention { pub struct AlbertSelfAttention {
@ -32,24 +32,55 @@ pub struct AlbertSelfAttention {
impl AlbertSelfAttention { impl AlbertSelfAttention {
pub fn new(p: nn::Path, config: &AlbertConfig) -> AlbertSelfAttention { pub fn new(p: nn::Path, config: &AlbertConfig) -> AlbertSelfAttention {
assert_eq!(config.hidden_size % config.num_attention_heads, 0, "Hidden size not a multiple of the number of attention heads"); assert_eq!(
config.hidden_size % config.num_attention_heads,
0,
"Hidden size not a multiple of the number of attention heads"
);
let query = nn::linear(&p / "query", config.hidden_size, config.hidden_size, Default::default()); let query = nn::linear(
let key = nn::linear(&p / "key", config.hidden_size, config.hidden_size, Default::default()); &p / "query",
let value = nn::linear(&p / "value", config.hidden_size, config.hidden_size, Default::default()); config.hidden_size,
let dense = nn::linear(&p / "dense", config.hidden_size, config.hidden_size, Default::default()); config.hidden_size,
Default::default(),
);
let key = nn::linear(
&p / "key",
config.hidden_size,
config.hidden_size,
Default::default(),
);
let value = nn::linear(
&p / "value",
config.hidden_size,
config.hidden_size,
Default::default(),
);
let dense = nn::linear(
&p / "dense",
config.hidden_size,
config.hidden_size,
Default::default(),
);
let dropout = Dropout::new(config.attention_probs_dropout_prob); let dropout = Dropout::new(config.attention_probs_dropout_prob);
let attention_head_size = config.hidden_size / config.num_attention_heads; let attention_head_size = config.hidden_size / config.num_attention_heads;
let output_attentions = match config.output_attentions { let output_attentions = match config.output_attentions {
Some(value) => value, Some(value) => value,
None => false None => false,
}; };
let layer_norm_eps = match config.layer_norm_eps { let layer_norm_eps = match config.layer_norm_eps {
Some(value) => value, Some(value) => value,
None => 1e-12 None => 1e-12,
}; };
let layer_norm_config = nn::LayerNormConfig { eps: layer_norm_eps, ..Default::default() }; let layer_norm_config = nn::LayerNormConfig {
let layer_norm = nn::layer_norm(&p / "LayerNorm", vec![config.hidden_size], layer_norm_config); eps: layer_norm_eps,
..Default::default()
};
let layer_norm = nn::layer_norm(
&p / "LayerNorm",
vec![config.hidden_size],
layer_norm_config,
);
AlbertSelfAttention { AlbertSelfAttention {
num_attention_heads: config.num_attention_heads, num_attention_heads: config.num_attention_heads,
@ -66,19 +97,23 @@ impl AlbertSelfAttention {
} }
fn split_heads(&self, x: Tensor, bs: i64, dim_per_head: i64) -> Tensor { fn split_heads(&self, x: Tensor, bs: i64, dim_per_head: i64) -> Tensor {
x.view((bs, -1, self.num_attention_heads, dim_per_head)).transpose(1, 2) x.view((bs, -1, self.num_attention_heads, dim_per_head))
.transpose(1, 2)
} }
pub fn forward_t(&self, pub fn forward_t(
input_ids: &Tensor, &self,
mask: &Option<Tensor>, input_ids: &Tensor,
train: bool) -> (Tensor, Option<Tensor>) { mask: &Option<Tensor>,
train: bool,
) -> (Tensor, Option<Tensor>) {
let bs = *input_ids.size().first().unwrap(); let bs = *input_ids.size().first().unwrap();
let key_layer = self.split_heads(input_ids.apply(&self.key), bs, self.attention_head_size); let key_layer = self.split_heads(input_ids.apply(&self.key), bs, self.attention_head_size);
let value_layer = self.split_heads(input_ids.apply(&self.value), bs, self.attention_head_size); let value_layer =
let query_layer = self.split_heads(input_ids.apply(&self.query), bs, self.attention_head_size); self.split_heads(input_ids.apply(&self.value), bs, self.attention_head_size);
let query_layer =
self.split_heads(input_ids.apply(&self.query), bs, self.attention_head_size);
let query_layer: Tensor = query_layer / (self.attention_head_size as f64).sqrt(); let query_layer: Tensor = query_layer / (self.attention_head_size as f64).sqrt();
@ -91,9 +126,11 @@ impl AlbertSelfAttention {
let weights = scores.softmax(-1, Float).apply_t(&self.dropout, train); let weights = scores.softmax(-1, Float).apply_t(&self.dropout, train);
let context = weights.matmul(&value_layer).transpose(1, 2).contiguous(); let context = weights.matmul(&value_layer).transpose(1, 2).contiguous();
let w = self.dense.ws let w = self.dense.ws.transpose(0, 1).view((
.transpose(0, 1) self.num_attention_heads,
.view((self.num_attention_heads, self.attention_head_size, self.hidden_size)); self.attention_head_size,
self.hidden_size,
));
let context: Tensor = Tensor::einsum("bfnd,ndh->bfh", &[context, w]) + &self.dense.bs; let context: Tensor = Tensor::einsum("bfnd,ndh->bfh", &[context, w]) + &self.dense.bs;
let context = (input_ids + context.apply_t(&self.dropout, train)).apply(&self.layer_norm); let context = (input_ids + context.apply_t(&self.dropout, train)).apply(&self.layer_norm);
@ -104,4 +141,4 @@ impl AlbertSelfAttention {
(context, Some(weights)) (context, Some(weights))
} }
} }
} }

View File

@ -11,10 +11,10 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use tch::{nn, Tensor, Kind};
use crate::common::dropout::Dropout;
use crate::albert::AlbertConfig; use crate::albert::AlbertConfig;
use tch::nn::{EmbeddingConfig, embedding}; use crate::common::dropout::Dropout;
use tch::nn::{embedding, EmbeddingConfig};
use tch::{nn, Kind, Tensor};
/// # Embeddings implementation for Albert model /// # Embeddings implementation for Albert model
#[derive(Debug)] #[derive(Debug)]
@ -34,49 +34,77 @@ impl AlbertEmbeddings {
..Default::default() ..Default::default()
}; };
let word_embeddings: nn::Embedding = embedding(p / "word_embeddings", let word_embeddings: nn::Embedding = embedding(
config.vocab_size, p / "word_embeddings",
config.embedding_size, config.vocab_size,
embedding_config); config.embedding_size,
embedding_config,
);
let position_embeddings: nn::Embedding = embedding(p / "position_embeddings", let position_embeddings: nn::Embedding = embedding(
config.max_position_embeddings, p / "position_embeddings",
config.embedding_size, config.max_position_embeddings,
Default::default()); config.embedding_size,
Default::default(),
);
let token_type_embeddings: nn::Embedding = embedding(p / "token_type_embeddings", let token_type_embeddings: nn::Embedding = embedding(
config.type_vocab_size, p / "token_type_embeddings",
config.embedding_size, config.type_vocab_size,
Default::default()); config.embedding_size,
Default::default(),
);
let layer_norm_eps = match config.layer_norm_eps { let layer_norm_eps = match config.layer_norm_eps {
Some(value) => value, Some(value) => value,
None => 1e-12 None => 1e-12,
}; };
let layer_norm_config = nn::LayerNormConfig { eps: layer_norm_eps, ..Default::default() }; let layer_norm_config = nn::LayerNormConfig {
let layer_norm: nn::LayerNorm = nn::layer_norm(p / "LayerNorm", vec![config.embedding_size], layer_norm_config); eps: layer_norm_eps,
..Default::default()
};
let layer_norm: nn::LayerNorm = nn::layer_norm(
p / "LayerNorm",
vec![config.embedding_size],
layer_norm_config,
);
let dropout: Dropout = Dropout::new(config.hidden_dropout_prob); let dropout: Dropout = Dropout::new(config.hidden_dropout_prob);
AlbertEmbeddings { word_embeddings, position_embeddings, token_type_embeddings, layer_norm, dropout} AlbertEmbeddings {
word_embeddings,
position_embeddings,
token_type_embeddings,
layer_norm,
dropout,
}
} }
pub fn forward_t(&self, pub fn forward_t(
input_ids: Option<Tensor>, &self,
token_type_ids: Option<Tensor>, input_ids: Option<Tensor>,
position_ids: Option<Tensor>, token_type_ids: Option<Tensor>,
input_embeds: Option<Tensor>, position_ids: Option<Tensor>,
train: bool) -> Result<Tensor, &'static str> { input_embeds: Option<Tensor>,
train: bool,
) -> Result<Tensor, &'static str> {
let (input_embeddings, input_shape) = match input_ids { let (input_embeddings, input_shape) = match input_ids {
Some(input_value) => match input_embeds { Some(input_value) => match input_embeds {
Some(_) => { return Err("Only one of input ids or input embeddings may be set"); } Some(_) => {
None => (input_value.apply_t(&self.word_embeddings, train), input_value.size()) return Err("Only one of input ids or input embeddings may be set");
} }
None => (
input_value.apply_t(&self.word_embeddings, train),
input_value.size(),
),
},
None => match input_embeds { None => match input_embeds {
Some(embeds) => { Some(embeds) => {
let size = vec!(embeds.size()[0], embeds.size()[1]); let size = vec![embeds.size()[0], embeds.size()[1]];
(embeds, size) (embeds, size)
}, }
None => { return Err("Only one of input ids or input embeddings may be set"); } None => {
} return Err("Only one of input ids or input embeddings may be set");
}
},
}; };
let seq_length = input_embeddings.as_ref().size()[1].to_owned(); let seq_length = input_embeddings.as_ref().size()[1].to_owned();
@ -84,19 +112,22 @@ impl AlbertEmbeddings {
let position_ids = match position_ids { let position_ids = match position_ids {
Some(value) => value, Some(value) => value,
None => Tensor::arange(seq_length, (Kind::Int64, input_embeddings.device())) None => Tensor::arange(seq_length, (Kind::Int64, input_embeddings.device()))
.unsqueeze(0). .unsqueeze(0)
expand(&input_shape, true) .expand(&input_shape, true),
}; };
let token_type_ids = match token_type_ids { let token_type_ids = match token_type_ids {
Some(value) => value, Some(value) => value,
None => Tensor::zeros(&input_shape, (Kind::Int64, input_embeddings.device())) None => Tensor::zeros(&input_shape, (Kind::Int64, input_embeddings.device())),
}; };
let position_embeddings = position_ids.apply(&self.position_embeddings); let position_embeddings = position_ids.apply(&self.position_embeddings);
let token_type_embeddings = token_type_ids.apply(&self.token_type_embeddings); let token_type_embeddings = token_type_ids.apply(&self.token_type_embeddings);
let input_embeddings: Tensor = input_embeddings + position_embeddings + token_type_embeddings; let input_embeddings: Tensor =
Ok(input_embeddings.apply(&self.layer_norm).apply_t(&self.dropout, train)) input_embeddings + position_embeddings + token_type_embeddings;
Ok(input_embeddings
.apply(&self.layer_norm)
.apply_t(&self.dropout, train))
} }
} }

View File

@ -11,12 +11,12 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use crate::albert::attention::AlbertSelfAttention;
use tch::{nn, Tensor};
use crate::albert::AlbertConfig;
use crate::albert::albert::Activation; use crate::albert::albert::Activation;
use crate::common::activations::{_gelu_new, _gelu, _relu, _mish}; use crate::albert::attention::AlbertSelfAttention;
use crate::albert::AlbertConfig;
use crate::common::activations::{_gelu, _gelu_new, _mish, _relu};
use std::borrow::BorrowMut; use std::borrow::BorrowMut;
use tch::{nn, Tensor};
pub struct AlbertLayer { pub struct AlbertLayer {
attention: AlbertSelfAttention, attention: AlbertSelfAttention,
@ -32,29 +32,55 @@ impl AlbertLayer {
let layer_norm_eps = match config.layer_norm_eps { let layer_norm_eps = match config.layer_norm_eps {
Some(value) => value, Some(value) => value,
None => 1e-12 None => 1e-12,
}; };
let layer_norm_config = nn::LayerNormConfig { eps: layer_norm_eps, ..Default::default() }; let layer_norm_config = nn::LayerNormConfig {
let full_layer_layer_norm = nn::layer_norm(&(p / "full_layer_layer_norm"), vec![config.hidden_size], layer_norm_config); eps: layer_norm_eps,
..Default::default()
};
let full_layer_layer_norm = nn::layer_norm(
&(p / "full_layer_layer_norm"),
vec![config.hidden_size],
layer_norm_config,
);
let ffn = nn::linear(&(p / "ffn"), config.hidden_size, config.intermediate_size, Default::default()); let ffn = nn::linear(
let ffn_output = nn::linear(&(p / "ffn_output"), config.intermediate_size, config.hidden_size, Default::default()); &(p / "ffn"),
config.hidden_size,
config.intermediate_size,
Default::default(),
);
let ffn_output = nn::linear(
&(p / "ffn_output"),
config.intermediate_size,
config.hidden_size,
Default::default(),
);
let activation = Box::new(match &config.hidden_act { let activation = Box::new(match &config.hidden_act {
Activation::gelu_new => _gelu_new, Activation::gelu_new => _gelu_new,
Activation::gelu => _gelu, Activation::gelu => _gelu,
Activation::relu => _relu, Activation::relu => _relu,
Activation::mish => _mish Activation::mish => _mish,
}); });
AlbertLayer { attention, full_layer_layer_norm, ffn, ffn_output, activation } AlbertLayer {
attention,
full_layer_layer_norm,
ffn,
ffn_output,
activation,
}
} }
pub fn forward_t(&self, pub fn forward_t(
hidden_states: &Tensor, &self,
mask: &Option<Tensor>, hidden_states: &Tensor,
train: bool) -> (Tensor, Option<Tensor>) { mask: &Option<Tensor>,
let (attention_output, attention_weights) = self.attention.forward_t(hidden_states, mask, train); train: bool,
) -> (Tensor, Option<Tensor>) {
let (attention_output, attention_weights) =
self.attention.forward_t(hidden_states, mask, train);
let ffn_output = attention_output.apply(&self.ffn); let ffn_output = attention_output.apply(&self.ffn);
let ffn_output: Tensor = (self.activation)(&ffn_output); let ffn_output: Tensor = (self.activation)(&ffn_output);
let ffn_output = ffn_output.apply(&self.ffn_output); let ffn_output = ffn_output.apply(&self.ffn_output);
@ -76,29 +102,42 @@ impl AlbertLayerGroup {
let output_attentions = match config.output_attentions { let output_attentions = match config.output_attentions {
Some(value) => value, Some(value) => value,
None => false None => false,
}; };
let output_hidden_states = match config.output_hidden_states { let output_hidden_states = match config.output_hidden_states {
Some(value) => value, Some(value) => value,
None => false None => false,
}; };
let mut layers: Vec<AlbertLayer> = vec!(); let mut layers: Vec<AlbertLayer> = vec![];
for layer_index in 0..config.inner_group_num { for layer_index in 0..config.inner_group_num {
layers.push(AlbertLayer::new(&(p / layer_index), config)); layers.push(AlbertLayer::new(&(p / layer_index), config));
}; }
AlbertLayerGroup { output_hidden_states, output_attentions, layers } AlbertLayerGroup {
output_hidden_states,
output_attentions,
layers,
}
} }
pub fn forward_t(&self, pub fn forward_t(
hidden_states: &Tensor, &self,
mask: &Option<Tensor>, hidden_states: &Tensor,
train: bool) mask: &Option<Tensor>,
-> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) { train: bool,
let mut all_hidden_states: Option<Vec<Tensor>> = if self.output_hidden_states { Some(vec!()) } else { None }; ) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
let mut all_attentions: Option<Vec<Tensor>> = if self.output_attentions { Some(vec!()) } else { None }; let mut all_hidden_states: Option<Vec<Tensor>> = if self.output_hidden_states {
Some(vec![])
} else {
None
};
let mut all_attentions: Option<Vec<Tensor>> = if self.output_attentions {
Some(vec![])
} else {
None
};
let mut hidden_state = hidden_states.copy(); let mut hidden_state = hidden_states.copy();
let mut attention_weights: Option<Tensor>; let mut attention_weights: Option<Tensor>;
@ -117,9 +156,9 @@ impl AlbertLayerGroup {
attentions.push(attention_weights.as_ref().unwrap().copy()); attentions.push(attention_weights.as_ref().unwrap().copy());
}; };
} }
None => break None => break,
}; };
}; }
(hidden_state, all_hidden_states, all_attentions) (hidden_state, all_hidden_states, all_attentions)
} }
@ -140,20 +179,25 @@ impl AlbertTransformer {
let output_attentions = match config.output_attentions { let output_attentions = match config.output_attentions {
Some(value) => value, Some(value) => value,
None => false None => false,
}; };
let output_hidden_states = match config.output_hidden_states { let output_hidden_states = match config.output_hidden_states {
Some(value) => value, Some(value) => value,
None => false None => false,
}; };
let embedding_hidden_mapping_in = nn::linear(&(p / "embedding_hidden_mapping_in"), config.embedding_size, config.hidden_size, Default::default()); let embedding_hidden_mapping_in = nn::linear(
&(p / "embedding_hidden_mapping_in"),
config.embedding_size,
config.hidden_size,
Default::default(),
);
let mut layers: Vec<AlbertLayerGroup> = vec!(); let mut layers: Vec<AlbertLayerGroup> = vec![];
for layer_index in 0..config.inner_group_num { for layer_index in 0..config.inner_group_num {
layers.push(AlbertLayerGroup::new(&(p_layers / layer_index), config)); layers.push(AlbertLayerGroup::new(&(p_layers / layer_index), config));
}; }
AlbertTransformer { AlbertTransformer {
output_hidden_states, output_hidden_states,
@ -165,16 +209,24 @@ impl AlbertTransformer {
} }
} }
pub fn forward_t(&self, pub fn forward_t(
hidden_states: &Tensor, &self,
mask: Option<Tensor>, hidden_states: &Tensor,
train: bool) mask: Option<Tensor>,
-> (Tensor, Option<Vec<Tensor>>, Option<Vec<Vec<Tensor>>>) { train: bool,
) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Vec<Tensor>>>) {
let mut hidden_state = hidden_states.apply(&self.embedding_hidden_mapping_in); let mut hidden_state = hidden_states.apply(&self.embedding_hidden_mapping_in);
let mut all_hidden_states: Option<Vec<Tensor>> = if self.output_hidden_states { Some(vec!()) } else { None }; let mut all_hidden_states: Option<Vec<Tensor>> = if self.output_hidden_states {
let mut all_attentions: Option<Vec<Vec<Tensor>>> = if self.output_attentions { Some(vec!()) } else { None }; Some(vec![])
} else {
None
};
let mut all_attentions: Option<Vec<Vec<Tensor>>> = if self.output_attentions {
Some(vec![])
} else {
None
};
for i in 0..self.num_hidden_layers { for i in 0..self.num_hidden_layers {
let group_idx = i / (self.num_hidden_layers / self.num_hidden_groups); let group_idx = i / (self.num_hidden_layers / self.num_hidden_groups);
@ -190,9 +242,8 @@ impl AlbertTransformer {
if let Some(attentions) = all_attentions.borrow_mut() { if let Some(attentions) = all_attentions.borrow_mut() {
attentions.push(attention_weights.unwrap()); attentions.push(attention_weights.unwrap());
}; };
}; }
(hidden_state, all_hidden_states, all_attentions) (hidden_state, all_hidden_states, all_attentions)
} }
} }

View File

@ -20,37 +20,46 @@
//! Pretrained models are available and can be downloaded using RemoteResources. //! Pretrained models are available and can be downloaded using RemoteResources.
//! //!
//! ```no_run //! ```no_run
//!# fn main() -> failure::Fallible<()> { //! # fn main() -> failure::Fallible<()> {
//!# //! #
//! use rust_tokenizers::AlbertTokenizer; //! use rust_tokenizers::AlbertTokenizer;
//! use tch::{nn, Device}; //! use tch::{nn, Device};
//!# use std::path::PathBuf; //! # use std::path::PathBuf;
//! use rust_bert::albert::{AlbertForMaskedLM, AlbertConfig}; //! use rust_bert::albert::{AlbertConfig, AlbertForMaskedLM};
//! use rust_bert::resources::{download_resource, LocalResource, Resource};
//! use rust_bert::Config; //! use rust_bert::Config;
//! use rust_bert::resources::{Resource, download_resource, LocalResource};
//! //!
//! let config_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/config.json")}); //! let config_resource = Resource::Local(LocalResource {
//! let vocab_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/vocab.txt")}); //! local_path: PathBuf::from("path/to/config.json"),
//! let weights_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/model.ot")}); //! });
//! let vocab_resource = Resource::Local(LocalResource {
//! local_path: PathBuf::from("path/to/vocab.txt"),
//! });
//! let weights_resource = Resource::Local(LocalResource {
//! local_path: PathBuf::from("path/to/model.ot"),
//! });
//! let config_path = download_resource(&config_resource)?; //! let config_path = download_resource(&config_resource)?;
//! let vocab_path = download_resource(&vocab_resource)?; //! let vocab_path = download_resource(&vocab_resource)?;
//! let weights_path = download_resource(&weights_resource)?; //! let weights_path = download_resource(&weights_resource)?;
//! let device = Device::cuda_if_available(); //! let device = Device::cuda_if_available();
//! let mut vs = nn::VarStore::new(device); //! let mut vs = nn::VarStore::new(device);
//! let tokenizer: AlbertTokenizer = AlbertTokenizer::from_file(vocab_path.to_str().unwrap(), true, true); //! let tokenizer: AlbertTokenizer =
//! AlbertTokenizer::from_file(vocab_path.to_str().unwrap(), true, true);
//! let config = AlbertConfig::from_file(config_path); //! let config = AlbertConfig::from_file(config_path);
//! let bert_model = AlbertForMaskedLM::new(&vs.root(), &config); //! let bert_model = AlbertForMaskedLM::new(&vs.root(), &config);
//! vs.load(weights_path)?; //! vs.load(weights_path)?;
//! //!
//!# Ok(()) //! # Ok(())
//!# } //! # }
//! ``` //! ```
mod albert;
mod encoder;
mod attention; mod attention;
mod embeddings; mod embeddings;
mod albert; mod encoder;
pub use albert::{AlbertConfig, AlbertModelResources, AlbertConfigResources, AlbertVocabResources, AlbertModel, AlbertForMaskedLM, AlbertForSequenceClassification, AlbertForTokenClassification, AlbertForQuestionAnswering, AlbertForMultipleChoice}; pub use albert::{
AlbertConfig, AlbertConfigResources, AlbertForMaskedLM, AlbertForMultipleChoice,
AlbertForQuestionAnswering, AlbertForSequenceClassification, AlbertForTokenClassification,
AlbertModel, AlbertModelResources, AlbertVocabResources,
};

View File

@ -12,8 +12,8 @@
// limitations under the License. // limitations under the License.
use crate::common::dropout::Dropout; use crate::common::dropout::Dropout;
use tch::{nn, Tensor};
use tch::kind::Kind::Float; use tch::kind::Kind::Float;
use tch::{nn, Tensor};
#[derive(Debug)] #[derive(Debug)]
/// # Cache for BART attention layers /// # Cache for BART attention layers
@ -31,7 +31,7 @@ impl Clone for LayerState {
fn clone(&self) -> Self { fn clone(&self) -> Self {
let prev_key_padding_mask = match &self.prev_key_padding_mask { let prev_key_padding_mask = match &self.prev_key_padding_mask {
Some(key_padding_mask) => Some(key_padding_mask.copy()), Some(key_padding_mask) => Some(key_padding_mask.copy()),
None => None None => None,
}; };
LayerState { LayerState {
prev_key: self.prev_key.copy(), prev_key: self.prev_key.copy(),
@ -46,12 +46,16 @@ impl LayerState {
self.prev_key = self.prev_key.index_select(0, new_indices); self.prev_key = self.prev_key.index_select(0, new_indices);
self.prev_value = self.prev_value.index_select(0, new_indices); self.prev_value = self.prev_value.index_select(0, new_indices);
if self.prev_key_padding_mask.is_some() { if self.prev_key_padding_mask.is_some() {
self.prev_key_padding_mask = Some(self.prev_key_padding_mask.as_ref().unwrap().index_select(0, new_indices)); self.prev_key_padding_mask = Some(
self.prev_key_padding_mask
.as_ref()
.unwrap()
.index_select(0, new_indices),
);
} }
} }
} }
#[derive(Debug)] #[derive(Debug)]
pub struct SelfAttention { pub struct SelfAttention {
num_heads: i64, num_heads: i64,
@ -68,8 +72,15 @@ pub struct SelfAttention {
} }
impl SelfAttention { impl SelfAttention {
pub fn new(p: nn::Path, embed_dim: i64, num_heads: i64, dropout: f64, pub fn new(
encoder_decoder_attention: bool, store_cache: bool, output_attentions: bool) -> SelfAttention { p: nn::Path,
embed_dim: i64,
num_heads: i64,
dropout: f64,
encoder_decoder_attention: bool,
store_cache: bool,
output_attentions: bool,
) -> SelfAttention {
let k_proj = nn::linear(&p / "k_proj", embed_dim, embed_dim, Default::default()); let k_proj = nn::linear(&p / "k_proj", embed_dim, embed_dim, Default::default());
let v_proj = nn::linear(&p / "v_proj", embed_dim, embed_dim, Default::default()); let v_proj = nn::linear(&p / "v_proj", embed_dim, embed_dim, Default::default());
let q_proj = nn::linear(&p / "q_proj", embed_dim, embed_dim, Default::default()); let q_proj = nn::linear(&p / "q_proj", embed_dim, embed_dim, Default::default());
@ -95,58 +106,90 @@ impl SelfAttention {
} }
fn flatten(&self, x: Tensor, dim_0: i64, bs: i64) -> Tensor { fn flatten(&self, x: Tensor, dim_0: i64, bs: i64) -> Tensor {
x.contiguous().view((dim_0, bs * self.num_heads, self.head_dim)).transpose(0, 1) x.contiguous()
.view((dim_0, bs * self.num_heads, self.head_dim))
.transpose(0, 1)
} }
pub fn forward_t(&self, query: &Tensor, pub fn forward_t(
key: Option<&Tensor>, &self,
key_padding_mask: Option<&Tensor>, query: &Tensor,
attention_mask: Option<&Tensor>, key: Option<&Tensor>,
mut layer_state: Option<LayerState>, key_padding_mask: Option<&Tensor>,
train: bool) -> (Tensor, Option<Tensor>, Option<LayerState>) { attention_mask: Option<&Tensor>,
mut layer_state: Option<LayerState>,
train: bool,
) -> (Tensor, Option<Tensor>, Option<LayerState>) {
let query_size = query.size(); let query_size = query.size();
let (target_sequence_length, bs) = (query_size[0], query_size[1]); let (target_sequence_length, bs) = (query_size[0], query_size[1]);
let q: Tensor = self.flatten(query.as_ref().apply(&self.q_proj) * self.scaling, target_sequence_length, bs); let q: Tensor = self.flatten(
query.as_ref().apply(&self.q_proj) * self.scaling,
target_sequence_length,
bs,
);
let key = match &layer_state { let key = match &layer_state {
Some(_) => { if self.encoder_decoder_attention { None } else { key } } Some(_) => {
None => key if self.encoder_decoder_attention {
None
} else {
key
}
}
None => key,
}; };
let (k, v) = if self.encoder_decoder_attention { let (k, v) = if self.encoder_decoder_attention {
match key { match key {
Some(key) => { Some(key) => (
(Some(self.flatten(key.apply(&self.k_proj), -1, bs)), Some(self.flatten(key.apply(&self.k_proj), -1, bs)),
Some(self.flatten(key.apply(&self.v_proj), -1, bs)) Some(self.flatten(key.apply(&self.v_proj), -1, bs)),
) ),
} None => (None, None),
None => (None, None)
} }
} else { } else {
(Some(self.flatten(query.apply(&self.k_proj), -1, bs)), (
Some(self.flatten(query.apply(&self.v_proj), -1, bs)) Some(self.flatten(query.apply(&self.k_proj), -1, bs)),
Some(self.flatten(query.apply(&self.v_proj), -1, bs)),
) )
}; };
let (k, v, key_padding_mask) = self.use_saved_state(&layer_state, k, v, key_padding_mask, bs); let (k, v, key_padding_mask) =
self.use_saved_state(&layer_state, k, v, key_padding_mask, bs);
let source_sequence_length = k.size()[1]; let source_sequence_length = k.size()[1];
let attention_weights = q.bmm(&k.transpose(1, 2)); let attention_weights = q.bmm(&k.transpose(1, 2));
let attention_weights = match attention_mask { let attention_weights = match attention_mask {
Some(mask) => { Some(mask) => {
let attention_weights = attention_weights.view((bs, self.num_heads, target_sequence_length, source_sequence_length)) + mask; let attention_weights = attention_weights.view((
attention_weights.view((bs * self.num_heads, target_sequence_length, source_sequence_length)) bs,
self.num_heads,
target_sequence_length,
source_sequence_length,
)) + mask;
attention_weights.view((
bs * self.num_heads,
target_sequence_length,
source_sequence_length,
))
} }
None => attention_weights None => attention_weights,
}; };
let attention_weights = match key_padding_mask.as_ref() { let attention_weights = match key_padding_mask.as_ref() {
Some(mask) => { Some(mask) => attention_weights
attention_weights .view((
.view((bs, self.num_heads, target_sequence_length, source_sequence_length)) bs,
.masked_fill(&mask.unsqueeze(1).unsqueeze(2), std::f64::NEG_INFINITY) self.num_heads,
.view((bs * self.num_heads, target_sequence_length, source_sequence_length)) target_sequence_length,
} source_sequence_length,
None => attention_weights ))
.masked_fill(&mask.unsqueeze(1).unsqueeze(2), std::f64::NEG_INFINITY)
.view((
bs * self.num_heads,
target_sequence_length,
source_sequence_length,
)),
None => attention_weights,
}; };
let attention_weights = attention_weights.softmax(-1, Float); let attention_weights = attention_weights.softmax(-1, Float);
@ -159,16 +202,25 @@ impl SelfAttention {
.apply(&self.out_proj); .apply(&self.out_proj);
let attention_weights = if self.output_attentions { let attention_weights = if self.output_attentions {
Some(attention_weights.view((bs, self.num_heads, target_sequence_length, source_sequence_length))) Some(attention_weights.view((
} else { None }; bs,
self.num_heads,
target_sequence_length,
source_sequence_length,
)))
} else {
None
};
if self.store_cache { if self.store_cache {
if layer_state.is_some() { if layer_state.is_some() {
layer_state.as_mut().unwrap().prev_key = k.view((bs, self.num_heads, -1, self.head_dim)); layer_state.as_mut().unwrap().prev_key =
layer_state.as_mut().unwrap().prev_value = v.view((bs, self.num_heads, -1, self.head_dim)); k.view((bs, self.num_heads, -1, self.head_dim));
layer_state.as_mut().unwrap().prev_value =
v.view((bs, self.num_heads, -1, self.head_dim));
layer_state.as_mut().unwrap().prev_key_padding_mask = match key_padding_mask { layer_state.as_mut().unwrap().prev_key_padding_mask = match key_padding_mask {
Some(tensor) => Some(tensor), Some(tensor) => Some(tensor),
None => None None => None,
}; };
} else { } else {
layer_state = Some(LayerState { layer_state = Some(LayerState {
@ -176,7 +228,7 @@ impl SelfAttention {
prev_value: v.view((bs, self.num_heads, -1, self.head_dim)), prev_value: v.view((bs, self.num_heads, -1, self.head_dim)),
prev_key_padding_mask: match key_padding_mask { prev_key_padding_mask: match key_padding_mask {
Some(tensor) => Some(tensor), Some(tensor) => Some(tensor),
None => None None => None,
}, },
}) })
}; };
@ -185,17 +237,23 @@ impl SelfAttention {
(output, attention_weights, layer_state) (output, attention_weights, layer_state)
} }
fn use_saved_state(&self, fn use_saved_state(
layer_state: &Option<LayerState>, &self,
k: Option<Tensor>, layer_state: &Option<LayerState>,
v: Option<Tensor>, k: Option<Tensor>,
key_padding_mask: Option<&Tensor>, v: Option<Tensor>,
bs: i64) key_padding_mask: Option<&Tensor>,
-> (Tensor, Tensor, Option<Tensor>) { bs: i64,
) -> (Tensor, Tensor, Option<Tensor>) {
match &layer_state { match &layer_state {
Some(prev_state) => { Some(prev_state) => {
let prev_key = prev_state.prev_key.view((bs * self.num_heads, -1, self.head_dim)); let prev_key = prev_state
let prev_value = prev_state.prev_value.view((bs * self.num_heads, -1, self.head_dim)); .prev_key
.view((bs * self.num_heads, -1, self.head_dim));
let prev_value =
prev_state
.prev_value
.view((bs * self.num_heads, -1, self.head_dim));
let k = if self.encoder_decoder_attention { let k = if self.encoder_decoder_attention {
prev_key prev_key
} else { } else {
@ -207,39 +265,54 @@ impl SelfAttention {
Tensor::cat(&[prev_value, v.unwrap()], 1) Tensor::cat(&[prev_value, v.unwrap()], 1)
}; };
let key_padding_mask = self.use_saved_key_padding_mask(key_padding_mask, let key_padding_mask = self.use_saved_key_padding_mask(
&prev_state.prev_key_padding_mask, key_padding_mask,
bs, &prev_state.prev_key_padding_mask,
k.size()[1]); bs,
k.size()[1],
);
(k, v, key_padding_mask) (k, v, key_padding_mask)
} }
None => { None => {
let key_padding_mask = match key_padding_mask { let key_padding_mask = match key_padding_mask {
Some(value) => Some(value.copy()), Some(value) => Some(value.copy()),
None => None None => None,
}; };
(k.unwrap(), v.unwrap(), key_padding_mask) (k.unwrap(), v.unwrap(), key_padding_mask)
} }
} }
} }
fn use_saved_key_padding_mask(&self, key_padding_mask: Option<&Tensor>, prev_key_padding_mask: &Option<Tensor>, fn use_saved_key_padding_mask(
bs: i64, sequence_length: i64) -> Option<Tensor> { &self,
key_padding_mask: Option<&Tensor>,
prev_key_padding_mask: &Option<Tensor>,
bs: i64,
sequence_length: i64,
) -> Option<Tensor> {
if prev_key_padding_mask.is_some() { if prev_key_padding_mask.is_some() {
if self.encoder_decoder_attention { if self.encoder_decoder_attention {
Some(prev_key_padding_mask.as_ref().unwrap().copy()) Some(prev_key_padding_mask.as_ref().unwrap().copy())
} else { } else {
Some(Tensor::cat(&[prev_key_padding_mask.as_ref().unwrap(), key_padding_mask.as_ref().unwrap()], 1)) Some(Tensor::cat(
&[
prev_key_padding_mask.as_ref().unwrap(),
key_padding_mask.as_ref().unwrap(),
],
1,
))
} }
} else { } else {
match key_padding_mask { match key_padding_mask {
Some(key_padding_mask) => { Some(key_padding_mask) => {
let filler = Tensor::zeros(&[bs, sequence_length - key_padding_mask.size()[1]], let filler = Tensor::zeros(
(key_padding_mask.kind(), key_padding_mask.device())); &[bs, sequence_length - key_padding_mask.size()[1]],
(key_padding_mask.kind(), key_padding_mask.device()),
);
Some(Tensor::cat(&[filler, key_padding_mask.copy()], 1)) Some(Tensor::cat(&[filler, key_padding_mask.copy()], 1))
} }
None => None None => None,
} }
} }
} }
} }

View File

@ -11,18 +11,18 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use crate::Config;
use tch::{Tensor, nn};
use tch::kind::Kind::{Int64, Float};
use crate::bart::encoder::BartEncoder;
use crate::bart::decoder::BartDecoder;
use tch::nn::{embedding, EmbeddingConfig};
use crate::bart::attention::LayerState; use crate::bart::attention::LayerState;
use std::borrow::BorrowMut; use crate::bart::decoder::BartDecoder;
use crate::bart::encoder::BartEncoder;
use crate::common::dropout::Dropout; use crate::common::dropout::Dropout;
use crate::pipelines::generation::{Cache, LMHeadModel}; use crate::pipelines::generation::{Cache, LMHeadModel};
use crate::Config;
use serde::{Deserialize, Serialize};
use std::borrow::BorrowMut;
use std::collections::HashMap;
use tch::kind::Kind::{Float, Int64};
use tch::nn::{embedding, EmbeddingConfig};
use tch::{nn, Tensor};
/// # BART Pretrained model weight files /// # BART Pretrained model weight files
pub struct BartModelResources; pub struct BartModelResources;
@ -38,38 +38,74 @@ pub struct BartMergesResources;
impl BartModelResources { impl BartModelResources {
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format. /// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
pub const BART: (&'static str, &'static str) = ("bart/model.ot", "https://cdn.huggingface.co/facebook/bart-large/rust_model.ot"); pub const BART: (&'static str, &'static str) = (
"bart/model.ot",
"https://cdn.huggingface.co/facebook/bart-large/rust_model.ot",
);
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format. /// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
pub const BART_CNN: (&'static str, &'static str) = ("bart-cnn/model.ot", "https://cdn.huggingface.co/facebook/bart-large-cnn/rust_model.ot"); pub const BART_CNN: (&'static str, &'static str) = (
"bart-cnn/model.ot",
"https://cdn.huggingface.co/facebook/bart-large-cnn/rust_model.ot",
);
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format. /// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
pub const BART_XSUM: (&'static str, &'static str) = ("bart-xsum/model.ot", "https://cdn.huggingface.co/facebook/bart-large-xsum/rust_model.ot"); pub const BART_XSUM: (&'static str, &'static str) = (
"bart-xsum/model.ot",
"https://cdn.huggingface.co/facebook/bart-large-xsum/rust_model.ot",
);
} }
impl BartConfigResources { impl BartConfigResources {
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format. /// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
pub const BART: (&'static str, &'static str) = ("bart/config.json", "https://cdn.huggingface.co/facebook/bart-large/config.json"); pub const BART: (&'static str, &'static str) = (
"bart/config.json",
"https://cdn.huggingface.co/facebook/bart-large/config.json",
);
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format. /// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
pub const BART_CNN: (&'static str, &'static str) = ("bart-cnn/config.json", "https://cdn.huggingface.co/facebook/bart-large-cnn/config.json"); pub const BART_CNN: (&'static str, &'static str) = (
"bart-cnn/config.json",
"https://cdn.huggingface.co/facebook/bart-large-cnn/config.json",
);
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format. /// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
pub const BART_XSUM: (&'static str, &'static str) = ("bart-xsum/config.json", "https://cdn.huggingface.co/facebook/bart-large-xsum/config.json"); pub const BART_XSUM: (&'static str, &'static str) = (
"bart-xsum/config.json",
"https://cdn.huggingface.co/facebook/bart-large-xsum/config.json",
);
} }
impl BartVocabResources { impl BartVocabResources {
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format. /// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
pub const BART: (&'static str, &'static str) = ("bart/vocab.txt", "https://cdn.huggingface.co/roberta-large-vocab.json"); pub const BART: (&'static str, &'static str) = (
"bart/vocab.txt",
"https://cdn.huggingface.co/roberta-large-vocab.json",
);
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format. /// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
pub const BART_CNN: (&'static str, &'static str) = ("bart-cnn/vocab.txt", "https://cdn.huggingface.co/roberta-large-vocab.json"); pub const BART_CNN: (&'static str, &'static str) = (
"bart-cnn/vocab.txt",
"https://cdn.huggingface.co/roberta-large-vocab.json",
);
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format. /// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
pub const BART_XSUM: (&'static str, &'static str) = ("bart-xsum/vocab.txt", "https://cdn.huggingface.co/roberta-large-vocab.json"); pub const BART_XSUM: (&'static str, &'static str) = (
"bart-xsum/vocab.txt",
"https://cdn.huggingface.co/roberta-large-vocab.json",
);
} }
impl BartMergesResources { impl BartMergesResources {
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format. /// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
pub const BART: (&'static str, &'static str) = ("bart/merges.txt", "https://cdn.huggingface.co/roberta-large-merges.txt"); pub const BART: (&'static str, &'static str) = (
"bart/merges.txt",
"https://cdn.huggingface.co/roberta-large-merges.txt",
);
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format. /// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
pub const BART_CNN: (&'static str, &'static str) = ("bart-cnn/merges.txt", "https://cdn.huggingface.co/roberta-large-merges.txt"); pub const BART_CNN: (&'static str, &'static str) = (
"bart-cnn/merges.txt",
"https://cdn.huggingface.co/roberta-large-merges.txt",
);
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format. /// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
pub const BART_XSUM: (&'static str, &'static str) = ("bart-xsum/merges.txt", "https://cdn.huggingface.co/roberta-large-merges.txt"); pub const BART_XSUM: (&'static str, &'static str) = (
"bart-xsum/merges.txt",
"https://cdn.huggingface.co/roberta-large-merges.txt",
);
} }
#[allow(non_camel_case_types)] #[allow(non_camel_case_types)]
@ -130,14 +166,15 @@ pub struct BartConfig {
impl Config<BartConfig> for BartConfig {} impl Config<BartConfig> for BartConfig {}
fn _prepare_bart_decoder_inputs(pad_token_id: i64, fn _prepare_bart_decoder_inputs(
input_ids: &Tensor, pad_token_id: i64,
decoder_input_ids: Option<&Tensor>, input_ids: &Tensor,
decoder_padding_mask: Option<&Tensor>) decoder_input_ids: Option<&Tensor>,
-> (Tensor, Option<Tensor>, Option<Tensor>) { decoder_padding_mask: Option<&Tensor>,
) -> (Tensor, Option<Tensor>, Option<Tensor>) {
let decoder_input_ids = match decoder_input_ids { let decoder_input_ids = match decoder_input_ids {
Some(value) => value.copy(), Some(value) => value.copy(),
None => _shift_tokens_right(input_ids, pad_token_id) None => _shift_tokens_right(input_ids, pad_token_id),
}; };
let decoder_padding_mask = match decoder_padding_mask { let decoder_padding_mask = match decoder_padding_mask {
@ -159,7 +196,6 @@ fn _prepare_bart_decoder_inputs(pad_token_id: i64,
(decoder_input_ids, decoder_padding_mask, Some(causal_mask)) (decoder_input_ids, decoder_padding_mask, Some(causal_mask))
} }
fn _shift_tokens_right(input_ids: &Tensor, pad_token_id: i64) -> Tensor { fn _shift_tokens_right(input_ids: &Tensor, pad_token_id: i64) -> Tensor {
let index_eos: Tensor = input_ids.ne(pad_token_id).sum1(&[-1], true, Int64) - 1; let index_eos: Tensor = input_ids.ne(pad_token_id).sum1(&[-1], true, Int64) - 1;
let output = input_ids.empty_like().to_kind(Int64); let output = input_ids.empty_like().to_kind(Int64);
@ -200,10 +236,10 @@ impl BartModel {
/// # Example /// # Example
/// ///
/// ```no_run /// ```no_run
/// use tch::{nn, Device}; /// use rust_bert::bart::{BartConfig, BartModel};
/// use rust_bert::Config; /// use rust_bert::Config;
/// use std::path::Path; /// use std::path::Path;
/// use rust_bert::bart::{BartConfig, BartModel}; /// use tch::{nn, Device};
/// ///
/// let config_path = Path::new("path/to/config.json"); /// let config_path = Path::new("path/to/config.json");
/// let device = Device::Cpu; /// let device = Device::Cpu;
@ -212,22 +248,32 @@ impl BartModel {
/// let generation_mode = true; /// let generation_mode = true;
/// let bart: BartModel = BartModel::new(&(&p.root() / "bart"), &config, generation_mode); /// let bart: BartModel = BartModel::new(&(&p.root() / "bart"), &config, generation_mode);
/// ``` /// ```
///
pub fn new(p: &nn::Path, config: &BartConfig, generation_mode: bool) -> BartModel { pub fn new(p: &nn::Path, config: &BartConfig, generation_mode: bool) -> BartModel {
let pad_token_id = match config.pad_token_id { let pad_token_id = match config.pad_token_id {
Some(value) => value, Some(value) => value,
None => 1 None => 1,
}; };
let embedding_config = EmbeddingConfig { padding_idx: pad_token_id, ..Default::default() }; let embedding_config = EmbeddingConfig {
let embeddings: nn::Embedding = embedding(p / "shared", padding_idx: pad_token_id,
config.vocab_size, ..Default::default()
config.d_model, };
embedding_config); let embeddings: nn::Embedding = embedding(
p / "shared",
config.vocab_size,
config.d_model,
embedding_config,
);
let encoder = BartEncoder::new(p / "encoder", config); let encoder = BartEncoder::new(p / "encoder", config);
let decoder = BartDecoder::new(p / "decoder", config, generation_mode); let decoder = BartDecoder::new(p / "decoder", config, generation_mode);
BartModel { encoder, decoder, generation_mode, pad_token_id, embeddings } BartModel {
encoder,
decoder,
generation_mode,
pad_token_id,
embeddings,
}
} }
/// Forward pass through the model /// Forward pass through the model
@ -257,82 +303,116 @@ impl BartModel {
/// # Example /// # Example
/// ///
/// ```no_run /// ```no_run
///# use tch::{nn, Device, Tensor, no_grad}; /// # use tch::{nn, Device, Tensor, no_grad};
///# use rust_bert::Config; /// # use rust_bert::Config;
///# use std::path::Path; /// # use std::path::Path;
///# use tch::kind::Kind::{Int64, Double}; /// # use tch::kind::Kind::{Int64, Double};
/// use rust_bert::bart::{BartConfig, BartModel}; /// use rust_bert::bart::{BartConfig, BartModel};
///# let config_path = Path::new("path/to/config.json"); /// # let config_path = Path::new("path/to/config.json");
///# let vocab_path = Path::new("path/to/vocab.txt"); /// # let vocab_path = Path::new("path/to/vocab.txt");
///# let device = Device::Cpu; /// # let device = Device::Cpu;
///# let vs = nn::VarStore::new(device); /// # let vs = nn::VarStore::new(device);
///# let config = BartConfig::from_file(config_path); /// # let config = BartConfig::from_file(config_path);
///# let bart_model: BartModel = BartModel::new(&vs.root(), &config, false); /// # let bart_model: BartModel = BartModel::new(&vs.root(), &config, false);
/// let (batch_size, source_sequence_length, target_sequence_length) = (64, 128, 56); /// let (batch_size, source_sequence_length, target_sequence_length) = (64, 128, 56);
/// let input_tensor = Tensor::rand(&[batch_size, source_sequence_length], (Int64, device)); /// let input_tensor = Tensor::rand(&[batch_size, source_sequence_length], (Int64, device));
/// let target_tensor = Tensor::rand(&[batch_size, target_sequence_length], (Int64, device)); /// let target_tensor = Tensor::rand(&[batch_size, target_sequence_length], (Int64, device));
/// let encoder_attention_mask = Tensor::ones(&[batch_size, source_sequence_length], (Int64, device)); /// let encoder_attention_mask =
/// let decoder_attention_mask = Tensor::ones(&[batch_size, source_sequence_length], (Int64, device)); /// Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
/// /// let decoder_attention_mask =
/// let (decoder_output, encoder_hidden_states, decoder_cache, /// Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
/// all_encoder_hidden_states, all_encoder_attentions,
/// all_decoder_hidden_states, all_decoder_attentions) = no_grad(|| {
/// bart_model
/// .forward_t(Some(&input_tensor),
/// Some(&encoder_attention_mask),
/// Some(&target_tensor),
/// None,
/// Some(&decoder_attention_mask),
/// None,
/// false)
/// });
/// ///
/// let (
/// decoder_output,
/// encoder_hidden_states,
/// decoder_cache,
/// all_encoder_hidden_states,
/// all_encoder_attentions,
/// all_decoder_hidden_states,
/// all_decoder_attentions,
/// ) = no_grad(|| {
/// bart_model.forward_t(
/// Some(&input_tensor),
/// Some(&encoder_attention_mask),
/// Some(&target_tensor),
/// None,
/// Some(&decoder_attention_mask),
/// None,
/// false,
/// )
/// });
/// ``` /// ```
/// pub fn forward_t(
pub fn forward_t(&self, &self,
input_ids: Option<&Tensor>, input_ids: Option<&Tensor>,
attention_mask: Option<&Tensor>, attention_mask: Option<&Tensor>,
decoder_input_ids: Option<&Tensor>, decoder_input_ids: Option<&Tensor>,
encoder_outputs: Option<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>)>, encoder_outputs: Option<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>)>,
decoder_attention_mask: Option<&Tensor>, decoder_attention_mask: Option<&Tensor>,
layer_states: Option<Vec<(Option<LayerState>, Option<LayerState>)>>, layer_states: Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
train: bool) -> train: bool,
(Tensor, Tensor, Option<Vec<(Option<LayerState>, Option<LayerState>)>>, ) -> (
Option<Vec<Tensor>>, Option<Vec<Tensor>>, Tensor,
Option<Vec<Tensor>>, Option<Vec<Tensor>>) { Tensor,
Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
Option<Vec<Tensor>>,
Option<Vec<Tensor>>,
Option<Vec<Tensor>>,
Option<Vec<Tensor>>,
) {
let (decoder_input_ids, decoder_padding_mask, causal_mask) = if self.generation_mode { let (decoder_input_ids, decoder_padding_mask, causal_mask) = if self.generation_mode {
(decoder_input_ids.unwrap().copy(), None, None) (decoder_input_ids.unwrap().copy(), None, None)
} else { } else {
assert!(input_ids.is_some(), "input_ids must be provided when not in generation mode"); assert!(
_prepare_bart_decoder_inputs(self.pad_token_id, input_ids.unwrap(), decoder_input_ids, decoder_attention_mask) input_ids.is_some(),
"input_ids must be provided when not in generation mode"
);
_prepare_bart_decoder_inputs(
self.pad_token_id,
input_ids.unwrap(),
decoder_input_ids,
decoder_attention_mask,
)
}; };
let (encoder_hidden_states, let (encoder_hidden_states, all_encoder_hidden_states, all_encoder_attentions) =
all_encoder_hidden_states, match encoder_outputs {
all_encoder_attentions) = match encoder_outputs { Some(value) => value,
Some(value) => value, None => {
None => { assert!(
assert!(input_ids.is_some(), "input_ids must be provided when encoder output is not pre-computed"); input_ids.is_some(),
self.encoder.forward_t(input_ids.unwrap(), attention_mask, &self.embeddings, train) "input_ids must be provided when encoder output is not pre-computed"
} );
}; self.encoder.forward_t(
input_ids.unwrap(),
attention_mask,
&self.embeddings,
train,
)
}
};
let (decoder_outputs, let (decoder_outputs, decoder_cache, all_decoder_hidden_states, all_decoder_attentions) =
decoder_cache, self.decoder.forward_t(
&decoder_input_ids,
&encoder_hidden_states,
attention_mask,
decoder_padding_mask.as_ref(),
causal_mask.as_ref(),
&self.embeddings,
layer_states,
train,
);
(
decoder_outputs,
encoder_hidden_states,
decoder_cache.1,
all_decoder_hidden_states, all_decoder_hidden_states,
all_decoder_attentions) = self.decoder.forward_t(&decoder_input_ids, all_decoder_attentions,
&encoder_hidden_states, all_encoder_hidden_states,
attention_mask, all_encoder_attentions,
decoder_padding_mask.as_ref(), )
causal_mask.as_ref(),
&self.embeddings,
layer_states,
train);
(decoder_outputs, encoder_hidden_states, decoder_cache.1,
all_decoder_hidden_states, all_decoder_attentions,
all_encoder_hidden_states, all_encoder_attentions)
} }
} }
/// # BART Model for conditional generation /// # BART Model for conditional generation
@ -356,20 +436,24 @@ impl BartForConditionalGeneration {
/// # Example /// # Example
/// ///
/// ```no_run /// ```no_run
/// use tch::{nn, Device}; /// use rust_bert::bart::{BartConfig, BartForConditionalGeneration};
/// use rust_bert::Config; /// use rust_bert::Config;
/// use std::path::Path; /// use std::path::Path;
/// use rust_bert::bart::{BartConfig, BartForConditionalGeneration}; /// use tch::{nn, Device};
/// ///
/// let config_path = Path::new("path/to/config.json"); /// let config_path = Path::new("path/to/config.json");
/// let device = Device::Cpu; /// let device = Device::Cpu;
/// let p = nn::VarStore::new(device); /// let p = nn::VarStore::new(device);
/// let config = BartConfig::from_file(config_path); /// let config = BartConfig::from_file(config_path);
/// let generation_mode = true; /// let generation_mode = true;
/// let bart: BartForConditionalGeneration = BartForConditionalGeneration::new(&(&p.root() / "bart"), &config, generation_mode); /// let bart: BartForConditionalGeneration =
/// BartForConditionalGeneration::new(&(&p.root() / "bart"), &config, generation_mode);
/// ``` /// ```
/// pub fn new(
pub fn new(p: &nn::Path, config: &BartConfig, generation_mode: bool) -> BartForConditionalGeneration { p: &nn::Path,
config: &BartConfig,
generation_mode: bool,
) -> BartForConditionalGeneration {
let base_model = BartModel::new(&(p / "model"), config, generation_mode); let base_model = BartModel::new(&(p / "model"), config, generation_mode);
BartForConditionalGeneration { base_model } BartForConditionalGeneration { base_model }
} }
@ -398,17 +482,17 @@ impl BartForConditionalGeneration {
/// # Example /// # Example
/// ///
/// ```no_run /// ```no_run
///# use tch::{nn, Device, Tensor, no_grad}; /// # use tch::{nn, Device, Tensor, no_grad};
///# use rust_bert::Config; /// # use rust_bert::Config;
///# use std::path::Path; /// # use std::path::Path;
///# use tch::kind::Kind::{Int64, Double}; /// # use tch::kind::Kind::{Int64, Double};
/// use rust_bert::bart::{BartConfig, BartForConditionalGeneration}; /// use rust_bert::bart::{BartConfig, BartForConditionalGeneration};
///# let config_path = Path::new("path/to/config.json"); /// # let config_path = Path::new("path/to/config.json");
///# let vocab_path = Path::new("path/to/vocab.txt"); /// # let vocab_path = Path::new("path/to/vocab.txt");
///# let device = Device::Cpu; /// # let device = Device::Cpu;
///# let vs = nn::VarStore::new(device); /// # let vs = nn::VarStore::new(device);
///# let config = BartConfig::from_file(config_path); /// # let config = BartConfig::from_file(config_path);
///# let bart_model: BartForConditionalGeneration = BartForConditionalGeneration::new(&vs.root(), &config, false); /// # let bart_model: BartForConditionalGeneration = BartForConditionalGeneration::new(&vs.root(), &config, false);
/// let (batch_size, source_sequence_length, target_sequence_length) = (64, 128, 56); /// let (batch_size, source_sequence_length, target_sequence_length) = (64, 128, 56);
/// let input_tensor = Tensor::rand(&[batch_size, source_sequence_length], (Int64, device)); /// let input_tensor = Tensor::rand(&[batch_size, source_sequence_length], (Int64, device));
/// let target_tensor = Tensor::rand(&[batch_size, target_sequence_length], (Int64, device)); /// let target_tensor = Tensor::rand(&[batch_size, target_sequence_length], (Int64, device));
@ -427,37 +511,64 @@ impl BartForConditionalGeneration {
/// None, /// None,
/// false) /// false)
/// }); /// });
///
/// ``` /// ```
/// pub fn forward_t(
pub fn forward_t(&self, &self,
input_ids: Option<&Tensor>, input_ids: Option<&Tensor>,
attention_mask: Option<&Tensor>, attention_mask: Option<&Tensor>,
encoder_outputs: Option<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>)>, encoder_outputs: Option<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>)>,
decoder_input_ids: Option<&Tensor>, decoder_input_ids: Option<&Tensor>,
decoder_attention_mask: Option<&Tensor>, decoder_attention_mask: Option<&Tensor>,
old_layer_states: Option<Vec<(Option<LayerState>, Option<LayerState>)>>, old_layer_states: Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
train: bool) train: bool,
-> (Tensor, Tensor, Option<Vec<(Option<LayerState>, Option<LayerState>)>>, ) -> (
Option<Vec<Tensor>>, Option<Vec<Tensor>>, Tensor,
Option<Vec<Tensor>>, Option<Vec<Tensor>>) Tensor,
{ Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
let (decoder_outputs, encoder_hidden_states, decoder_cache, Option<Vec<Tensor>>,
all_decoder_hidden_states, all_decoder_attentions, Option<Vec<Tensor>>,
all_encoder_hidden_states, all_encoder_attentions) = Option<Vec<Tensor>>,
self.base_model.forward_t(input_ids, attention_mask, decoder_input_ids, encoder_outputs, decoder_attention_mask, old_layer_states, train); Option<Vec<Tensor>>,
) {
let (
decoder_outputs,
encoder_hidden_states,
decoder_cache,
all_decoder_hidden_states,
all_decoder_attentions,
all_encoder_hidden_states,
all_encoder_attentions,
) = self.base_model.forward_t(
input_ids,
attention_mask,
decoder_input_ids,
encoder_outputs,
decoder_attention_mask,
old_layer_states,
train,
);
let lm_logits = decoder_outputs.linear::<Tensor>(&self.base_model.embeddings.ws, None); let lm_logits = decoder_outputs.linear::<Tensor>(&self.base_model.embeddings.ws, None);
(lm_logits, encoder_hidden_states, decoder_cache, (
all_decoder_hidden_states, all_decoder_attentions, lm_logits,
all_encoder_hidden_states, all_encoder_attentions) encoder_hidden_states,
decoder_cache,
all_decoder_hidden_states,
all_decoder_attentions,
all_encoder_hidden_states,
all_encoder_attentions,
)
} }
pub fn encode(&self, input_ids: &Tensor, attention_mask: Option<&Tensor>) -> Tensor { pub fn encode(&self, input_ids: &Tensor, attention_mask: Option<&Tensor>) -> Tensor {
let (encoder_hidden_states, _, _) = self.base_model.encoder.forward_t(input_ids, attention_mask, &self.base_model.embeddings, false); let (encoder_hidden_states, _, _) = self.base_model.encoder.forward_t(
input_ids,
attention_mask,
&self.base_model.embeddings,
false,
);
encoder_hidden_states encoder_hidden_states
} }
} }
pub struct BartClassificationHead { pub struct BartClassificationHead {
@ -468,16 +579,29 @@ pub struct BartClassificationHead {
impl BartClassificationHead { impl BartClassificationHead {
pub fn new(p: &nn::Path, config: &BartConfig) -> BartClassificationHead { pub fn new(p: &nn::Path, config: &BartConfig) -> BartClassificationHead {
let dense = nn::linear(&(p / "dense"), config.d_model, config.d_model, Default::default()); let dense = nn::linear(
&(p / "dense"),
config.d_model,
config.d_model,
Default::default(),
);
let dropout = Dropout::new(config.classif_dropout); let dropout = Dropout::new(config.classif_dropout);
let out_proj = nn::linear(&(p / "out_proj"), config.d_model, config.num_labels.unwrap(), Default::default()); let out_proj = nn::linear(
&(p / "out_proj"),
config.d_model,
config.num_labels.unwrap(),
Default::default(),
);
BartClassificationHead { dense, dropout, out_proj } BartClassificationHead {
dense,
dropout,
out_proj,
}
} }
pub fn forward_t(&self, x: &Tensor, train: bool) -> Tensor { pub fn forward_t(&self, x: &Tensor, train: bool) -> Tensor {
x x.apply_t(&self.dropout, train)
.apply_t(&self.dropout, train)
.apply(&self.dense) .apply(&self.dense)
.tanh() .tanh()
.apply_t(&self.dropout, train) .apply_t(&self.dropout, train)
@ -497,7 +621,6 @@ pub struct BartForSequenceClassification {
eos_token_id: i64, eos_token_id: i64,
} }
impl BartForSequenceClassification { impl BartForSequenceClassification {
/// Build a new `BartForSequenceClassification` /// Build a new `BartForSequenceClassification`
/// ///
@ -509,27 +632,31 @@ impl BartForSequenceClassification {
/// # Example /// # Example
/// ///
/// ```no_run /// ```no_run
/// use tch::{nn, Device}; /// use rust_bert::bart::{BartConfig, BartForSequenceClassification};
/// use rust_bert::Config; /// use rust_bert::Config;
/// use std::path::Path; /// use std::path::Path;
/// use rust_bert::bart::{BartConfig, BartForSequenceClassification}; /// use tch::{nn, Device};
/// ///
/// let config_path = Path::new("path/to/config.json"); /// let config_path = Path::new("path/to/config.json");
/// let device = Device::Cpu; /// let device = Device::Cpu;
/// let p = nn::VarStore::new(device); /// let p = nn::VarStore::new(device);
/// let config = BartConfig::from_file(config_path); /// let config = BartConfig::from_file(config_path);
/// let generation_mode = true; /// let generation_mode = true;
/// let bart: BartForSequenceClassification = BartForSequenceClassification::new(&(&p.root() / "bart"), &config); /// let bart: BartForSequenceClassification =
/// BartForSequenceClassification::new(&(&p.root() / "bart"), &config);
/// ``` /// ```
///
pub fn new(p: &nn::Path, config: &BartConfig) -> BartForSequenceClassification { pub fn new(p: &nn::Path, config: &BartConfig) -> BartForSequenceClassification {
let base_model = BartModel::new(&(p / "model"), config, false); let base_model = BartModel::new(&(p / "model"), config, false);
let classification_head = BartClassificationHead::new(&(p / "classification_head"), config); let classification_head = BartClassificationHead::new(&(p / "classification_head"), config);
let eos_token_id = match config.eos_token_id { let eos_token_id = match config.eos_token_id {
Some(value) => value, Some(value) => value,
None => 3 None => 3,
}; };
BartForSequenceClassification { base_model, classification_head, eos_token_id } BartForSequenceClassification {
base_model,
classification_head,
eos_token_id,
}
} }
/// Forward pass through the model /// Forward pass through the model
@ -556,17 +683,17 @@ impl BartForSequenceClassification {
/// # Example /// # Example
/// ///
/// ```no_run /// ```no_run
///# use tch::{nn, Device, Tensor, no_grad}; /// # use tch::{nn, Device, Tensor, no_grad};
///# use rust_bert::Config; /// # use rust_bert::Config;
///# use std::path::Path; /// # use std::path::Path;
///# use tch::kind::Kind::{Int64, Double}; /// # use tch::kind::Kind::{Int64, Double};
/// use rust_bert::bart::{BartConfig, BartForConditionalGeneration}; /// use rust_bert::bart::{BartConfig, BartForConditionalGeneration};
///# let config_path = Path::new("path/to/config.json"); /// # let config_path = Path::new("path/to/config.json");
///# let vocab_path = Path::new("path/to/vocab.txt"); /// # let vocab_path = Path::new("path/to/vocab.txt");
///# let device = Device::Cpu; /// # let device = Device::Cpu;
///# let vs = nn::VarStore::new(device); /// # let vs = nn::VarStore::new(device);
///# let config = BartConfig::from_file(config_path); /// # let config = BartConfig::from_file(config_path);
///# let bart_model: BartForConditionalGeneration = BartForConditionalGeneration::new(&vs.root(), &config, false); /// # let bart_model: BartForConditionalGeneration = BartForConditionalGeneration::new(&vs.root(), &config, false);
/// let (batch_size, source_sequence_length, target_sequence_length) = (64, 128, 56); /// let (batch_size, source_sequence_length, target_sequence_length) = (64, 128, 56);
/// let input_tensor = Tensor::rand(&[batch_size, source_sequence_length], (Int64, device)); /// let input_tensor = Tensor::rand(&[batch_size, source_sequence_length], (Int64, device));
/// let target_tensor = Tensor::rand(&[batch_size, target_sequence_length], (Int64, device)); /// let target_tensor = Tensor::rand(&[batch_size, target_sequence_length], (Int64, device));
@ -585,36 +712,63 @@ impl BartForSequenceClassification {
/// None, /// None,
/// false) /// false)
/// }); /// });
///
/// ``` /// ```
/// pub fn forward_t(
pub fn forward_t(&mut self, &mut self,
input_ids: &Tensor, input_ids: &Tensor,
attention_mask: Option<&Tensor>, attention_mask: Option<&Tensor>,
encoder_outputs: Option<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>)>, encoder_outputs: Option<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>)>,
decoder_input_ids: Option<&Tensor>, decoder_input_ids: Option<&Tensor>,
decoder_attention_mask: Option<&Tensor>, decoder_attention_mask: Option<&Tensor>,
train: bool) train: bool,
-> (Tensor, Tensor, ) -> (
Option<Vec<Tensor>>, Option<Vec<Tensor>>, Tensor,
Option<Vec<Tensor>>, Option<Vec<Tensor>>) { Tensor,
let (decoder_outputs, encoder_hidden_states, _, Option<Vec<Tensor>>,
all_decoder_hidden_states, all_decoder_attentions, Option<Vec<Tensor>>,
all_encoder_hidden_states, all_encoder_attentions) = Option<Vec<Tensor>>,
self.borrow_mut().base_model.forward_t(Some(input_ids), attention_mask, decoder_input_ids, encoder_outputs, decoder_attention_mask, None, train); Option<Vec<Tensor>>,
) {
let (
decoder_outputs,
encoder_hidden_states,
_,
all_decoder_hidden_states,
all_decoder_attentions,
all_encoder_hidden_states,
all_encoder_attentions,
) = self.borrow_mut().base_model.forward_t(
Some(input_ids),
attention_mask,
decoder_input_ids,
encoder_outputs,
decoder_attention_mask,
None,
train,
);
let eos_mask = input_ids.eq(self.eos_token_id); let eos_mask = input_ids.eq(self.eos_token_id);
let sentence_representation = decoder_outputs let sentence_representation = decoder_outputs
.index_select(0, &eos_mask) .index_select(0, &eos_mask)
.view((decoder_outputs.size()[0], -1, *decoder_outputs.size().last().unwrap())) .view((
decoder_outputs.size()[0],
-1,
*decoder_outputs.size().last().unwrap(),
))
.select(1, -1); .select(1, -1);
let logits = self.classification_head.forward_t(&sentence_representation, train); let logits = self
(logits, encoder_hidden_states, .classification_head
all_decoder_hidden_states, all_decoder_attentions, .forward_t(&sentence_representation, train);
all_encoder_hidden_states, all_encoder_attentions) (
logits,
encoder_hidden_states,
all_decoder_hidden_states,
all_decoder_attentions,
all_encoder_hidden_states,
all_encoder_attentions,
)
} }
} }
impl LMHeadModel for BartForConditionalGeneration { impl LMHeadModel for BartForConditionalGeneration {
@ -645,18 +799,18 @@ impl LMHeadModel for BartForConditionalGeneration {
/// # Example /// # Example
/// ///
/// ```no_run /// ```no_run
///# use tch::{nn, Device, Tensor, no_grad}; /// # use tch::{nn, Device, Tensor, no_grad};
///# use rust_bert::Config; /// # use rust_bert::Config;
///# use std::path::Path; /// # use std::path::Path;
///# use tch::kind::Kind::{Int64, Double}; /// # use tch::kind::Kind::{Int64, Double};
/// use rust_bert::pipelines::generation::LMHeadModel; /// use rust_bert::pipelines::generation::LMHeadModel;
/// use rust_bert::bart::{BartForConditionalGeneration, BartConfig}; /// use rust_bert::bart::{BartForConditionalGeneration, BartConfig};
///# let config_path = Path::new("path/to/config.json"); /// # let config_path = Path::new("path/to/config.json");
///# let vocab_path = Path::new("path/to/vocab.txt"); /// # let vocab_path = Path::new("path/to/vocab.txt");
///# let device = Device::Cpu; /// # let device = Device::Cpu;
///# let vs = nn::VarStore::new(device); /// # let vs = nn::VarStore::new(device);
///# let config = BartConfig::from_file(config_path); /// # let config = BartConfig::from_file(config_path);
///# let bart_model: BartForConditionalGeneration = BartForConditionalGeneration::new(&vs.root(), &config, false); /// # let bart_model: BartForConditionalGeneration = BartForConditionalGeneration::new(&vs.root(), &config, false);
/// let (batch_size, source_sequence_length, target_sequence_length) = (64, 128, 56); /// let (batch_size, source_sequence_length, target_sequence_length) = (64, 128, 56);
/// let input_tensor = Tensor::rand(&[batch_size, source_sequence_length], (Int64, device)); /// let input_tensor = Tensor::rand(&[batch_size, source_sequence_length], (Int64, device));
/// let target_tensor = Tensor::rand(&[batch_size, target_sequence_length], (Int64, device)); /// let target_tensor = Tensor::rand(&[batch_size, target_sequence_length], (Int64, device));
@ -675,39 +829,58 @@ impl LMHeadModel for BartForConditionalGeneration {
/// None, /// None,
/// false) /// false)
/// }); /// });
///
/// ``` /// ```
/// fn forward_t(
fn forward_t(&self, &self,
input_ids: &Option<Tensor>, input_ids: &Option<Tensor>,
cache: Cache, cache: Cache,
attention_mask: &Option<Tensor>, attention_mask: &Option<Tensor>,
_token_type_ids: &Option<Tensor>, _token_type_ids: &Option<Tensor>,
_position_ids: &Option<Tensor>, _position_ids: &Option<Tensor>,
_input_embeds: &Option<Tensor>, _input_embeds: &Option<Tensor>,
encoder_outputs: Option<&Tensor>, encoder_outputs: Option<&Tensor>,
decoder_input_ids: &Option<Tensor>, decoder_input_ids: &Option<Tensor>,
train: bool) -> Result<(Tensor, Option<Tensor>, Cache, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> { train: bool,
) -> Result<
(
Tensor,
Option<Tensor>,
Cache,
Option<Vec<Tensor>>,
Option<Vec<Tensor>>,
),
&'static str,
> {
let (decoder_output, encoder_hidden_states, new_cache, _, _, _, _) = match cache { let (decoder_output, encoder_hidden_states, new_cache, _, _, _, _) = match cache {
Cache::BARTCache(cached_layer_states) => self.base_model.forward_t(input_ids.as_ref(), Cache::BARTCache(cached_layer_states) => self.base_model.forward_t(
attention_mask.as_ref(), input_ids.as_ref(),
decoder_input_ids.as_ref(), attention_mask.as_ref(),
Some((encoder_outputs.as_ref().unwrap().copy(), None, None)), decoder_input_ids.as_ref(),
None, Some((encoder_outputs.as_ref().unwrap().copy(), None, None)),
cached_layer_states, None,
train), cached_layer_states,
train,
),
Cache::None => self.base_model.forward_t(input_ids.as_ref(), Cache::None => self.base_model.forward_t(
attention_mask.as_ref(), input_ids.as_ref(),
decoder_input_ids.as_ref(), attention_mask.as_ref(),
Some((encoder_outputs.as_ref().unwrap().copy(), None, None)), decoder_input_ids.as_ref(),
None, Some((encoder_outputs.as_ref().unwrap().copy(), None, None)),
None, None,
train), None,
_ => Err("Cache not compatible with BART Model")? train,
),
_ => Err("Cache not compatible with BART Model")?,
}; };
let lm_logits = decoder_output.linear::<Tensor>(&self.base_model.embeddings.ws, None); let lm_logits = decoder_output.linear::<Tensor>(&self.base_model.embeddings.ws, None);
Ok((lm_logits, Some(encoder_hidden_states), Cache::BARTCache(new_cache), None, None)) Ok((
lm_logits,
Some(encoder_hidden_states),
Cache::BARTCache(new_cache),
None,
None,
))
} }
} }

View File

@ -11,15 +11,17 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use crate::bart::attention::{SelfAttention, LayerState}; use crate::bart::attention::{LayerState, SelfAttention};
use tch::{nn, Tensor};
use crate::common::dropout::Dropout;
use crate::bart::BartConfig;
use crate::bart::bart::Activation; use crate::bart::bart::Activation;
use crate::common::activations::{_gelu, _relu, _swish, _gelu_new, _tanh}; use crate::bart::embeddings::{
use tch::kind::Kind::Int64; EmbeddingOption, LearnedPositionalEmbedding, SinusoidalPositionalEmbedding,
};
use crate::bart::BartConfig;
use crate::common::activations::{_gelu, _gelu_new, _relu, _swish, _tanh};
use crate::common::dropout::Dropout;
use std::borrow::BorrowMut; use std::borrow::BorrowMut;
use crate::bart::embeddings::{EmbeddingOption, LearnedPositionalEmbedding, SinusoidalPositionalEmbedding}; use tch::kind::Kind::Int64;
use tch::{nn, Tensor};
pub struct DecoderLayer { pub struct DecoderLayer {
self_attention: SelfAttention, self_attention: SelfAttention,
@ -36,51 +38,74 @@ pub struct DecoderLayer {
impl DecoderLayer { impl DecoderLayer {
pub fn new(p: nn::Path, config: &BartConfig) -> DecoderLayer { pub fn new(p: nn::Path, config: &BartConfig) -> DecoderLayer {
let layer_norm_config = nn::LayerNormConfig { eps: 1e-5, ..Default::default() }; let layer_norm_config = nn::LayerNormConfig {
eps: 1e-5,
..Default::default()
};
let output_attention = match config.output_attentions { let output_attention = match config.output_attentions {
Some(value) => value, Some(value) => value,
None => false None => false,
}; };
let self_attention = SelfAttention::new(&p / "self_attn", let self_attention = SelfAttention::new(
config.d_model, &p / "self_attn",
config.decoder_attention_heads, config.d_model,
config.attention_dropout, config.decoder_attention_heads,
false, config.attention_dropout,
true, false,
output_attention); true,
let encoder_attention = SelfAttention::new(&p / "encoder_attn", output_attention,
config.d_model, );
config.decoder_attention_heads, let encoder_attention = SelfAttention::new(
config.attention_dropout, &p / "encoder_attn",
true, config.d_model,
true, config.decoder_attention_heads,
output_attention); config.attention_dropout,
let self_attention_layer_norm = nn::layer_norm(&p / "self_attn_layer_norm", true,
vec![config.d_model], true,
layer_norm_config); output_attention,
let encoder_attention_layer_norm = nn::layer_norm(&p / "encoder_attn_layer_norm", );
vec![config.d_model], let self_attention_layer_norm = nn::layer_norm(
layer_norm_config); &p / "self_attn_layer_norm",
vec![config.d_model],
layer_norm_config,
);
let encoder_attention_layer_norm = nn::layer_norm(
&p / "encoder_attn_layer_norm",
vec![config.d_model],
layer_norm_config,
);
let dropout = Dropout::new(config.dropout); let dropout = Dropout::new(config.dropout);
let activation_dropout = Dropout::new(config.activation_dropout); let activation_dropout = Dropout::new(config.activation_dropout);
let activation_function = match &config.activation_function { let activation_function = match &config.activation_function {
Some(act_function) => act_function, Some(act_function) => act_function,
None => &Activation::gelu None => &Activation::gelu,
}; };
let activation = Box::new(match activation_function { let activation = Box::new(match activation_function {
Activation::gelu => _gelu, Activation::gelu => _gelu,
Activation::relu => _relu, Activation::relu => _relu,
Activation::swish => _swish, Activation::swish => _swish,
Activation::gelu_new => _gelu_new, Activation::gelu_new => _gelu_new,
Activation::tanh => _tanh Activation::tanh => _tanh,
}); });
let fc1 = nn::linear(&p / "fc1", config.d_model, config.decoder_ffn_dim, Default::default()); let fc1 = nn::linear(
let fc2 = nn::linear(&p / "fc2", config.decoder_ffn_dim, config.d_model, Default::default()); &p / "fc1",
config.d_model,
config.decoder_ffn_dim,
Default::default(),
);
let fc2 = nn::linear(
&p / "fc2",
config.decoder_ffn_dim,
config.d_model,
Default::default(),
);
let final_layer_norm = nn::layer_norm(&p / "final_layer_norm", let final_layer_norm = nn::layer_norm(
vec![config.d_model], &p / "final_layer_norm",
layer_norm_config); vec![config.d_model],
layer_norm_config,
);
DecoderLayer { DecoderLayer {
self_attention, self_attention,
@ -96,18 +121,38 @@ impl DecoderLayer {
} }
} }
pub fn forward_t(&self, pub fn forward_t(
x: &Tensor, &self,
encoder_hidden_states: &Tensor, x: &Tensor,
encoder_attn_mask: Option<&Tensor>, encoder_hidden_states: &Tensor,
causal_mask: Option<&Tensor>, encoder_attn_mask: Option<&Tensor>,
decoder_padding_mask: Option<&Tensor>, causal_mask: Option<&Tensor>,
layer_states: (Option<LayerState>, Option<LayerState>), decoder_padding_mask: Option<&Tensor>,
train: bool) -> (Tensor, Option<Tensor>, (Option<LayerState>, Option<LayerState>)) { layer_states: (Option<LayerState>, Option<LayerState>),
let (output, attention_weights, new_self_layer_states) = self.self_attention.forward_t(x, Some(x), decoder_padding_mask, causal_mask, layer_states.0, train); train: bool,
) -> (
Tensor,
Option<Tensor>,
(Option<LayerState>, Option<LayerState>),
) {
let (output, attention_weights, new_self_layer_states) = self.self_attention.forward_t(
x,
Some(x),
decoder_padding_mask,
causal_mask,
layer_states.0,
train,
);
let output: Tensor = output.apply_t(&self.dropout, train) + x; let output: Tensor = output.apply_t(&self.dropout, train) + x;
let output = output.apply(&self.self_attention_layer_norm); let output = output.apply(&self.self_attention_layer_norm);
let (output1, _, new_encoder_layer_states) = self.encoder_attention.forward_t(&output, Some(encoder_hidden_states), encoder_attn_mask, None, layer_states.1, train); let (output1, _, new_encoder_layer_states) = self.encoder_attention.forward_t(
&output,
Some(encoder_hidden_states),
encoder_attn_mask,
None,
layer_states.1,
train,
);
let output1: Tensor = output1.apply_t(&self.dropout, train) + output; let output1: Tensor = output1.apply_t(&self.dropout, train) + output;
let output1 = output1.apply(&self.encoder_attention_layer_norm); let output1 = output1.apply(&self.encoder_attention_layer_norm);
let output2 = (self.activation)(&output1.apply(&self.fc1)); let output2 = (self.activation)(&output1.apply(&self.fc1));
@ -116,7 +161,11 @@ impl DecoderLayer {
.apply(&self.fc2) .apply(&self.fc2)
.apply_t(&self.dropout, train); .apply_t(&self.dropout, train);
let output2: Tensor = output2 + output1; let output2: Tensor = output2 + output1;
(output2.apply(&self.final_layer_norm), attention_weights, (new_self_layer_states, new_encoder_layer_states)) (
output2.apply(&self.final_layer_norm),
attention_weights,
(new_self_layer_states, new_encoder_layer_states),
)
} }
} }
@ -136,61 +185,76 @@ impl BartDecoder {
pub fn new(p: nn::Path, config: &BartConfig, generation_mode: bool) -> BartDecoder { pub fn new(p: nn::Path, config: &BartConfig, generation_mode: bool) -> BartDecoder {
let output_past = match config.output_past { let output_past = match config.output_past {
Some(value) => value, Some(value) => value,
None => true None => true,
}; };
let output_attentions = match config.output_attentions { let output_attentions = match config.output_attentions {
Some(value) => value, Some(value) => value,
None => false None => false,
}; };
let output_hidden_states = match config.output_hidden_states { let output_hidden_states = match config.output_hidden_states {
Some(value) => value, Some(value) => value,
None => false None => false,
}; };
let normalize_embedding = match config.normalize_embedding { let normalize_embedding = match config.normalize_embedding {
Some(value) => value, Some(value) => value,
None => true None => true,
}; };
let static_position_embeddings = match config.static_position_embeddings { let static_position_embeddings = match config.static_position_embeddings {
Some(value) => value, Some(value) => value,
None => false None => false,
}; };
let scale_embedding = match config.scale_embedding { let scale_embedding = match config.scale_embedding {
Some(value) => if value { (config.d_model as f64).sqrt() } else { 1.0 }, Some(value) => {
None => 1.0 if value {
(config.d_model as f64).sqrt()
} else {
1.0
}
}
None => 1.0,
}; };
let dropout = Dropout::new(config.dropout); let dropout = Dropout::new(config.dropout);
let layer_norm_embedding = if normalize_embedding { let layer_norm_embedding = if normalize_embedding {
let layer_norm_config = nn::LayerNormConfig { eps: 1e-5, ..Default::default() }; let layer_norm_config = nn::LayerNormConfig {
Some(nn::layer_norm(&p / "layernorm_embedding", eps: 1e-5,
vec![config.d_model], ..Default::default()
layer_norm_config)) };
Some(nn::layer_norm(
&p / "layernorm_embedding",
vec![config.d_model],
layer_norm_config,
))
} else { } else {
None None
}; };
let pad_token_id = match config.pad_token_id { let pad_token_id = match config.pad_token_id {
Some(value) => value, Some(value) => value,
None => 1 None => 1,
}; };
let embed_positions = if static_position_embeddings { let embed_positions = if static_position_embeddings {
EmbeddingOption::SinusoidalPositionalEmbedding(SinusoidalPositionalEmbedding::new(&p / "embed_positions", EmbeddingOption::SinusoidalPositionalEmbedding(SinusoidalPositionalEmbedding::new(
config.max_position_embeddings, &p / "embed_positions",
config.d_model)) config.max_position_embeddings,
config.d_model,
))
} else { } else {
EmbeddingOption::LearnedPositionalEmbedding(LearnedPositionalEmbedding::new(&p / "embed_positions", EmbeddingOption::LearnedPositionalEmbedding(LearnedPositionalEmbedding::new(
config.max_position_embeddings, &p / "embed_positions",
config.d_model, config.max_position_embeddings,
pad_token_id)) config.d_model,
pad_token_id,
))
}; };
let mut layers: Vec<DecoderLayer> = vec!(); let mut layers: Vec<DecoderLayer> = vec![];
let p_layers = &p / "layers"; let p_layers = &p / "layers";
for layer_index in 0..config.decoder_layers { for layer_index in 0..config.decoder_layers {
layers.push(DecoderLayer::new(&p_layers / layer_index, config)); layers.push(DecoderLayer::new(&p_layers / layer_index, config));
}; }
BartDecoder { BartDecoder {
dropout, dropout,
@ -205,44 +269,68 @@ impl BartDecoder {
} }
} }
pub fn forward_t(&self, pub fn forward_t(
input_ids: &Tensor, &self,
encoder_hidden_states: &Tensor, input_ids: &Tensor,
encoder_padding_mask: Option<&Tensor>, encoder_hidden_states: &Tensor,
decoder_padding_mask: Option<&Tensor>, encoder_padding_mask: Option<&Tensor>,
decoder_causal_mask: Option<&Tensor>, decoder_padding_mask: Option<&Tensor>,
embeddings: &nn::Embedding, decoder_causal_mask: Option<&Tensor>,
old_layer_states: Option<Vec<(Option<LayerState>, Option<LayerState>)>>, embeddings: &nn::Embedding,
train: bool) old_layer_states: Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
-> (Tensor, train: bool,
(Option<Tensor>, Option<Vec<(Option<LayerState>, Option<LayerState>)>>), ) -> (
Option<Vec<Tensor>>, Tensor,
Option<Vec<Tensor>>) { (
Option<Tensor>,
Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
),
Option<Vec<Tensor>>,
Option<Vec<Tensor>>,
) {
let encoder_padding_mask = match encoder_padding_mask { let encoder_padding_mask = match encoder_padding_mask {
Some(mask) => Some(mask.eq(0).to_kind(Int64)), Some(mask) => Some(mask.eq(0).to_kind(Int64)),
None => None None => None,
}; };
let positions = self.embed_positions.forward(input_ids, self.generation_mode); let positions = self
.embed_positions
.forward(input_ids, self.generation_mode);
let x: Tensor = if self.generation_mode { let x: Tensor = if self.generation_mode {
let end_inputs = input_ids.size()[1]; let end_inputs = input_ids.size()[1];
let end_positions = positions.size()[1]; let end_positions = positions.size()[1];
input_ids.narrow(1, end_inputs - 1, 1).apply(embeddings) * self.scale_embedding + positions.narrow(1, end_positions - 1, 1) input_ids.narrow(1, end_inputs - 1, 1).apply(embeddings) * self.scale_embedding
+ positions.narrow(1, end_positions - 1, 1)
} else { } else {
input_ids.apply(embeddings) * self.scale_embedding + positions input_ids.apply(embeddings) * self.scale_embedding + positions
}; };
let x = if let Some(layer_norm_embedding) = &self.layer_norm_embedding { x.apply(layer_norm_embedding) } else { x }; let x = if let Some(layer_norm_embedding) = &self.layer_norm_embedding {
let mut hidden_state = x x.apply(layer_norm_embedding)
.apply_t(&self.dropout, train) } else {
.transpose(0, 1); x
let mut all_hidden_states: Option<Vec<Tensor>> = if self.output_hidden_states { Some(Vec::with_capacity(self.layers.len())) } else { None }; };
let mut all_attentions: Option<Vec<Tensor>> = if self.output_attentions { Some(Vec::with_capacity(self.layers.len())) } else { None }; let mut hidden_state = x.apply_t(&self.dropout, train).transpose(0, 1);
let mut next_decoder_cache: Option<Vec<(Option<LayerState>, Option<LayerState>)>> = if self.output_past { let mut all_hidden_states: Option<Vec<Tensor>> = if self.output_hidden_states {
if old_layer_states.is_some() { old_layer_states } else { Some(vec!((None, None); self.layers.len())) } Some(Vec::with_capacity(self.layers.len()))
} else { } else {
None None
}; };
let mut all_attentions: Option<Vec<Tensor>> = if self.output_attentions {
Some(Vec::with_capacity(self.layers.len()))
} else {
None
};
let mut next_decoder_cache: Option<Vec<(Option<LayerState>, Option<LayerState>)>> =
if self.output_past {
if old_layer_states.is_some() {
old_layer_states
} else {
Some(vec![(None, None); self.layers.len()])
}
} else {
None
};
let encoder_hidden_states = encoder_hidden_states.transpose(0, 1); let encoder_hidden_states = encoder_hidden_states.transpose(0, 1);
let mut attention_weights: Option<Tensor>; let mut attention_weights: Option<Tensor>;
let mut layers = self.layers.iter().enumerate(); let mut layers = self.layers.iter().enumerate();
@ -252,15 +340,17 @@ impl BartDecoder {
Some((layer_idx, layer)) => { Some((layer_idx, layer)) => {
let layer_state = match &next_decoder_cache { let layer_state = match &next_decoder_cache {
Some(values) => values[layer_idx].to_owned(), Some(values) => values[layer_idx].to_owned(),
None => (None, None) None => (None, None),
}; };
let temp = layer.forward_t(&hidden_state, let temp = layer.forward_t(
&encoder_hidden_states, &hidden_state,
encoder_padding_mask.as_ref(), &encoder_hidden_states,
decoder_causal_mask, encoder_padding_mask.as_ref(),
decoder_padding_mask, decoder_causal_mask,
layer_state, decoder_padding_mask,
train); layer_state,
train,
);
hidden_state = temp.0; hidden_state = temp.0;
attention_weights = temp.1; attention_weights = temp.1;
if let Some(hidden_states) = all_hidden_states.borrow_mut() { if let Some(hidden_states) = all_hidden_states.borrow_mut() {
@ -269,15 +359,19 @@ impl BartDecoder {
if let Some(attentions) = all_attentions.borrow_mut() { if let Some(attentions) = all_attentions.borrow_mut() {
attentions.push(attention_weights.as_ref().unwrap().copy()); attentions.push(attention_weights.as_ref().unwrap().copy());
}; };
if let Some(value) = &mut next_decoder_cache { value[layer_idx] = temp.2 }; if let Some(value) = &mut next_decoder_cache {
value[layer_idx] = temp.2
};
} }
None => break None => break,
}; };
}; }
(hidden_state.transpose(0, 1), (
(encoder_padding_mask, next_decoder_cache), hidden_state.transpose(0, 1),
all_hidden_states, (encoder_padding_mask, next_decoder_cache),
all_attentions) all_hidden_states,
all_attentions,
)
} }
} }

View File

@ -11,10 +11,9 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use tch::{nn, Tensor};
use tch::nn::{EmbeddingConfig, embedding};
use tch::kind::Kind::Int64; use tch::kind::Kind::Int64;
use tch::nn::{embedding, EmbeddingConfig};
use tch::{nn, Tensor};
/// # Abstraction that holds a embeddings configuration /// # Abstraction that holds a embeddings configuration
pub enum EmbeddingOption { pub enum EmbeddingOption {
@ -27,8 +26,12 @@ impl EmbeddingOption {
/// Interface method to forward_t() of the particular models. /// Interface method to forward_t() of the particular models.
pub fn forward(&self, input: &Tensor, generation_mode: bool) -> Tensor { pub fn forward(&self, input: &Tensor, generation_mode: bool) -> Tensor {
match *self { match *self {
Self::LearnedPositionalEmbedding(ref embeddings) => embeddings.forward(input, generation_mode), Self::LearnedPositionalEmbedding(ref embeddings) => {
Self::SinusoidalPositionalEmbedding(ref embeddings) => embeddings.forward(input, generation_mode) embeddings.forward(input, generation_mode)
}
Self::SinusoidalPositionalEmbedding(ref embeddings) => {
embeddings.forward(input, generation_mode)
}
} }
} }
} }
@ -40,15 +43,24 @@ pub struct LearnedPositionalEmbedding {
} }
impl LearnedPositionalEmbedding { impl LearnedPositionalEmbedding {
pub fn new(p: nn::Path, num_embeddings: i64, embedding_dim: i64, padding_index: i64) -> LearnedPositionalEmbedding { pub fn new(
let embedding_config = EmbeddingConfig { padding_idx: padding_index, ..Default::default() }; p: nn::Path,
num_embeddings: i64,
embedding_dim: i64,
padding_index: i64,
) -> LearnedPositionalEmbedding {
let embedding_config = EmbeddingConfig {
padding_idx: padding_index,
..Default::default()
};
let num_embeddings = num_embeddings + padding_index + 1; let num_embeddings = num_embeddings + padding_index + 1;
let embedding: nn::Embedding = embedding(p, let embedding: nn::Embedding =
num_embeddings, embedding(p, num_embeddings, embedding_dim, embedding_config);
embedding_dim, LearnedPositionalEmbedding {
embedding_config); embedding,
LearnedPositionalEmbedding { embedding, padding_index } padding_index,
}
} }
pub fn forward(&self, input: &Tensor, generation_mode: bool) -> Tensor { pub fn forward(&self, input: &Tensor, generation_mode: bool) -> Tensor {
@ -74,11 +86,13 @@ pub struct SinusoidalPositionalEmbedding {
} }
impl SinusoidalPositionalEmbedding { impl SinusoidalPositionalEmbedding {
pub fn new(p: nn::Path, num_embeddings: i64, embedding_dim: i64) -> SinusoidalPositionalEmbedding { pub fn new(
let embedding: nn::Embedding = embedding(p, p: nn::Path,
num_embeddings, num_embeddings: i64,
embedding_dim, embedding_dim: i64,
Default::default()); ) -> SinusoidalPositionalEmbedding {
let embedding: nn::Embedding =
embedding(p, num_embeddings, embedding_dim, Default::default());
SinusoidalPositionalEmbedding { embedding } SinusoidalPositionalEmbedding { embedding }
} }
@ -86,8 +100,8 @@ impl SinusoidalPositionalEmbedding {
let positions = if generation_mode { let positions = if generation_mode {
Tensor::full(&[1, 1], input.size()[1] - 1, (Int64, input.device())) Tensor::full(&[1, 1], input.size()[1] - 1, (Int64, input.device()))
} else { } else {
Tensor::arange(input.size()[1],(Int64, input.device())) Tensor::arange(input.size()[1], (Int64, input.device()))
}; };
positions.apply(&self.embedding) positions.apply(&self.embedding)
} }
} }

View File

@ -12,14 +12,16 @@
// limitations under the License. // limitations under the License.
use crate::bart::attention::SelfAttention; use crate::bart::attention::SelfAttention;
use tch::{nn, Tensor};
use crate::common::dropout::Dropout;
use crate::bart::BartConfig;
use crate::bart::bart::Activation; use crate::bart::bart::Activation;
use crate::common::activations::{_gelu, _relu, _swish, _gelu_new, _tanh}; use crate::bart::embeddings::{
use crate::bart::embeddings::{EmbeddingOption, LearnedPositionalEmbedding, SinusoidalPositionalEmbedding}; EmbeddingOption, LearnedPositionalEmbedding, SinusoidalPositionalEmbedding,
use tch::kind::Kind::Bool; };
use crate::bart::BartConfig;
use crate::common::activations::{_gelu, _gelu_new, _relu, _swish, _tanh};
use crate::common::dropout::Dropout;
use std::borrow::BorrowMut; use std::borrow::BorrowMut;
use tch::kind::Kind::Bool;
use tch::{nn, Tensor};
pub struct EncoderLayer { pub struct EncoderLayer {
self_attention: SelfAttention, self_attention: SelfAttention,
@ -34,46 +36,81 @@ pub struct EncoderLayer {
impl EncoderLayer { impl EncoderLayer {
pub fn new(p: nn::Path, config: &BartConfig) -> EncoderLayer { pub fn new(p: nn::Path, config: &BartConfig) -> EncoderLayer {
let layer_norm_config = nn::LayerNormConfig { eps: 1e-5, ..Default::default() }; let layer_norm_config = nn::LayerNormConfig {
eps: 1e-5,
..Default::default()
};
let output_attention = match config.output_attentions { let output_attention = match config.output_attentions {
Some(value) => value, Some(value) => value,
None => false None => false,
}; };
let self_attention = SelfAttention::new(&p / "self_attn", let self_attention = SelfAttention::new(
config.d_model, &p / "self_attn",
config.encoder_attention_heads, config.d_model,
config.attention_dropout, config.encoder_attention_heads,
false, config.attention_dropout,
false, false,
output_attention); false,
let self_attention_layer_norm = nn::layer_norm(&p / "self_attn_layer_norm", output_attention,
vec![config.d_model], );
layer_norm_config); let self_attention_layer_norm = nn::layer_norm(
&p / "self_attn_layer_norm",
vec![config.d_model],
layer_norm_config,
);
let dropout = Dropout::new(config.dropout); let dropout = Dropout::new(config.dropout);
let activation_dropout = Dropout::new(config.activation_dropout); let activation_dropout = Dropout::new(config.activation_dropout);
let activation_function = match &config.activation_function { let activation_function = match &config.activation_function {
Some(act_function) => act_function, Some(act_function) => act_function,
None => &Activation::gelu None => &Activation::gelu,
}; };
let activation = Box::new(match activation_function { let activation = Box::new(match activation_function {
Activation::gelu => _gelu, Activation::gelu => _gelu,
Activation::relu => _relu, Activation::relu => _relu,
Activation::swish => _swish, Activation::swish => _swish,
Activation::gelu_new => _gelu_new, Activation::gelu_new => _gelu_new,
Activation::tanh => _tanh Activation::tanh => _tanh,
}); });
let fc1 = nn::linear(&p / "fc1", config.d_model, config.encoder_ffn_dim, Default::default()); let fc1 = nn::linear(
let fc2 = nn::linear(&p / "fc2", config.encoder_ffn_dim, config.d_model, Default::default()); &p / "fc1",
config.d_model,
config.encoder_ffn_dim,
Default::default(),
);
let fc2 = nn::linear(
&p / "fc2",
config.encoder_ffn_dim,
config.d_model,
Default::default(),
);
let final_layer_norm = nn::layer_norm(&p / "final_layer_norm", let final_layer_norm = nn::layer_norm(
vec![config.d_model], &p / "final_layer_norm",
layer_norm_config); vec![config.d_model],
layer_norm_config,
);
EncoderLayer { self_attention, self_attention_layer_norm, dropout, activation_dropout, activation, fc1, fc2, final_layer_norm } EncoderLayer {
self_attention,
self_attention_layer_norm,
dropout,
activation_dropout,
activation,
fc1,
fc2,
final_layer_norm,
}
} }
pub fn forward_t(&self, x: &Tensor, encoder_padding_mask: Option<&Tensor>, train: bool) -> (Tensor, Option<Tensor>) { pub fn forward_t(
let (output, attention_weights, _) = self.self_attention.forward_t(x, None, encoder_padding_mask, None, None, train); &self,
x: &Tensor,
encoder_padding_mask: Option<&Tensor>,
train: bool,
) -> (Tensor, Option<Tensor>) {
let (output, attention_weights, _) =
self.self_attention
.forward_t(x, None, encoder_padding_mask, None, None, train);
let output: Tensor = output.apply_t(&self.dropout, train) + x; let output: Tensor = output.apply_t(&self.dropout, train) + x;
let output = output.apply(&self.self_attention_layer_norm); let output = output.apply(&self.self_attention_layer_norm);
@ -102,57 +139,72 @@ impl BartEncoder {
pub fn new(p: nn::Path, config: &BartConfig) -> BartEncoder { pub fn new(p: nn::Path, config: &BartConfig) -> BartEncoder {
let output_attentions = match config.output_attentions { let output_attentions = match config.output_attentions {
Some(value) => value, Some(value) => value,
None => false None => false,
}; };
let output_hidden_states = match config.output_hidden_states { let output_hidden_states = match config.output_hidden_states {
Some(value) => value, Some(value) => value,
None => false None => false,
}; };
let normalize_embedding = match config.normalize_embedding { let normalize_embedding = match config.normalize_embedding {
Some(value) => value, Some(value) => value,
None => true None => true,
}; };
let static_position_embeddings = match config.static_position_embeddings { let static_position_embeddings = match config.static_position_embeddings {
Some(value) => value, Some(value) => value,
None => false None => false,
}; };
let scale_embedding = match config.scale_embedding { let scale_embedding = match config.scale_embedding {
Some(value) => if value { (config.d_model as f64).sqrt() } else { 1.0 }, Some(value) => {
None => 1.0 if value {
(config.d_model as f64).sqrt()
} else {
1.0
}
}
None => 1.0,
}; };
let dropout = Dropout::new(config.dropout); let dropout = Dropout::new(config.dropout);
let layer_norm_embedding = if normalize_embedding { let layer_norm_embedding = if normalize_embedding {
let layer_norm_config = nn::LayerNormConfig { eps: 1e-5, ..Default::default() }; let layer_norm_config = nn::LayerNormConfig {
Some(nn::layer_norm(&p / "layernorm_embedding", eps: 1e-5,
vec![config.d_model], ..Default::default()
layer_norm_config)) };
Some(nn::layer_norm(
&p / "layernorm_embedding",
vec![config.d_model],
layer_norm_config,
))
} else { } else {
None None
}; };
let pad_token_id = match config.pad_token_id { let pad_token_id = match config.pad_token_id {
Some(value) => value, Some(value) => value,
None => 1 None => 1,
}; };
let embed_positions = if static_position_embeddings { let embed_positions = if static_position_embeddings {
EmbeddingOption::SinusoidalPositionalEmbedding(SinusoidalPositionalEmbedding::new(&p / "embed_positions", EmbeddingOption::SinusoidalPositionalEmbedding(SinusoidalPositionalEmbedding::new(
config.max_position_embeddings, &p / "embed_positions",
config.d_model)) config.max_position_embeddings,
config.d_model,
))
} else { } else {
EmbeddingOption::LearnedPositionalEmbedding(LearnedPositionalEmbedding::new(&p / "embed_positions", EmbeddingOption::LearnedPositionalEmbedding(LearnedPositionalEmbedding::new(
config.max_position_embeddings, &p / "embed_positions",
config.d_model, config.max_position_embeddings,
pad_token_id)) config.d_model,
pad_token_id,
))
}; };
let mut layers: Vec<EncoderLayer> = vec!(); let mut layers: Vec<EncoderLayer> = vec![];
let p_layers = &p / "layers"; let p_layers = &p / "layers";
for layer_index in 0..config.encoder_layers { for layer_index in 0..config.encoder_layers {
layers.push(EncoderLayer::new(&p_layers / layer_index, config)); layers.push(EncoderLayer::new(&p_layers / layer_index, config));
}; }
BartEncoder { BartEncoder {
dropout, dropout,
@ -165,26 +217,37 @@ impl BartEncoder {
} }
} }
pub fn forward_t(&self, pub fn forward_t(
input_ids: &Tensor, &self,
attention_mask: Option<&Tensor>, input_ids: &Tensor,
embeddings: &nn::Embedding, attention_mask: Option<&Tensor>,
train: bool) embeddings: &nn::Embedding,
-> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) { train: bool,
) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
let attention_mask = match attention_mask { let attention_mask = match attention_mask {
Some(mask) => Some(mask.eq(0).to_kind(Bool)), Some(mask) => Some(mask.eq(0).to_kind(Bool)),
None => None None => None,
}; };
let x = input_ids.apply(embeddings) * self.scale_embedding; let x = input_ids.apply(embeddings) * self.scale_embedding;
let x: Tensor = x + &self.embed_positions.forward(input_ids, false); let x: Tensor = x + &self.embed_positions.forward(input_ids, false);
let x = if let Some(layer_norm_embedding) = &self.layer_norm_embedding { x.apply(layer_norm_embedding) } else { x }; let x = if let Some(layer_norm_embedding) = &self.layer_norm_embedding {
let x = x x.apply(layer_norm_embedding)
.apply_t(&self.dropout, train) } else {
.transpose(0, 1); x
};
let x = x.apply_t(&self.dropout, train).transpose(0, 1);
let mut all_hidden_states: Option<Vec<Tensor>> = if self.output_hidden_states { Some(vec!()) } else { None }; let mut all_hidden_states: Option<Vec<Tensor>> = if self.output_hidden_states {
let mut all_attentions: Option<Vec<Tensor>> = if self.output_attentions { Some(vec!()) } else { None }; Some(vec![])
} else {
None
};
let mut all_attentions: Option<Vec<Tensor>> = if self.output_attentions {
Some(vec![])
} else {
None
};
let mut hidden_state = x.copy(); let mut hidden_state = x.copy();
let mut attention_weights: Option<Tensor>; let mut attention_weights: Option<Tensor>;
@ -204,13 +267,17 @@ impl BartEncoder {
attentions.push(attention_weights.as_ref().unwrap().copy()); attentions.push(attention_weights.as_ref().unwrap().copy());
}; };
} }
None => break None => break,
}; };
}; }
if let Some(hidden_states) = all_hidden_states.borrow_mut() { if let Some(hidden_states) = all_hidden_states.borrow_mut() {
hidden_states.push(hidden_state.as_ref().copy().transpose(0, 1)); hidden_states.push(hidden_state.as_ref().copy().transpose(0, 1));
}; };
(hidden_state.transpose(0, 1), all_hidden_states, all_attentions) (
hidden_state.transpose(0, 1),
all_hidden_states,
all_attentions,
)
} }
} }

View File

@ -15,19 +15,27 @@
//! Pretrained models are available and can be downloaded using RemoteResources. //! Pretrained models are available and can be downloaded using RemoteResources.
//! //!
//! ```no_run //! ```no_run
//!# fn main() -> failure::Fallible<()> { //! # fn main() -> failure::Fallible<()> {
//!# //! #
//! use rust_tokenizers::RobertaTokenizer; //! use rust_tokenizers::RobertaTokenizer;
//! use tch::{nn, Device}; //! use tch::{nn, Device};
//!# use std::path::PathBuf; //! # use std::path::PathBuf;
//! use rust_bert::Config;
//! use rust_bert::bart::{BartConfig, BartModel}; //! use rust_bert::bart::{BartConfig, BartModel};
//! use rust_bert::resources::{Resource, download_resource, LocalResource}; //! use rust_bert::resources::{download_resource, LocalResource, Resource};
//! use rust_bert::Config;
//! //!
//! let config_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/config.json")}); //! let config_resource = Resource::Local(LocalResource {
//! let vocab_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/vocab.txt")}); //! local_path: PathBuf::from("path/to/config.json"),
//! let merges_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/vocab.txt")}); //! });
//! let weights_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/model.ot")}); //! let vocab_resource = Resource::Local(LocalResource {
//! local_path: PathBuf::from("path/to/vocab.txt"),
//! });
//! let merges_resource = Resource::Local(LocalResource {
//! local_path: PathBuf::from("path/to/vocab.txt"),
//! });
//! let weights_resource = Resource::Local(LocalResource {
//! local_path: PathBuf::from("path/to/model.ot"),
//! });
//! let config_path = download_resource(&config_resource)?; //! let config_path = download_resource(&config_resource)?;
//! let vocab_path = download_resource(&vocab_resource)?; //! let vocab_path = download_resource(&vocab_resource)?;
//! let merges_path = download_resource(&merges_resource)?; //! let merges_path = download_resource(&merges_resource)?;
@ -35,21 +43,28 @@
//! //!
//! let device = Device::cuda_if_available(); //! let device = Device::cuda_if_available();
//! let mut vs = nn::VarStore::new(device); //! let mut vs = nn::VarStore::new(device);
//! let tokenizer: RobertaTokenizer = RobertaTokenizer::from_file(vocab_path.to_str().unwrap(), merges_path.to_str().unwrap(), true); //! let tokenizer: RobertaTokenizer = RobertaTokenizer::from_file(
//! vocab_path.to_str().unwrap(),
//! merges_path.to_str().unwrap(),
//! true,
//! );
//! let config = BartConfig::from_file(config_path); //! let config = BartConfig::from_file(config_path);
//! let bart_model = BartModel::new(&vs.root(), &config, false); //! let bart_model = BartModel::new(&vs.root(), &config, false);
//! vs.load(weights_path)?; //! vs.load(weights_path)?;
//! //!
//!# Ok(()) //! # Ok(())
//!# } //! # }
//! ``` //! ```
mod bart;
mod attention; mod attention;
mod encoder; mod bart;
mod decoder; mod decoder;
mod embeddings; mod embeddings;
mod encoder;
pub use bart::{BartModelResources, BartConfigResources, BartVocabResources, BartMergesResources, pub use attention::LayerState;
BartConfig, Activation, BartModel, BartForSequenceClassification, BartForConditionalGeneration}; pub use bart::{
pub use attention::LayerState; Activation, BartConfig, BartConfigResources, BartForConditionalGeneration,
BartForSequenceClassification, BartMergesResources, BartModel, BartModelResources,
BartVocabResources,
};

View File

@ -11,11 +11,11 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use crate::common::dropout::Dropout;
use tch::{nn, Tensor};
use tch::kind::Kind::Float;
use crate::bert::bert::{Activation, BertConfig}; use crate::bert::bert::{Activation, BertConfig};
use crate::common::activations::{_gelu, _relu, _mish}; use crate::common::activations::{_gelu, _mish, _relu};
use crate::common::dropout::Dropout;
use tch::kind::Kind::Float;
use tch::{nn, Tensor};
#[derive(Debug)] #[derive(Debug)]
pub struct BertSelfAttention { pub struct BertSelfAttention {
@ -30,17 +30,36 @@ pub struct BertSelfAttention {
impl BertSelfAttention { impl BertSelfAttention {
pub fn new(p: nn::Path, config: &BertConfig) -> BertSelfAttention { pub fn new(p: nn::Path, config: &BertConfig) -> BertSelfAttention {
assert_eq!(config.hidden_size % config.num_attention_heads, 0, "Hidden size not a multiple of the number of attention heads"); assert_eq!(
config.hidden_size % config.num_attention_heads,
0,
"Hidden size not a multiple of the number of attention heads"
);
let query = nn::linear(&p / "query", config.hidden_size, config.hidden_size, Default::default()); let query = nn::linear(
let key = nn::linear(&p / "key", config.hidden_size, config.hidden_size, Default::default()); &p / "query",
let value = nn::linear(&p / "value", config.hidden_size, config.hidden_size, Default::default()); config.hidden_size,
config.hidden_size,
Default::default(),
);
let key = nn::linear(
&p / "key",
config.hidden_size,
config.hidden_size,
Default::default(),
);
let value = nn::linear(
&p / "value",
config.hidden_size,
config.hidden_size,
Default::default(),
);
let dropout = Dropout::new(config.attention_probs_dropout_prob); let dropout = Dropout::new(config.attention_probs_dropout_prob);
let attention_head_size = config.hidden_size / config.num_attention_heads; let attention_head_size = config.hidden_size / config.num_attention_heads;
let output_attentions = match config.output_attentions { let output_attentions = match config.output_attentions {
Some(value) => value, Some(value) => value,
None => false None => false,
}; };
BertSelfAttention { BertSelfAttention {
@ -55,35 +74,44 @@ impl BertSelfAttention {
} }
fn split_heads(&self, x: Tensor, bs: i64, dim_per_head: i64) -> Tensor { fn split_heads(&self, x: Tensor, bs: i64, dim_per_head: i64) -> Tensor {
x.view((bs, -1, self.num_attention_heads, dim_per_head)).transpose(1, 2) x.view((bs, -1, self.num_attention_heads, dim_per_head))
.transpose(1, 2)
} }
fn flatten(&self, x: Tensor, bs: i64, dim_per_head: i64) -> Tensor { fn flatten(&self, x: Tensor, bs: i64, dim_per_head: i64) -> Tensor {
x.transpose(1, 2).contiguous().view((bs, -1, &self.num_attention_heads * dim_per_head)) x.transpose(1, 2)
.contiguous()
.view((bs, -1, &self.num_attention_heads * dim_per_head))
} }
pub fn forward_t(&self, pub fn forward_t(
hidden_states: &Tensor, &self,
mask: &Option<Tensor>, hidden_states: &Tensor,
encoder_hidden_states: &Option<Tensor>, mask: &Option<Tensor>,
encoder_mask: &Option<Tensor>, encoder_hidden_states: &Option<Tensor>,
train: bool) -> (Tensor, Option<Tensor>) { encoder_mask: &Option<Tensor>,
train: bool,
) -> (Tensor, Option<Tensor>) {
let (key_layer, value_layer, mask) = match encoder_hidden_states { let (key_layer, value_layer, mask) = match encoder_hidden_states {
Some(encoder_hidden_state_values) => { Some(encoder_hidden_state_values) => (
(encoder_hidden_state_values.apply(&self.key), encoder_hidden_state_values.apply(&self.key),
encoder_hidden_state_values.apply(&self.value), encoder_hidden_state_values.apply(&self.value),
encoder_mask) encoder_mask,
} ),
None => { None => (
(hidden_states.apply(&self.key), hidden_states.apply(&self.key),
hidden_states.apply(&self.value), hidden_states.apply(&self.value),
mask) mask,
} ),
}; };
let bs = hidden_states.size()[0]; let bs = hidden_states.size()[0];
let query_layer = self.split_heads(hidden_states.apply(&self.query), bs, self.attention_head_size); let query_layer = self.split_heads(
hidden_states.apply(&self.query),
bs,
self.attention_head_size,
);
let key_layer = self.split_heads(key_layer, bs, self.attention_head_size); let key_layer = self.split_heads(key_layer, bs, self.attention_head_size);
let value_layer = self.split_heads(value_layer, bs, self.attention_head_size); let value_layer = self.split_heads(value_layer, bs, self.attention_head_size);
let query_layer: Tensor = query_layer / (self.attention_head_size as f64).sqrt(); let query_layer: Tensor = query_layer / (self.attention_head_size as f64).sqrt();
@ -114,16 +142,32 @@ pub struct BertSelfOutput {
impl BertSelfOutput { impl BertSelfOutput {
pub fn new(p: &nn::Path, config: &BertConfig) -> BertSelfOutput { pub fn new(p: &nn::Path, config: &BertConfig) -> BertSelfOutput {
let linear = nn::linear(p / "dense", config.hidden_size, config.hidden_size, Default::default()); let linear = nn::linear(
let layer_norm_config = nn::LayerNormConfig { eps: 1e-12, ..Default::default() }; p / "dense",
let layer_norm = nn::layer_norm(p / "LayerNorm", vec![config.hidden_size], layer_norm_config); config.hidden_size,
config.hidden_size,
Default::default(),
);
let layer_norm_config = nn::LayerNormConfig {
eps: 1e-12,
..Default::default()
};
let layer_norm =
nn::layer_norm(p / "LayerNorm", vec![config.hidden_size], layer_norm_config);
let dropout = Dropout::new(config.hidden_dropout_prob); let dropout = Dropout::new(config.hidden_dropout_prob);
BertSelfOutput { linear, layer_norm, dropout } BertSelfOutput {
linear,
layer_norm,
dropout,
}
} }
pub fn forward_t(&self, hidden_states: &Tensor, input_tensor: &Tensor, train: bool) -> Tensor { pub fn forward_t(&self, hidden_states: &Tensor, input_tensor: &Tensor, train: bool) -> Tensor {
let hidden_states: Tensor = input_tensor + hidden_states.apply(&self.linear).apply_t(&self.dropout, train); let hidden_states: Tensor = input_tensor
+ hidden_states
.apply(&self.linear)
.apply_t(&self.dropout, train);
hidden_states.apply(&self.layer_norm) hidden_states.apply(&self.layer_norm)
} }
} }
@ -141,14 +185,21 @@ impl BertAttention {
BertAttention { _self, output } BertAttention { _self, output }
} }
pub fn forward_t(&self, pub fn forward_t(
hidden_states: &Tensor, &self,
mask: &Option<Tensor>, hidden_states: &Tensor,
encoder_hidden_states: &Option<Tensor>, mask: &Option<Tensor>,
encoder_mask: &Option<Tensor>, encoder_hidden_states: &Option<Tensor>,
train: bool) -> (Tensor, Option<Tensor>) { encoder_mask: &Option<Tensor>,
let (self_output, attention_weights) = self._self. train: bool,
forward_t(hidden_states, mask, encoder_hidden_states, encoder_mask, train); ) -> (Tensor, Option<Tensor>) {
let (self_output, attention_weights) = self._self.forward_t(
hidden_states,
mask,
encoder_hidden_states,
encoder_mask,
train,
);
let self_output = self.output.forward_t(&self_output, hidden_states, train); let self_output = self.output.forward_t(&self_output, hidden_states, train);
(self_output, attention_weights) (self_output, attention_weights)
@ -162,11 +213,16 @@ pub struct BertIntermediate {
impl BertIntermediate { impl BertIntermediate {
pub fn new(p: &nn::Path, config: &BertConfig) -> BertIntermediate { pub fn new(p: &nn::Path, config: &BertConfig) -> BertIntermediate {
let lin = nn::linear(p / "dense", config.hidden_size, config.intermediate_size, Default::default()); let lin = nn::linear(
p / "dense",
config.hidden_size,
config.intermediate_size,
Default::default(),
);
let activation = Box::new(match &config.hidden_act { let activation = Box::new(match &config.hidden_act {
Activation::gelu => _gelu, Activation::gelu => _gelu,
Activation::relu => _relu, Activation::relu => _relu,
Activation::mish => _mish Activation::mish => _mish,
}); });
BertIntermediate { lin, activation } BertIntermediate { lin, activation }
} }
@ -184,17 +240,30 @@ pub struct BertOutput {
impl BertOutput { impl BertOutput {
pub fn new(p: &nn::Path, config: &BertConfig) -> BertOutput { pub fn new(p: &nn::Path, config: &BertConfig) -> BertOutput {
let lin = nn::linear(p / "dense", config.intermediate_size, config.hidden_size, Default::default()); let lin = nn::linear(
let layer_norm_config = nn::LayerNormConfig { eps: 1e-12, ..Default::default() }; p / "dense",
let layer_norm = nn::layer_norm(p / "LayerNorm", vec![config.hidden_size], layer_norm_config); config.intermediate_size,
config.hidden_size,
Default::default(),
);
let layer_norm_config = nn::LayerNormConfig {
eps: 1e-12,
..Default::default()
};
let layer_norm =
nn::layer_norm(p / "LayerNorm", vec![config.hidden_size], layer_norm_config);
let dropout = Dropout::new(config.hidden_dropout_prob); let dropout = Dropout::new(config.hidden_dropout_prob);
BertOutput { lin, layer_norm, dropout } BertOutput {
lin,
layer_norm,
dropout,
}
} }
pub fn forward_t(&self, hidden_states: &Tensor, input_tensor: &Tensor, train: bool) -> Tensor { pub fn forward_t(&self, hidden_states: &Tensor, input_tensor: &Tensor, train: bool) -> Tensor {
let hidden_states: Tensor = input_tensor + hidden_states.apply(&self.lin).apply_t(&self.dropout, train); let hidden_states: Tensor =
input_tensor + hidden_states.apply(&self.lin).apply_t(&self.dropout, train);
hidden_states.apply(&self.layer_norm) hidden_states.apply(&self.layer_norm)
} }
} }

File diff suppressed because it is too large Load Diff

View File

@ -11,22 +11,24 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use tch::{nn, Tensor, Kind};
use tch::nn::{EmbeddingConfig, embedding};
use crate::common::dropout::Dropout;
use crate::bert::bert::BertConfig; use crate::bert::bert::BertConfig;
use crate::common::dropout::Dropout;
use tch::nn::{embedding, EmbeddingConfig};
use tch::{nn, Kind, Tensor};
/// # BertEmbedding trait (for use in BertModel or RoBERTaModel) /// # BertEmbedding trait (for use in BertModel or RoBERTaModel)
/// Defines an interface for the embedding layers in BERT-based models /// Defines an interface for the embedding layers in BERT-based models
pub trait BertEmbedding { pub trait BertEmbedding {
fn new(p: &nn::Path, config: &BertConfig) -> Self; fn new(p: &nn::Path, config: &BertConfig) -> Self;
fn forward_t(&self, fn forward_t(
input_ids: Option<Tensor>, &self,
token_type_ids: Option<Tensor>, input_ids: Option<Tensor>,
position_ids: Option<Tensor>, token_type_ids: Option<Tensor>,
input_embeds: Option<Tensor>, position_ids: Option<Tensor>,
train: bool) -> Result<Tensor, &'static str>; input_embeds: Option<Tensor>,
train: bool,
) -> Result<Tensor, &'static str>;
} }
#[derive(Debug)] #[derive(Debug)]
@ -51,10 +53,10 @@ impl BertEmbedding for BertEmbeddings {
/// # Example /// # Example
/// ///
/// ```no_run /// ```no_run
/// use rust_bert::bert::{BertConfig, BertEmbeddings, BertEmbedding}; /// use rust_bert::bert::{BertConfig, BertEmbedding, BertEmbeddings};
/// use tch::{nn, Device};
/// use rust_bert::Config; /// use rust_bert::Config;
/// use std::path::Path; /// use std::path::Path;
/// use tch::{nn, Device};
/// ///
/// let config_path = Path::new("path/to/config.json"); /// let config_path = Path::new("path/to/config.json");
/// let device = Device::Cpu; /// let device = Device::Cpu;
@ -62,29 +64,47 @@ impl BertEmbedding for BertEmbeddings {
/// let config = BertConfig::from_file(config_path); /// let config = BertConfig::from_file(config_path);
/// let bert_embeddings = BertEmbeddings::new(&(&p.root() / "bert_embeddings"), &config); /// let bert_embeddings = BertEmbeddings::new(&(&p.root() / "bert_embeddings"), &config);
/// ``` /// ```
///
fn new(p: &nn::Path, config: &BertConfig) -> BertEmbeddings { fn new(p: &nn::Path, config: &BertConfig) -> BertEmbeddings {
let embedding_config = EmbeddingConfig { padding_idx: 0, ..Default::default() }; let embedding_config = EmbeddingConfig {
padding_idx: 0,
..Default::default()
};
let word_embeddings: nn::Embedding = embedding(p / "word_embeddings", let word_embeddings: nn::Embedding = embedding(
config.vocab_size, p / "word_embeddings",
config.hidden_size, config.vocab_size,
embedding_config); config.hidden_size,
embedding_config,
);
let position_embeddings: nn::Embedding = embedding(p / "position_embeddings", let position_embeddings: nn::Embedding = embedding(
config.max_position_embeddings, p / "position_embeddings",
config.hidden_size, config.max_position_embeddings,
Default::default()); config.hidden_size,
Default::default(),
);
let token_type_embeddings: nn::Embedding = embedding(p / "token_type_embeddings", let token_type_embeddings: nn::Embedding = embedding(
config.type_vocab_size, p / "token_type_embeddings",
config.hidden_size, config.type_vocab_size,
Default::default()); config.hidden_size,
Default::default(),
);
let layer_norm_config = nn::LayerNormConfig { eps: 1e-12, ..Default::default() }; let layer_norm_config = nn::LayerNormConfig {
let layer_norm: nn::LayerNorm = nn::layer_norm(p / "LayerNorm", vec![config.hidden_size], layer_norm_config); eps: 1e-12,
..Default::default()
};
let layer_norm: nn::LayerNorm =
nn::layer_norm(p / "LayerNorm", vec![config.hidden_size], layer_norm_config);
let dropout: Dropout = Dropout::new(config.hidden_dropout_prob); let dropout: Dropout = Dropout::new(config.hidden_dropout_prob);
BertEmbeddings { word_embeddings, position_embeddings, token_type_embeddings, layer_norm, dropout } BertEmbeddings {
word_embeddings,
position_embeddings,
token_type_embeddings,
layer_norm,
dropout,
}
} }
/// Forward pass through the embedding layer /// Forward pass through the embedding layer
@ -104,50 +124,62 @@ impl BertEmbedding for BertEmbeddings {
/// # Example /// # Example
/// ///
/// ```no_run /// ```no_run
///# use rust_bert::bert::{BertConfig, BertEmbeddings, BertEmbedding}; /// # use rust_bert::bert::{BertConfig, BertEmbeddings, BertEmbedding};
///# use tch::{nn, Device, Tensor, no_grad}; /// # use tch::{nn, Device, Tensor, no_grad};
///# use rust_bert::Config; /// # use rust_bert::Config;
///# use std::path::Path; /// # use std::path::Path;
///# use tch::kind::Kind::Int64; /// # use tch::kind::Kind::Int64;
///# let config_path = Path::new("path/to/config.json"); /// # let config_path = Path::new("path/to/config.json");
///# let vocab_path = Path::new("path/to/vocab.txt"); /// # let vocab_path = Path::new("path/to/vocab.txt");
///# let device = Device::Cpu; /// # let device = Device::Cpu;
///# let vs = nn::VarStore::new(device); /// # let vs = nn::VarStore::new(device);
///# let config = BertConfig::from_file(config_path); /// # let config = BertConfig::from_file(config_path);
///# let bert_embeddings = BertEmbeddings::new(&vs.root(), &config); /// # let bert_embeddings = BertEmbeddings::new(&vs.root(), &config);
/// let (batch_size, sequence_length) = (64, 128); /// let (batch_size, sequence_length) = (64, 128);
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device)); /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
/// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device)); /// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
/// let position_ids = Tensor::arange(sequence_length, (Int64, device)).expand(&[batch_size, sequence_length], true); /// let position_ids = Tensor::arange(sequence_length, (Int64, device))
/// .expand(&[batch_size, sequence_length], true);
/// ///
/// let embedded_output = no_grad(|| { /// let embedded_output = no_grad(|| {
/// bert_embeddings /// bert_embeddings
/// .forward_t(Some(input_tensor), /// .forward_t(
/// Some(token_type_ids), /// Some(input_tensor),
/// Some(position_ids), /// Some(token_type_ids),
/// None, /// Some(position_ids),
/// false).unwrap() /// None,
/// }); /// false,
/// )
/// .unwrap()
/// });
/// ``` /// ```
/// fn forward_t(
fn forward_t(&self, &self,
input_ids: Option<Tensor>, input_ids: Option<Tensor>,
token_type_ids: Option<Tensor>, token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>, position_ids: Option<Tensor>,
input_embeds: Option<Tensor>, input_embeds: Option<Tensor>,
train: bool) -> Result<Tensor, &'static str> { train: bool,
) -> Result<Tensor, &'static str> {
let (input_embeddings, input_shape) = match input_ids { let (input_embeddings, input_shape) = match input_ids {
Some(input_value) => match input_embeds { Some(input_value) => match input_embeds {
Some(_) => { return Err("Only one of input ids or input embeddings may be set"); } Some(_) => {
None => (input_value.apply_t(&self.word_embeddings, train), input_value.size()) return Err("Only one of input ids or input embeddings may be set");
} }
None => (
input_value.apply_t(&self.word_embeddings, train),
input_value.size(),
),
},
None => match input_embeds { None => match input_embeds {
Some(embeds) => { Some(embeds) => {
let size = vec!(embeds.size()[0], embeds.size()[1]); let size = vec![embeds.size()[0], embeds.size()[1]];
(embeds, size) (embeds, size)
}, }
None => { return Err("Only one of input ids or input embeddings may be set"); } None => {
} return Err("Only one of input ids or input embeddings may be set");
}
},
}; };
let seq_length = input_embeddings.as_ref().size()[1].to_owned(); let seq_length = input_embeddings.as_ref().size()[1].to_owned();
@ -155,19 +187,22 @@ impl BertEmbedding for BertEmbeddings {
let position_ids = match position_ids { let position_ids = match position_ids {
Some(value) => value, Some(value) => value,
None => Tensor::arange(seq_length, (Kind::Int64, input_embeddings.device())) None => Tensor::arange(seq_length, (Kind::Int64, input_embeddings.device()))
.unsqueeze(0). .unsqueeze(0)
expand(&input_shape, true) .expand(&input_shape, true),
}; };
let token_type_ids = match token_type_ids { let token_type_ids = match token_type_ids {
Some(value) => value, Some(value) => value,
None => Tensor::zeros(&input_shape, (Kind::Int64, input_embeddings.device())) None => Tensor::zeros(&input_shape, (Kind::Int64, input_embeddings.device())),
}; };
let position_embeddings = position_ids.apply(&self.position_embeddings); let position_embeddings = position_ids.apply(&self.position_embeddings);
let token_type_embeddings = token_type_ids.apply(&self.token_type_embeddings); let token_type_embeddings = token_type_ids.apply(&self.token_type_embeddings);
let input_embeddings: Tensor = input_embeddings + position_embeddings + token_type_embeddings; let input_embeddings: Tensor =
Ok(input_embeddings.apply(&self.layer_norm).apply_t(&self.dropout, train)) input_embeddings + position_embeddings + token_type_embeddings;
Ok(input_embeddings
.apply(&self.layer_norm)
.apply_t(&self.dropout, train))
} }
} }

View File

@ -11,10 +11,10 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use tch::{Tensor, nn};
use crate::bert::attention::{BertAttention, BertIntermediate, BertOutput}; use crate::bert::attention::{BertAttention, BertIntermediate, BertOutput};
use std::borrow::BorrowMut;
use crate::bert::bert::BertConfig; use crate::bert::bert::BertConfig;
use std::borrow::BorrowMut;
use tch::{nn, Tensor};
pub struct BertLayer { pub struct BertLayer {
attention: BertAttention, attention: BertAttention,
@ -30,37 +30,57 @@ impl BertLayer {
let (is_decoder, cross_attention) = match config.is_decoder { let (is_decoder, cross_attention) = match config.is_decoder {
Some(value) => { Some(value) => {
if value == true { if value == true {
(value, Some(BertAttention::new(&(p / "cross_attention"), &config))) (
value,
Some(BertAttention::new(&(p / "cross_attention"), &config)),
)
} else { } else {
(value, None) (value, None)
} }
} }
None => (false, None) None => (false, None),
}; };
let intermediate = BertIntermediate::new(&(p / "intermediate"), &config); let intermediate = BertIntermediate::new(&(p / "intermediate"), &config);
let output = BertOutput::new(&(p / "output"), &config); let output = BertOutput::new(&(p / "output"), &config);
BertLayer { attention, is_decoder, cross_attention, intermediate, output } BertLayer {
attention,
is_decoder,
cross_attention,
intermediate,
output,
}
} }
pub fn forward_t(&self, pub fn forward_t(
hidden_states: &Tensor, &self,
mask: &Option<Tensor>, hidden_states: &Tensor,
encoder_hidden_states: &Option<Tensor>, mask: &Option<Tensor>,
encoder_mask: &Option<Tensor>, encoder_hidden_states: &Option<Tensor>,
train: bool) -> (Tensor, Option<Tensor>, Option<Tensor>) { encoder_mask: &Option<Tensor>,
let (attention_output, attention_weights, cross_attention_weights) = if self.is_decoder & encoder_hidden_states.is_some() { train: bool,
let (attention_output, attention_weights) = ) -> (Tensor, Option<Tensor>, Option<Tensor>) {
self.attention.forward_t(hidden_states, mask, &None, &None, train); let (attention_output, attention_weights, cross_attention_weights) =
let (attention_output, cross_attention_weights) = if self.is_decoder & encoder_hidden_states.is_some() {
self.cross_attention.as_ref().unwrap().forward_t(&attention_output, mask, encoder_hidden_states, encoder_mask, train); let (attention_output, attention_weights) =
(attention_output, attention_weights, cross_attention_weights) self.attention
} else { .forward_t(hidden_states, mask, &None, &None, train);
let (attention_output, attention_weights) = let (attention_output, cross_attention_weights) =
self.attention.forward_t(hidden_states, mask, &None, &None, train); self.cross_attention.as_ref().unwrap().forward_t(
(attention_output, attention_weights, None) &attention_output,
}; mask,
encoder_hidden_states,
encoder_mask,
train,
);
(attention_output, attention_weights, cross_attention_weights)
} else {
let (attention_output, attention_weights) =
self.attention
.forward_t(hidden_states, mask, &None, &None, train);
(attention_output, attention_weights, None)
};
let output = self.intermediate.forward(&attention_output); let output = self.intermediate.forward(&attention_output);
let output = self.output.forward_t(&output, &attention_output, train); let output = self.output.forward_t(&output, &attention_output, train);
@ -78,26 +98,47 @@ pub struct BertEncoder {
impl BertEncoder { impl BertEncoder {
pub fn new(p: &nn::Path, config: &BertConfig) -> BertEncoder { pub fn new(p: &nn::Path, config: &BertConfig) -> BertEncoder {
let p = &(p / "layer"); let p = &(p / "layer");
let output_attentions = if let Some(value) = config.output_attentions { value } else { false }; let output_attentions = if let Some(value) = config.output_attentions {
let output_hidden_states = if let Some(value) = config.output_hidden_states { value } else { false }; value
} else {
let mut layers: Vec<BertLayer> = vec!(); false
for layer_index in 0..config.num_hidden_layers { };
layers.push(BertLayer::new(&(p / layer_index), config)); let output_hidden_states = if let Some(value) = config.output_hidden_states {
value
} else {
false
}; };
BertEncoder { output_attentions, output_hidden_states, layers } let mut layers: Vec<BertLayer> = vec![];
for layer_index in 0..config.num_hidden_layers {
layers.push(BertLayer::new(&(p / layer_index), config));
}
BertEncoder {
output_attentions,
output_hidden_states,
layers,
}
} }
pub fn forward_t(&self, pub fn forward_t(
hidden_states: &Tensor, &self,
mask: &Option<Tensor>, hidden_states: &Tensor,
encoder_hidden_states: &Option<Tensor>, mask: &Option<Tensor>,
encoder_mask: &Option<Tensor>, encoder_hidden_states: &Option<Tensor>,
train: bool) encoder_mask: &Option<Tensor>,
-> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) { train: bool,
let mut all_hidden_states: Option<Vec<Tensor>> = if self.output_hidden_states { Some(vec!()) } else { None }; ) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
let mut all_attentions: Option<Vec<Tensor>> = if self.output_attentions { Some(vec!()) } else { None }; let mut all_hidden_states: Option<Vec<Tensor>> = if self.output_hidden_states {
Some(vec![])
} else {
None
};
let mut all_attentions: Option<Vec<Tensor>> = if self.output_attentions {
Some(vec![])
} else {
None
};
let mut hidden_state = hidden_states.copy(); let mut hidden_state = hidden_states.copy();
let mut attention_weights: Option<Tensor>; let mut attention_weights: Option<Tensor>;
@ -109,16 +150,22 @@ impl BertEncoder {
hidden_states.push(hidden_state.as_ref().copy()); hidden_states.push(hidden_state.as_ref().copy());
}; };
let temp = layer.forward_t(&hidden_state, &mask, encoder_hidden_states, encoder_mask, train); let temp = layer.forward_t(
&hidden_state,
&mask,
encoder_hidden_states,
encoder_mask,
train,
);
hidden_state = temp.0; hidden_state = temp.0;
attention_weights = temp.1; attention_weights = temp.1;
if let Some(attentions) = all_attentions.borrow_mut() { if let Some(attentions) = all_attentions.borrow_mut() {
attentions.push(attention_weights.as_ref().unwrap().copy()); attentions.push(attention_weights.as_ref().unwrap().copy());
}; };
} }
None => break None => break,
}; };
}; }
(hidden_state, all_hidden_states, all_attentions) (hidden_state, all_hidden_states, all_attentions)
} }
@ -130,14 +177,16 @@ pub struct BertPooler {
impl BertPooler { impl BertPooler {
pub fn new(p: &nn::Path, config: &BertConfig) -> BertPooler { pub fn new(p: &nn::Path, config: &BertConfig) -> BertPooler {
let lin = nn::linear(&(p / "dense"), config.hidden_size, config.hidden_size, Default::default()); let lin = nn::linear(
&(p / "dense"),
config.hidden_size,
config.hidden_size,
Default::default(),
);
BertPooler { lin } BertPooler { lin }
} }
pub fn forward(&self, hidden_states: &Tensor) -> Tensor { pub fn forward(&self, hidden_states: &Tensor) -> Tensor {
hidden_states hidden_states.select(1, 0).apply(&self.lin).tanh()
.select(1, 0)
.apply(&self.lin)
.tanh()
} }
} }

View File

@ -19,18 +19,24 @@
//! Pretrained models are available and can be downloaded using RemoteResources. //! Pretrained models are available and can be downloaded using RemoteResources.
//! //!
//! ```no_run //! ```no_run
//!# fn main() -> failure::Fallible<()> { //! # fn main() -> failure::Fallible<()> {
//!# //! #
//! use rust_tokenizers::BertTokenizer; //! use rust_tokenizers::BertTokenizer;
//! use tch::{nn, Device}; //! use tch::{nn, Device};
//!# use std::path::PathBuf; //! # use std::path::PathBuf;
//! use rust_bert::bert::{BertForMaskedLM, BertConfig}; //! use rust_bert::bert::{BertConfig, BertForMaskedLM};
//! use rust_bert::resources::{download_resource, LocalResource, Resource};
//! use rust_bert::Config; //! use rust_bert::Config;
//! use rust_bert::resources::{Resource, download_resource, LocalResource};
//! //!
//! let config_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/config.json")}); //! let config_resource = Resource::Local(LocalResource {
//! let vocab_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/vocab.txt")}); //! local_path: PathBuf::from("path/to/config.json"),
//! let weights_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/model.ot")}); //! });
//! let vocab_resource = Resource::Local(LocalResource {
//! local_path: PathBuf::from("path/to/vocab.txt"),
//! });
//! let weights_resource = Resource::Local(LocalResource {
//! local_path: PathBuf::from("path/to/model.ot"),
//! });
//! let config_path = download_resource(&config_resource)?; //! let config_path = download_resource(&config_resource)?;
//! let vocab_path = download_resource(&vocab_resource)?; //! let vocab_path = download_resource(&vocab_resource)?;
//! let weights_path = download_resource(&weights_resource)?; //! let weights_path = download_resource(&weights_resource)?;
@ -41,17 +47,18 @@
//! let bert_model = BertForMaskedLM::new(&vs.root(), &config); //! let bert_model = BertForMaskedLM::new(&vs.root(), &config);
//! vs.load(weights_path)?; //! vs.load(weights_path)?;
//! //!
//!# Ok(()) //! # Ok(())
//!# } //! # }
//! ``` //! ```
mod attention;
mod bert; mod bert;
mod embeddings; mod embeddings;
mod attention;
pub(crate) mod encoder; pub(crate) mod encoder;
pub use bert::{BertModelResources, BertConfigResources, BertVocabResources, pub use bert::{
BertConfig, Activation, BertModel, BertForTokenClassification, BertForMultipleChoice, Activation, BertConfig, BertConfigResources, BertForMaskedLM, BertForMultipleChoice,
BertForMaskedLM, BertForSequenceClassification, BertForQuestionAnswering}; BertForQuestionAnswering, BertForSequenceClassification, BertForTokenClassification, BertModel,
pub use embeddings::{BertEmbedding, BertEmbeddings}; BertModelResources, BertVocabResources,
};
pub use embeddings::{BertEmbedding, BertEmbeddings};

View File

@ -1,14 +1,26 @@
use tch::Tensor;
use std::f64::consts::PI; use std::f64::consts::PI;
use tch::Tensor;
pub fn _gelu(x: &Tensor) -> Tensor { x * 0.5 * (1.0 + (x / ((2.0 as f64).sqrt())).erf()) } pub fn _gelu(x: &Tensor) -> Tensor {
x * 0.5 * (1.0 + (x / ((2.0 as f64).sqrt())).erf())
}
pub fn _relu(x: &Tensor) -> Tensor { x.relu() } pub fn _relu(x: &Tensor) -> Tensor {
x.relu()
}
pub fn _swish(x: &Tensor) -> Tensor { x * x.sigmoid() } pub fn _swish(x: &Tensor) -> Tensor {
x * x.sigmoid()
}
pub fn _mish(x: &Tensor) -> Tensor { x * (x.softplus().tanh()) } pub fn _mish(x: &Tensor) -> Tensor {
x * (x.softplus().tanh())
}
pub fn _gelu_new(x: &Tensor) -> Tensor { x * 0.5 * (((x.pow(3.0f64) * 0.044715 + x) * ((2f64 / PI).sqrt())).tanh() + 1) } pub fn _gelu_new(x: &Tensor) -> Tensor {
x * 0.5 * (((x.pow(3.0f64) * 0.044715 + x) * ((2f64 / PI).sqrt())).tanh() + 1)
}
pub fn _tanh(x: &Tensor) -> Tensor { x.tanh() } pub fn _tanh(x: &Tensor) -> Tensor {
x.tanh()
}

View File

@ -9,15 +9,16 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use serde::Deserialize;
use std::path::Path;
use std::fs::File; use std::fs::File;
use std::io::BufReader; use std::io::BufReader;
use serde::Deserialize; use std::path::Path;
/// # Utility to deserialize JSON config files /// # Utility to deserialize JSON config files
pub trait Config<T> pub trait Config<T>
where for<'de> T: Deserialize<'de> { where
for<'de> T: Deserialize<'de>,
{
/// Loads a `Config` object from a JSON file. The format is expected to be aligned with the [Transformers library](https://github.com/huggingface/transformers) configuration files for each model. /// Loads a `Config` object from a JSON file. The format is expected to be aligned with the [Transformers library](https://github.com/huggingface/transformers) configuration files for each model.
/// The parsing will fail if non-optional keys expected by the model are missing. /// The parsing will fail if non-optional keys expected by the model are missing.
/// ///
@ -28,18 +29,17 @@ pub trait Config<T>
/// # Example /// # Example
/// ///
/// ```no_run /// ```no_run
/// use rust_bert::gpt2::Gpt2Config;
/// use rust_bert::Config; /// use rust_bert::Config;
/// use std::path::Path; /// use std::path::Path;
/// use rust_bert::gpt2::Gpt2Config;
/// ///
/// let config_path = Path::new("path/to/config.json"); /// let config_path = Path::new("path/to/config.json");
/// let config = Gpt2Config::from_file(config_path); /// let config = Gpt2Config::from_file(config_path);
/// ``` /// ```
///
fn from_file(path: &Path) -> T { fn from_file(path: &Path) -> T {
let f = File::open(path).expect("Could not open configuration file."); let f = File::open(path).expect("Could not open configuration file.");
let br = BufReader::new(f); let br = BufReader::new(f);
let config: T = serde_json::from_reader(br).expect("could not parse configuration"); let config: T = serde_json::from_reader(br).expect("could not parse configuration");
config config
} }
} }

View File

@ -27,4 +27,4 @@ impl ModuleT for Dropout {
fn forward_t(&self, input: &Tensor, train: bool) -> Tensor { fn forward_t(&self, input: &Tensor, train: bool) -> Tensor {
input.dropout(self.dropout_prob, train) input.dropout(self.dropout_prob, train)
} }
} }

View File

@ -10,9 +10,9 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use tch::nn::{Init, Path, Module};
use tch::Tensor;
use std::borrow::Borrow; use std::borrow::Borrow;
use tch::nn::{Init, Module, Path};
use tch::Tensor;
#[derive(Debug, Clone, Copy)] #[derive(Debug, Clone, Copy)]
pub struct LinearNoBiasConfig { pub struct LinearNoBiasConfig {
@ -32,7 +32,6 @@ pub struct LinearNoBias {
pub ws: Tensor, pub ws: Tensor,
} }
pub fn linear_no_bias<'a, T: Borrow<Path<'a>>>( pub fn linear_no_bias<'a, T: Borrow<Path<'a>>>(
vs: T, vs: T,
in_dim: i64, in_dim: i64,
@ -49,4 +48,4 @@ impl Module for LinearNoBias {
fn forward(&self, xs: &Tensor) -> Tensor { fn forward(&self, xs: &Tensor) -> Tensor {
xs.matmul(&self.ws.tr()) xs.matmul(&self.ws.tr())
} }
} }

View File

@ -1,7 +1,7 @@
pub mod config;
pub mod resources;
pub(crate) mod dropout;
pub(crate) mod activations; pub(crate) mod activations;
pub mod config;
pub(crate) mod dropout;
pub(crate) mod linear; pub(crate) mod linear;
pub mod resources;
pub use config::Config; pub use config::Config;

View File

@ -18,9 +18,9 @@
//! pre-trained models in each model module. //! pre-trained models in each model module.
use lazy_static::lazy_static; use lazy_static::lazy_static;
use std::path::PathBuf;
use reqwest::Client; use reqwest::Client;
use std::{fs, env}; use std::path::PathBuf;
use std::{env, fs};
use tokio::prelude::*; use tokio::prelude::*;
use tokio::runtime::Runtime; use tokio::runtime::Runtime;
use tokio::task; use tokio::task;
@ -47,12 +47,13 @@ impl Resource {
/// # Example /// # Example
/// ///
/// ```no_run /// ```no_run
/// use rust_bert::resources::{Resource, LocalResource}; /// use rust_bert::resources::{LocalResource, Resource};
/// use std::path::PathBuf; /// use std::path::PathBuf;
/// let config_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/config.json")}); /// let config_resource = Resource::Local(LocalResource {
/// local_path: PathBuf::from("path/to/config.json"),
/// });
/// let config_path = config_resource.get_local_path(); /// let config_path = config_resource.get_local_path();
/// ``` /// ```
///
pub fn get_local_path(&self) -> &PathBuf { pub fn get_local_path(&self) -> &PathBuf {
match self { match self {
Resource::Local(resource) => &resource.local_path, Resource::Local(resource) => &resource.local_path,
@ -65,7 +66,7 @@ impl Resource {
#[derive(PartialEq, Clone)] #[derive(PartialEq, Clone)]
pub struct LocalResource { pub struct LocalResource {
/// Local path for the resource /// Local path for the resource
pub local_path: PathBuf pub local_path: PathBuf,
} }
/// # Remote resource /// # Remote resource
@ -93,13 +94,18 @@ impl RemoteResource {
/// # Example /// # Example
/// ///
/// ```no_run /// ```no_run
/// use rust_bert::resources::{Resource, RemoteResource}; /// use rust_bert::resources::{RemoteResource, Resource};
/// use std::path::PathBuf; /// use std::path::PathBuf;
/// let config_resource = Resource::Remote(RemoteResource::new("http://config_json_location", PathBuf::from("path/to/config.json"))); /// let config_resource = Resource::Remote(RemoteResource::new(
/// "http://config_json_location",
/// PathBuf::from("path/to/config.json"),
/// ));
/// ``` /// ```
///
pub fn new(url: &str, target: PathBuf) -> RemoteResource { pub fn new(url: &str, target: PathBuf) -> RemoteResource {
RemoteResource { url: url.to_string(), local_path: target } RemoteResource {
url: url.to_string(),
local_path: target,
}
} }
/// Creates a new RemoteResource from an URL and local name. Will define a local path pointing to /// Creates a new RemoteResource from an URL and local name. Will define a local path pointing to
@ -117,14 +123,12 @@ impl RemoteResource {
/// # Example /// # Example
/// ///
/// ```no_run /// ```no_run
/// use rust_bert::resources::{Resource, RemoteResource}; /// use rust_bert::resources::{RemoteResource, Resource};
/// let model_resource = Resource::Remote(RemoteResource::from_pretrained( /// let model_resource = Resource::Remote(RemoteResource::from_pretrained((
/// ("distilbert-sst2/model.ot", /// "distilbert-sst2/model.ot",
/// "https://cdn.huggingface.co/distilbert-base-uncased-finetuned-sst-2-english-rust_model.ot" /// "https://cdn.huggingface.co/distilbert-base-uncased-finetuned-sst-2-english-rust_model.ot",
/// ) /// )));
/// ));
/// ``` /// ```
///
pub fn from_pretrained(name_url_tuple: (&str, &str)) -> RemoteResource { pub fn from_pretrained(name_url_tuple: (&str, &str)) -> RemoteResource {
let name = name_url_tuple.0; let name = name_url_tuple.0;
let url = name_url_tuple.1.to_string(); let url = name_url_tuple.1.to_string();
@ -171,15 +175,13 @@ fn _get_cache_directory() -> PathBuf {
/// # Example /// # Example
/// ///
/// ```no_run /// ```no_run
/// use rust_bert::resources::{Resource, RemoteResource, download_resource}; /// use rust_bert::resources::{download_resource, RemoteResource, Resource};
/// let model_resource = Resource::Remote(RemoteResource::from_pretrained( /// let model_resource = Resource::Remote(RemoteResource::from_pretrained((
/// ("distilbert-sst2/model.ot", /// "distilbert-sst2/model.ot",
/// "https://cdn.huggingface.co/distilbert-base-uncased-finetuned-sst-2-english-rust_model.ot" /// "https://cdn.huggingface.co/distilbert-base-uncased-finetuned-sst-2-english-rust_model.ot",
/// ) /// )));
/// ));
/// let local_path = download_resource(&model_resource); /// let local_path = download_resource(&model_resource);
/// ``` /// ```
///
pub fn download_resource(resource: &Resource) -> failure::Fallible<&PathBuf> { pub fn download_resource(resource: &Resource) -> failure::Fallible<&PathBuf> {
match resource { match resource {
Resource::Remote(remote_resource) => { Resource::Remote(remote_resource) => {
@ -202,8 +204,6 @@ pub fn download_resource(resource: &Resource) -> failure::Fallible<&PathBuf> {
Ok(resource.get_local_path()) Ok(resource.get_local_path())
} }
Resource::Local(_) => { Resource::Local(_) => Ok(resource.get_local_path()),
Ok(resource.get_local_path())
}
} }
} }

View File

@ -16,7 +16,11 @@ extern crate tch;
pub fn main() -> failure::Fallible<()> { pub fn main() -> failure::Fallible<()> {
let args: Vec<_> = std::env::args().collect(); let args: Vec<_> = std::env::args().collect();
ensure!(args.len() == 3, "usage: {} source.npz destination.ot", args[0]); ensure!(
args.len() == 3,
"usage: {} source.npz destination.ot",
args[0]
);
let source_file = &args[1]; let source_file = &args[1];
let destination_file = &args[2]; let destination_file = &args[2];
@ -24,4 +28,4 @@ pub fn main() -> failure::Fallible<()> {
tch::Tensor::save_multi(&tensors, destination_file)?; tch::Tensor::save_multi(&tensors, destination_file)?;
Ok(()) Ok(())
} }

View File

@ -10,11 +10,10 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use tch::{nn, Tensor}; use crate::common::dropout::Dropout;
use crate::distilbert::distilbert::DistilBertConfig; use crate::distilbert::distilbert::DistilBertConfig;
use tch::kind::Kind::Float; use tch::kind::Kind::Float;
use crate::common::dropout::Dropout; use tch::{nn, Tensor};
#[derive(Debug)] #[derive(Debug)]
pub struct MultiHeadSelfAttention { pub struct MultiHeadSelfAttention {
@ -39,7 +38,7 @@ impl MultiHeadSelfAttention {
let output_attentions = match config.output_attentions { let output_attentions = match config.output_attentions {
Some(value) => value, Some(value) => value,
None => false None => false,
}; };
MultiHeadSelfAttention { MultiHeadSelfAttention {
@ -59,10 +58,19 @@ impl MultiHeadSelfAttention {
} }
fn flatten(&self, x: Tensor, bs: i64, dim_per_head: i64) -> Tensor { fn flatten(&self, x: Tensor, bs: i64, dim_per_head: i64) -> Tensor {
x.transpose(1, 2).contiguous().view((bs, -1, &self.n_heads * dim_per_head)) x.transpose(1, 2)
.contiguous()
.view((bs, -1, &self.n_heads * dim_per_head))
} }
pub fn forward_t(&self, query: &Tensor, key: &Tensor, value: &Tensor, mask: &Option<Tensor>, train: bool) -> (Tensor, Option<Tensor>) { pub fn forward_t(
&self,
query: &Tensor,
key: &Tensor,
value: &Tensor,
mask: &Option<Tensor>,
train: bool,
) -> (Tensor, Option<Tensor>) {
let bs = query.size()[0]; let bs = query.size()[0];
let k_length = key.size()[1]; let k_length = key.size()[1];
@ -73,14 +81,19 @@ impl MultiHeadSelfAttention {
let scores = if let Some(mask) = mask { let scores = if let Some(mask) = mask {
let unmasked_scores = q.matmul(&k.transpose(2, 3)); let unmasked_scores = q.matmul(&k.transpose(2, 3));
let mask = mask.le1(&(mask.zeros_like() + 0.1)).view((bs, 1i64, 1i64, k_length)).expand_as(&unmasked_scores); let mask = mask
.le1(&(mask.zeros_like() + 0.1))
.view((bs, 1i64, 1i64, k_length))
.expand_as(&unmasked_scores);
unmasked_scores.masked_fill(&mask, std::f64::NEG_INFINITY) unmasked_scores.masked_fill(&mask, std::f64::NEG_INFINITY)
} else { } else {
q.matmul(&k.transpose(2, 3)) q.matmul(&k.transpose(2, 3))
}; };
let weights = scores.softmax(-1, Float).apply_t(&self.dropout, train); let weights = scores.softmax(-1, Float).apply_t(&self.dropout, train);
let context = self.flatten(weights.matmul(&v), bs, self.dim_per_head).apply(&self.out_lin); let context = self
.flatten(weights.matmul(&v), bs, self.dim_per_head)
.apply(&self.out_lin);
if !self.output_attentions { if !self.output_attentions {
(context, None) (context, None)

View File

@ -12,13 +12,13 @@
extern crate tch; extern crate tch;
use std::collections::HashMap;
use serde::{Deserialize, Serialize};
use crate::distilbert::embeddings::DistilBertEmbedding;
use crate::distilbert::transformer::Transformer;
use self::tch::{nn, Tensor}; use self::tch::{nn, Tensor};
use crate::common::dropout::Dropout; use crate::common::dropout::Dropout;
use crate::distilbert::embeddings::DistilBertEmbedding;
use crate::distilbert::transformer::Transformer;
use crate::Config; use crate::Config;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
/// # DistilBERT Pretrained model weight files /// # DistilBERT Pretrained model weight files
pub struct DistilBertModelResources; pub struct DistilBertModelResources;
@ -31,29 +31,56 @@ pub struct DistilBertVocabResources;
impl DistilBertModelResources { impl DistilBertModelResources {
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format. /// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
pub const DISTIL_BERT_SST2: (&'static str, &'static str) = ("distilbert-sst2/model.ot", "https://cdn.huggingface.co/distilbert-base-uncased-finetuned-sst-2-english-rust_model.ot"); pub const DISTIL_BERT_SST2: (&'static str, &'static str) = (
"distilbert-sst2/model.ot",
"https://cdn.huggingface.co/distilbert-base-uncased-finetuned-sst-2-english-rust_model.ot",
);
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format. /// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
pub const DISTIL_BERT: (&'static str, &'static str) = ("distilbert/model.ot", "https://cdn.huggingface.co/distilbert-base-uncased-rust_model.ot"); pub const DISTIL_BERT: (&'static str, &'static str) = (
"distilbert/model.ot",
"https://cdn.huggingface.co/distilbert-base-uncased-rust_model.ot",
);
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format. /// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
pub const DISTIL_BERT_SQUAD: (&'static str, &'static str) = ("distilbert-qa/model.ot", "https://cdn.huggingface.co/distilbert-base-cased-distilled-squad-rust_model.ot"); pub const DISTIL_BERT_SQUAD: (&'static str, &'static str) = (
"distilbert-qa/model.ot",
"https://cdn.huggingface.co/distilbert-base-cased-distilled-squad-rust_model.ot",
);
} }
impl DistilBertConfigResources { impl DistilBertConfigResources {
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format. /// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
pub const DISTIL_BERT_SST2: (&'static str, &'static str) = ("distilbert-sst2/config.json", "https://cdn.huggingface.co/distilbert-base-uncased-finetuned-sst-2-english-config.json"); pub const DISTIL_BERT_SST2: (&'static str, &'static str) = (
"distilbert-sst2/config.json",
"https://cdn.huggingface.co/distilbert-base-uncased-finetuned-sst-2-english-config.json",
);
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format. /// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
pub const DISTIL_BERT: (&'static str, &'static str) = ("distilbert/config.json", "https://cdn.huggingface.co/distilbert-base-uncased-config.json"); pub const DISTIL_BERT: (&'static str, &'static str) = (
"distilbert/config.json",
"https://cdn.huggingface.co/distilbert-base-uncased-config.json",
);
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format. /// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
pub const DISTIL_BERT_SQUAD: (&'static str, &'static str) = ("distilbert-qa/config.json", "https://cdn.huggingface.co/distilbert-base-cased-distilled-squad-config.json"); pub const DISTIL_BERT_SQUAD: (&'static str, &'static str) = (
"distilbert-qa/config.json",
"https://cdn.huggingface.co/distilbert-base-cased-distilled-squad-config.json",
);
} }
impl DistilBertVocabResources { impl DistilBertVocabResources {
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format. /// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
pub const DISTIL_BERT_SST2: (&'static str, &'static str) = ("distilbert-sst2/vocab.txt", "https://cdn.huggingface.co/distilbert-base-uncased-finetuned-sst-2-english-vocab.txt"); pub const DISTIL_BERT_SST2: (&'static str, &'static str) = (
"distilbert-sst2/vocab.txt",
"https://cdn.huggingface.co/distilbert-base-uncased-finetuned-sst-2-english-vocab.txt",
);
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format. /// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
pub const DISTIL_BERT: (&'static str, &'static str) = ("distilbert/vocab.txt", "https://cdn.huggingface.co/bert-base-uncased-vocab.txt"); pub const DISTIL_BERT: (&'static str, &'static str) = (
"distilbert/vocab.txt",
"https://cdn.huggingface.co/bert-base-uncased-vocab.txt",
);
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format. /// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
pub const DISTIL_BERT_SQUAD: (&'static str, &'static str) = ("distilbert-qa/vocab.txt", "https://cdn.huggingface.co/bert-large-cased-vocab.txt"); pub const DISTIL_BERT_SQUAD: (&'static str, &'static str) = (
"distilbert-qa/vocab.txt",
"https://cdn.huggingface.co/bert-large-cased-vocab.txt",
);
} }
#[allow(non_camel_case_types)] #[allow(non_camel_case_types)]
@ -118,10 +145,10 @@ impl DistilBertModel {
/// # Example /// # Example
/// ///
/// ```no_run /// ```no_run
/// use tch::{nn, Device}; /// use rust_bert::distilbert::{DistilBertConfig, DistilBertModel};
/// use rust_bert::Config; /// use rust_bert::Config;
/// use std::path::Path; /// use std::path::Path;
/// use rust_bert::distilbert::{DistilBertConfig, DistilBertModel}; /// use tch::{nn, Device};
/// ///
/// let config_path = Path::new("path/to/config.json"); /// let config_path = Path::new("path/to/config.json");
/// let device = Device::Cpu; /// let device = Device::Cpu;
@ -129,12 +156,14 @@ impl DistilBertModel {
/// let config = DistilBertConfig::from_file(config_path); /// let config = DistilBertConfig::from_file(config_path);
/// let distil_bert: DistilBertModel = DistilBertModel::new(&(&p.root() / "distilbert"), &config); /// let distil_bert: DistilBertModel = DistilBertModel::new(&(&p.root() / "distilbert"), &config);
/// ``` /// ```
///
pub fn new(p: &nn::Path, config: &DistilBertConfig) -> DistilBertModel { pub fn new(p: &nn::Path, config: &DistilBertConfig) -> DistilBertModel {
let p = &(p / "distilbert"); let p = &(p / "distilbert");
let embeddings = DistilBertEmbedding::new(&(p / "embeddings"), config); let embeddings = DistilBertEmbedding::new(&(p / "embeddings"), config);
let transformer = Transformer::new(&(p / "transformer"), config); let transformer = Transformer::new(&(p / "transformer"), config);
DistilBertModel { embeddings, transformer } DistilBertModel {
embeddings,
transformer,
}
} }
/// Forward pass through the model /// Forward pass through the model
@ -155,45 +184,49 @@ impl DistilBertModel {
/// # Example /// # Example
/// ///
/// ```no_run /// ```no_run
///# use tch::{nn, Device, Tensor, no_grad}; /// # use tch::{nn, Device, Tensor, no_grad};
///# use rust_bert::Config; /// # use rust_bert::Config;
///# use std::path::Path; /// # use std::path::Path;
///# use tch::kind::Kind::Int64; /// # use tch::kind::Kind::Int64;
/// use rust_bert::distilbert::{DistilBertConfig, DistilBertModel}; /// use rust_bert::distilbert::{DistilBertConfig, DistilBertModel};
///# let config_path = Path::new("path/to/config.json"); /// # let config_path = Path::new("path/to/config.json");
///# let vocab_path = Path::new("path/to/vocab.txt"); /// # let vocab_path = Path::new("path/to/vocab.txt");
///# let device = Device::Cpu; /// # let device = Device::Cpu;
///# let vs = nn::VarStore::new(device); /// # let vs = nn::VarStore::new(device);
///# let config = DistilBertConfig::from_file(config_path); /// # let config = DistilBertConfig::from_file(config_path);
///# let distilbert_model: DistilBertModel = DistilBertModel::new(&vs.root(), &config); /// # let distilbert_model: DistilBertModel = DistilBertModel::new(&vs.root(), &config);
/// let (batch_size, sequence_length) = (64, 128); /// let (batch_size, sequence_length) = (64, 128);
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device)); /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device)); /// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
///
/// let (output, all_hidden_states, all_attentions) = no_grad(|| {
/// distilbert_model
/// .forward_t(Some(input_tensor),
/// Some(mask),
/// None,
/// false).unwrap()
/// });
/// ///
/// let (output, all_hidden_states, all_attentions) = no_grad(|| {
/// distilbert_model
/// .forward_t(Some(input_tensor), Some(mask), None, false)
/// .unwrap()
/// });
/// ``` /// ```
/// pub fn forward_t(
pub fn forward_t(&self, input: Option<Tensor>, mask: Option<Tensor>, input_embeds: Option<Tensor>, train: bool) &self,
-> Result<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> { input: Option<Tensor>,
mask: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool,
) -> Result<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
let input_embeddings = match input { let input_embeddings = match input {
Some(input_value) => match input_embeds { Some(input_value) => match input_embeds {
Some(_) => { return Err("Only one of input ids or input embeddings may be set"); } Some(_) => {
None => input_value.apply_t(&self.embeddings, train) return Err("Only one of input ids or input embeddings may be set");
} }
None => input_value.apply_t(&self.embeddings, train),
},
None => match input_embeds { None => match input_embeds {
Some(embeds) => embeds, Some(embeds) => embeds,
None => { return Err("At least one of input ids or input embeddings must be set"); } None => {
} return Err("At least one of input ids or input embeddings must be set");
}
},
}; };
let transformer_output = (&self.transformer).forward_t(&input_embeddings, mask, train); let transformer_output = (&self.transformer).forward_t(&input_embeddings, mask, train);
Ok(transformer_output) Ok(transformer_output)
} }
@ -223,28 +256,47 @@ impl DistilBertModelClassifier {
/// # Example /// # Example
/// ///
/// ```no_run /// ```no_run
/// use tch::{nn, Device}; /// use rust_bert::distilbert::{DistilBertConfig, DistilBertModelClassifier};
/// use rust_bert::Config; /// use rust_bert::Config;
/// use std::path::Path; /// use std::path::Path;
/// use rust_bert::distilbert::{DistilBertConfig, DistilBertModelClassifier}; /// use tch::{nn, Device};
/// ///
/// let config_path = Path::new("path/to/config.json"); /// let config_path = Path::new("path/to/config.json");
/// let device = Device::Cpu; /// let device = Device::Cpu;
/// let p = nn::VarStore::new(device); /// let p = nn::VarStore::new(device);
/// let config = DistilBertConfig::from_file(config_path); /// let config = DistilBertConfig::from_file(config_path);
/// let distil_bert: DistilBertModelClassifier = DistilBertModelClassifier::new(&(&p.root() / "distilbert"), &config); /// let distil_bert: DistilBertModelClassifier =
/// DistilBertModelClassifier::new(&(&p.root() / "distilbert"), &config);
/// ``` /// ```
///
pub fn new(p: &nn::Path, config: &DistilBertConfig) -> DistilBertModelClassifier { pub fn new(p: &nn::Path, config: &DistilBertConfig) -> DistilBertModelClassifier {
let distil_bert_model = DistilBertModel::new(&p, config); let distil_bert_model = DistilBertModel::new(&p, config);
let num_labels = config.id2label.as_ref().expect("id2label must be provided for classifiers").len() as i64; let num_labels = config
.id2label
.as_ref()
.expect("id2label must be provided for classifiers")
.len() as i64;
let pre_classifier = nn::linear(&(p / "pre_classifier"), config.dim, config.dim, Default::default()); let pre_classifier = nn::linear(
let classifier = nn::linear(&(p / "classifier"), config.dim, num_labels, Default::default()); &(p / "pre_classifier"),
config.dim,
config.dim,
Default::default(),
);
let classifier = nn::linear(
&(p / "classifier"),
config.dim,
num_labels,
Default::default(),
);
let dropout = Dropout::new(config.seq_classif_dropout); let dropout = Dropout::new(config.seq_classif_dropout);
DistilBertModelClassifier { distil_bert_model, pre_classifier, classifier, dropout } DistilBertModelClassifier {
distil_bert_model,
pre_classifier,
classifier,
dropout,
}
} }
/// Forward pass through the model /// Forward pass through the model
@ -265,17 +317,17 @@ impl DistilBertModelClassifier {
/// # Example /// # Example
/// ///
/// ```no_run /// ```no_run
///# use tch::{nn, Device, Tensor, no_grad}; /// # use tch::{nn, Device, Tensor, no_grad};
///# use rust_bert::Config; /// # use rust_bert::Config;
///# use std::path::Path; /// # use std::path::Path;
///# use tch::kind::Kind::Int64; /// # use tch::kind::Kind::Int64;
/// use rust_bert::distilbert::{DistilBertConfig, DistilBertModelClassifier}; /// use rust_bert::distilbert::{DistilBertConfig, DistilBertModelClassifier};
///# let config_path = Path::new("path/to/config.json"); /// # let config_path = Path::new("path/to/config.json");
///# let vocab_path = Path::new("path/to/vocab.txt"); /// # let vocab_path = Path::new("path/to/vocab.txt");
///# let device = Device::Cpu; /// # let device = Device::Cpu;
///# let vs = nn::VarStore::new(device); /// # let vs = nn::VarStore::new(device);
///# let config = DistilBertConfig::from_file(config_path); /// # let config = DistilBertConfig::from_file(config_path);
///# let distilbert_model: DistilBertModelClassifier = DistilBertModelClassifier::new(&vs.root(), &config); /// # let distilbert_model: DistilBertModelClassifier = DistilBertModelClassifier::new(&vs.root(), &config);
/// let (batch_size, sequence_length) = (64, 128); /// let (batch_size, sequence_length) = (64, 128);
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device)); /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device)); /// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
@ -287,15 +339,22 @@ impl DistilBertModelClassifier {
/// None, /// None,
/// false).unwrap() /// false).unwrap()
/// }); /// });
///
/// ``` /// ```
/// pub fn forward_t(
pub fn forward_t(&self, input: Option<Tensor>, mask: Option<Tensor>, input_embeds: Option<Tensor>, train: bool) &self,
-> Result<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> { input: Option<Tensor>,
let (output, all_hidden_states, all_attentions) = match self.distil_bert_model.forward_t(input, mask, input_embeds, train) { mask: Option<Tensor>,
Ok(value) => value, input_embeds: Option<Tensor>,
Err(err) => return Err(err) train: bool,
}; ) -> Result<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
let (output, all_hidden_states, all_attentions) =
match self
.distil_bert_model
.forward_t(input, mask, input_embeds, train)
{
Ok(value) => value,
Err(err) => return Err(err),
};
let output = output let output = output
.select(1, 0) .select(1, 0)
@ -322,7 +381,6 @@ pub struct DistilBertModelMaskedLM {
vocab_projector: nn::Linear, vocab_projector: nn::Linear,
} }
impl DistilBertModelMaskedLM { impl DistilBertModelMaskedLM {
/// Build a new `DistilBertModelMaskedLM` for sequence classification /// Build a new `DistilBertModelMaskedLM` for sequence classification
/// ///
@ -334,10 +392,10 @@ impl DistilBertModelMaskedLM {
/// # Example /// # Example
/// ///
/// ```no_run /// ```no_run
/// use tch::{nn, Device}; /// use rust_bert::distilbert::{DistilBertConfig, DistilBertModelMaskedLM};
/// use rust_bert::Config; /// use rust_bert::Config;
/// use std::path::Path; /// use std::path::Path;
/// use rust_bert::distilbert::{DistilBertConfig, DistilBertModelMaskedLM}; /// use tch::{nn, Device};
/// ///
/// let config_path = Path::new("path/to/config.json"); /// let config_path = Path::new("path/to/config.json");
/// let device = Device::Cpu; /// let device = Device::Cpu;
@ -345,15 +403,33 @@ impl DistilBertModelMaskedLM {
/// let config = DistilBertConfig::from_file(config_path); /// let config = DistilBertConfig::from_file(config_path);
/// let distil_bert = DistilBertModelMaskedLM::new(&(&p.root() / "distilbert"), &config); /// let distil_bert = DistilBertModelMaskedLM::new(&(&p.root() / "distilbert"), &config);
/// ``` /// ```
///
pub fn new(p: &nn::Path, config: &DistilBertConfig) -> DistilBertModelMaskedLM { pub fn new(p: &nn::Path, config: &DistilBertConfig) -> DistilBertModelMaskedLM {
let distil_bert_model = DistilBertModel::new(&p, config); let distil_bert_model = DistilBertModel::new(&p, config);
let vocab_transform = nn::linear(&(p / "vocab_transform"), config.dim, config.dim, Default::default()); let vocab_transform = nn::linear(
let layer_norm_config = nn::LayerNormConfig { eps: 1e-12, ..Default::default() }; &(p / "vocab_transform"),
let vocab_layer_norm = nn::layer_norm(p / "vocab_layer_norm", vec![config.dim], layer_norm_config); config.dim,
let vocab_projector = nn::linear(&(p / "vocab_projector"), config.dim, config.vocab_size, Default::default()); config.dim,
Default::default(),
);
let layer_norm_config = nn::LayerNormConfig {
eps: 1e-12,
..Default::default()
};
let vocab_layer_norm =
nn::layer_norm(p / "vocab_layer_norm", vec![config.dim], layer_norm_config);
let vocab_projector = nn::linear(
&(p / "vocab_projector"),
config.dim,
config.vocab_size,
Default::default(),
);
DistilBertModelMaskedLM { distil_bert_model, vocab_transform, vocab_layer_norm, vocab_projector } DistilBertModelMaskedLM {
distil_bert_model,
vocab_transform,
vocab_layer_norm,
vocab_projector,
}
} }
/// Forward pass through the model /// Forward pass through the model
@ -374,37 +450,42 @@ impl DistilBertModelMaskedLM {
/// # Example /// # Example
/// ///
/// ```no_run /// ```no_run
///# use tch::{nn, Device, Tensor, no_grad}; /// # use tch::{nn, Device, Tensor, no_grad};
///# use rust_bert::Config; /// # use rust_bert::Config;
///# use std::path::Path; /// # use std::path::Path;
///# use tch::kind::Kind::Int64; /// # use tch::kind::Kind::Int64;
/// use rust_bert::distilbert::{DistilBertConfig, DistilBertModelMaskedLM}; /// use rust_bert::distilbert::{DistilBertConfig, DistilBertModelMaskedLM};
///# let config_path = Path::new("path/to/config.json"); /// # let config_path = Path::new("path/to/config.json");
///# let vocab_path = Path::new("path/to/vocab.txt"); /// # let vocab_path = Path::new("path/to/vocab.txt");
///# let device = Device::Cpu; /// # let device = Device::Cpu;
///# let vs = nn::VarStore::new(device); /// # let vs = nn::VarStore::new(device);
///# let config = DistilBertConfig::from_file(config_path); /// # let config = DistilBertConfig::from_file(config_path);
///# let distilbert_model = DistilBertModelMaskedLM::new(&vs.root(), &config); /// # let distilbert_model = DistilBertModelMaskedLM::new(&vs.root(), &config);
/// let (batch_size, sequence_length) = (64, 128); /// let (batch_size, sequence_length) = (64, 128);
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device)); /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device)); /// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
///
/// let (output, all_hidden_states, all_attentions) = no_grad(|| {
/// distilbert_model
/// .forward_t(Some(input_tensor),
/// Some(mask),
/// None,
/// false).unwrap()
/// });
/// ///
/// let (output, all_hidden_states, all_attentions) = no_grad(|| {
/// distilbert_model
/// .forward_t(Some(input_tensor), Some(mask), None, false)
/// .unwrap()
/// });
/// ``` /// ```
/// pub fn forward_t(
pub fn forward_t(&self, input: Option<Tensor>, mask: Option<Tensor>, input_embeds: Option<Tensor>, train: bool) &self,
-> Result<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> { input: Option<Tensor>,
let (output, all_hidden_states, all_attentions) = match self.distil_bert_model.forward_t(input, mask, input_embeds, train) { mask: Option<Tensor>,
Ok(value) => value, input_embeds: Option<Tensor>,
Err(err) => return Err(err) train: bool,
}; ) -> Result<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
let (output, all_hidden_states, all_attentions) =
match self
.distil_bert_model
.forward_t(input, mask, input_embeds, train)
{
Ok(value) => value,
Err(err) => return Err(err),
};
let output = output let output = output
.apply(&self.vocab_transform) .apply(&self.vocab_transform)
@ -440,10 +521,10 @@ impl DistilBertForQuestionAnswering {
/// # Example /// # Example
/// ///
/// ```no_run /// ```no_run
/// use tch::{nn, Device}; /// use rust_bert::distilbert::{DistilBertConfig, DistilBertForQuestionAnswering};
/// use rust_bert::Config; /// use rust_bert::Config;
/// use std::path::Path; /// use std::path::Path;
/// use rust_bert::distilbert::{DistilBertConfig, DistilBertForQuestionAnswering}; /// use tch::{nn, Device};
/// ///
/// let config_path = Path::new("path/to/config.json"); /// let config_path = Path::new("path/to/config.json");
/// let device = Device::Cpu; /// let device = Device::Cpu;
@ -451,13 +532,16 @@ impl DistilBertForQuestionAnswering {
/// let config = DistilBertConfig::from_file(config_path); /// let config = DistilBertConfig::from_file(config_path);
/// let distil_bert = DistilBertForQuestionAnswering::new(&(&p.root() / "distilbert"), &config); /// let distil_bert = DistilBertForQuestionAnswering::new(&(&p.root() / "distilbert"), &config);
/// ``` /// ```
///
pub fn new(p: &nn::Path, config: &DistilBertConfig) -> DistilBertForQuestionAnswering { pub fn new(p: &nn::Path, config: &DistilBertConfig) -> DistilBertForQuestionAnswering {
let distil_bert_model = DistilBertModel::new(&p, config); let distil_bert_model = DistilBertModel::new(&p, config);
let qa_outputs = nn::linear(&(p / "qa_outputs"), config.dim, 2, Default::default()); let qa_outputs = nn::linear(&(p / "qa_outputs"), config.dim, 2, Default::default());
let dropout = Dropout::new(config.qa_dropout); let dropout = Dropout::new(config.qa_dropout);
DistilBertForQuestionAnswering { distil_bert_model, qa_outputs, dropout } DistilBertForQuestionAnswering {
distil_bert_model,
qa_outputs,
dropout,
}
} }
/// Forward pass through the model /// Forward pass through the model
@ -479,52 +563,50 @@ impl DistilBertForQuestionAnswering {
/// # Example /// # Example
/// ///
/// ```no_run /// ```no_run
///# use tch::{nn, Device, Tensor, no_grad}; /// # use tch::{nn, Device, Tensor, no_grad};
///# use rust_bert::Config; /// # use rust_bert::Config;
///# use std::path::Path; /// # use std::path::Path;
///# use tch::kind::Kind::Int64; /// # use tch::kind::Kind::Int64;
/// use rust_bert::distilbert::{DistilBertConfig, DistilBertForQuestionAnswering}; /// use rust_bert::distilbert::{DistilBertConfig, DistilBertForQuestionAnswering};
///# let config_path = Path::new("path/to/config.json"); /// # let config_path = Path::new("path/to/config.json");
///# let vocab_path = Path::new("path/to/vocab.txt"); /// # let vocab_path = Path::new("path/to/vocab.txt");
///# let device = Device::Cpu; /// # let device = Device::Cpu;
///# let vs = nn::VarStore::new(device); /// # let vs = nn::VarStore::new(device);
///# let config = DistilBertConfig::from_file(config_path); /// # let config = DistilBertConfig::from_file(config_path);
///# let distilbert_model = DistilBertForQuestionAnswering::new(&vs.root(), &config); /// # let distilbert_model = DistilBertForQuestionAnswering::new(&vs.root(), &config);
/// let (batch_size, sequence_length) = (64, 128); /// let (batch_size, sequence_length) = (64, 128);
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device)); /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device)); /// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
///
/// let (start_scores, end_score, all_hidden_states, all_attentions) = no_grad(|| {
/// distilbert_model
/// .forward_t(Some(input_tensor),
/// Some(mask),
/// None,
/// false).unwrap()
/// });
/// ///
/// let (start_scores, end_score, all_hidden_states, all_attentions) = no_grad(|| {
/// distilbert_model
/// .forward_t(Some(input_tensor), Some(mask), None, false)
/// .unwrap()
/// });
/// ``` /// ```
/// pub fn forward_t(
pub fn forward_t(&self, &self,
input: Option<Tensor>, input: Option<Tensor>,
mask: Option<Tensor>, mask: Option<Tensor>,
input_embeds: Option<Tensor>, input_embeds: Option<Tensor>,
train: bool) train: bool,
-> Result<(Tensor, Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> { ) -> Result<(Tensor, Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
let (output, all_hidden_states, all_attentions) = match self.distil_bert_model.forward_t(input, mask, input_embeds, train) { let (output, all_hidden_states, all_attentions) =
Ok(value) => value, match self
Err(err) => return Err(err) .distil_bert_model
}; .forward_t(input, mask, input_embeds, train)
{
Ok(value) => value,
Err(err) => return Err(err),
};
let output = output let output = output.apply_t(&self.dropout, train).apply(&self.qa_outputs);
.apply_t(&self.dropout, train)
.apply(&self.qa_outputs);
let logits = output.split(1, -1); let logits = output.split(1, -1);
let (start_logits, end_logits) = (&logits[0], &logits[1]); let (start_logits, end_logits) = (&logits[0], &logits[1]);
let start_logits = start_logits.squeeze1(-1); let start_logits = start_logits.squeeze1(-1);
let end_logits = end_logits.squeeze1(-1); let end_logits = end_logits.squeeze1(-1);
Ok((start_logits, end_logits, all_hidden_states, all_attentions)) Ok((start_logits, end_logits, all_hidden_states, all_attentions))
} }
} }
@ -552,10 +634,10 @@ impl DistilBertForTokenClassification {
/// # Example /// # Example
/// ///
/// ```no_run /// ```no_run
/// use tch::{nn, Device}; /// use rust_bert::distilbert::{DistilBertConfig, DistilBertForTokenClassification};
/// use rust_bert::Config; /// use rust_bert::Config;
/// use std::path::Path; /// use std::path::Path;
/// use rust_bert::distilbert::{DistilBertConfig, DistilBertForTokenClassification}; /// use tch::{nn, Device};
/// ///
/// let config_path = Path::new("path/to/config.json"); /// let config_path = Path::new("path/to/config.json");
/// let device = Device::Cpu; /// let device = Device::Cpu;
@ -563,16 +645,28 @@ impl DistilBertForTokenClassification {
/// let config = DistilBertConfig::from_file(config_path); /// let config = DistilBertConfig::from_file(config_path);
/// let distil_bert = DistilBertForTokenClassification::new(&(&p.root() / "distilbert"), &config); /// let distil_bert = DistilBertForTokenClassification::new(&(&p.root() / "distilbert"), &config);
/// ``` /// ```
///
pub fn new(p: &nn::Path, config: &DistilBertConfig) -> DistilBertForTokenClassification { pub fn new(p: &nn::Path, config: &DistilBertConfig) -> DistilBertForTokenClassification {
let distil_bert_model = DistilBertModel::new(&p, config); let distil_bert_model = DistilBertModel::new(&p, config);
let num_labels = config.id2label.as_ref().expect("id2label must be provided for classifiers").len() as i64; let num_labels = config
.id2label
.as_ref()
.expect("id2label must be provided for classifiers")
.len() as i64;
let classifier = nn::linear(&(p / "classifier"), config.dim, num_labels, Default::default()); let classifier = nn::linear(
&(p / "classifier"),
config.dim,
num_labels,
Default::default(),
);
let dropout = Dropout::new(config.seq_classif_dropout); let dropout = Dropout::new(config.seq_classif_dropout);
DistilBertForTokenClassification { distil_bert_model, classifier, dropout } DistilBertForTokenClassification {
distil_bert_model,
classifier,
dropout,
}
} }
/// Forward pass through the model /// Forward pass through the model
@ -593,41 +687,44 @@ impl DistilBertForTokenClassification {
/// # Example /// # Example
/// ///
/// ```no_run /// ```no_run
///# use tch::{nn, Device, Tensor, no_grad}; /// # use tch::{nn, Device, Tensor, no_grad};
///# use rust_bert::Config; /// # use rust_bert::Config;
///# use std::path::Path; /// # use std::path::Path;
///# use tch::kind::Kind::Int64; /// # use tch::kind::Kind::Int64;
/// use rust_bert::distilbert::{DistilBertConfig, DistilBertForTokenClassification}; /// use rust_bert::distilbert::{DistilBertConfig, DistilBertForTokenClassification};
///# let config_path = Path::new("path/to/config.json"); /// # let config_path = Path::new("path/to/config.json");
///# let vocab_path = Path::new("path/to/vocab.txt"); /// # let vocab_path = Path::new("path/to/vocab.txt");
///# let device = Device::Cpu; /// # let device = Device::Cpu;
///# let vs = nn::VarStore::new(device); /// # let vs = nn::VarStore::new(device);
///# let config = DistilBertConfig::from_file(config_path); /// # let config = DistilBertConfig::from_file(config_path);
///# let distilbert_model = DistilBertForTokenClassification::new(&vs.root(), &config); /// # let distilbert_model = DistilBertForTokenClassification::new(&vs.root(), &config);
/// let (batch_size, sequence_length) = (64, 128); /// let (batch_size, sequence_length) = (64, 128);
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device)); /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device)); /// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
///
/// let (output, all_hidden_states, all_attentions) = no_grad(|| {
/// distilbert_model
/// .forward_t(Some(input_tensor),
/// Some(mask),
/// None,
/// false).unwrap()
/// });
/// ///
/// let (output, all_hidden_states, all_attentions) = no_grad(|| {
/// distilbert_model
/// .forward_t(Some(input_tensor), Some(mask), None, false)
/// .unwrap()
/// });
/// ``` /// ```
/// pub fn forward_t(
pub fn forward_t(&self, input: Option<Tensor>, mask: Option<Tensor>, input_embeds: Option<Tensor>, train: bool) &self,
-> Result<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> { input: Option<Tensor>,
let (output, all_hidden_states, all_attentions) = match self.distil_bert_model.forward_t(input, mask, input_embeds, train) { mask: Option<Tensor>,
Ok(value) => value, input_embeds: Option<Tensor>,
Err(err) => return Err(err) train: bool,
}; ) -> Result<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
let (output, all_hidden_states, all_attentions) =
match self
.distil_bert_model
.forward_t(input, mask, input_embeds, train)
{
Ok(value) => value,
Err(err) => return Err(err),
};
let output = output let output = output.apply_t(&self.dropout, train).apply(&self.classifier);
.apply_t(&self.dropout, train)
.apply(&self.classifier);
Ok((output, all_hidden_states, all_attentions)) Ok((output, all_hidden_states, all_attentions))
} }

View File

@ -10,22 +10,26 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use tch::{nn, Tensor, Kind, Device};
use tch::nn::{ModuleT, embedding, EmbeddingConfig};
use crate::distilbert::distilbert::DistilBertConfig;
use crate::common::dropout::Dropout; use crate::common::dropout::Dropout;
use crate::distilbert::distilbert::DistilBertConfig;
use tch::kind::Kind::Float; use tch::kind::Kind::Float;
use tch::nn::{embedding, EmbeddingConfig, ModuleT};
use tch::{nn, Device, Kind, Tensor};
fn create_sinusoidal_embeddings(config: &DistilBertConfig, device: Device) -> nn::Embedding { fn create_sinusoidal_embeddings(config: &DistilBertConfig, device: Device) -> nn::Embedding {
let mut sinusoidal_embedding: Vec<Tensor> = Vec::with_capacity(config.max_position_embeddings as usize); let mut sinusoidal_embedding: Vec<Tensor> =
Vec::with_capacity(config.max_position_embeddings as usize);
for pos in 0..config.max_position_embeddings { for pos in 0..config.max_position_embeddings {
let mut temp_vec: Vec<f64> = Vec::with_capacity(config.dim as usize); let mut temp_vec: Vec<f64> = Vec::with_capacity(config.dim as usize);
for j in 0..config.dim { for j in 0..config.dim {
if j % 2 == 0 { if j % 2 == 0 {
temp_vec.push((pos as f64 / 10000f64.powf((2 * (j / 2)) as f64 / config.dim as f64)).sin()); temp_vec.push(
(pos as f64 / 10000f64.powf((2 * (j / 2)) as f64 / config.dim as f64)).sin(),
);
} else { } else {
temp_vec.push((pos as f64 / 10000f64.powf((2 * (j / 2)) as f64 / config.dim as f64)).cos()); temp_vec.push(
(pos as f64 / 10000f64.powf((2 * (j / 2)) as f64 / config.dim as f64)).cos(),
);
} }
} }
let temp_vec = Tensor::of_slice(&temp_vec); let temp_vec = Tensor::of_slice(&temp_vec);
@ -35,17 +39,21 @@ fn create_sinusoidal_embeddings(config: &DistilBertConfig, device: Device) -> nn
.to_kind(Float) .to_kind(Float)
.to_device(device); .to_device(device);
let embedding_config = EmbeddingConfig { padding_idx: 0, ..Default::default() }; let embedding_config = EmbeddingConfig {
let mut embeddings = embedding(&nn::VarStore::new(device).root(), padding_idx: 0,
config.max_position_embeddings, ..Default::default()
config.dim, };
embedding_config); let mut embeddings = embedding(
&nn::VarStore::new(device).root(),
config.max_position_embeddings,
config.dim,
embedding_config,
);
embeddings.ws = sinusoidal_embedding; embeddings.ws = sinusoidal_embedding;
embeddings embeddings
} }
#[derive(Debug)] #[derive(Debug)]
pub struct DistilBertEmbedding { pub struct DistilBertEmbedding {
word_embeddings: nn::Embedding, word_embeddings: nn::Embedding,
@ -56,24 +64,40 @@ pub struct DistilBertEmbedding {
impl DistilBertEmbedding { impl DistilBertEmbedding {
pub fn new(p: &nn::Path, config: &DistilBertConfig) -> DistilBertEmbedding { pub fn new(p: &nn::Path, config: &DistilBertConfig) -> DistilBertEmbedding {
let embedding_config = EmbeddingConfig { padding_idx: 0, ..Default::default() }; let embedding_config = EmbeddingConfig {
padding_idx: 0,
let word_embeddings: nn::Embedding = embedding(p / "word_embeddings", ..Default::default()
config.vocab_size,
config.dim,
embedding_config);
let position_embeddings: nn::Embedding = match config.sinusoidal_pos_embds {
false => embedding(p / "position_embeddings",
config.max_position_embeddings,
config.dim,
embedding_config),
true => create_sinusoidal_embeddings(&config, p.device())
}; };
let layer_norm_config = nn::LayerNormConfig { eps: 1e-12, ..Default::default() };
let layer_norm: nn::LayerNorm = nn::layer_norm(p / "LayerNorm", vec![config.dim], layer_norm_config); let word_embeddings: nn::Embedding = embedding(
p / "word_embeddings",
config.vocab_size,
config.dim,
embedding_config,
);
let position_embeddings: nn::Embedding = match config.sinusoidal_pos_embds {
false => embedding(
p / "position_embeddings",
config.max_position_embeddings,
config.dim,
embedding_config,
),
true => create_sinusoidal_embeddings(&config, p.device()),
};
let layer_norm_config = nn::LayerNormConfig {
eps: 1e-12,
..Default::default()
};
let layer_norm: nn::LayerNorm =
nn::layer_norm(p / "LayerNorm", vec![config.dim], layer_norm_config);
let dropout: Dropout = Dropout::new(config.dropout); let dropout: Dropout = Dropout::new(config.dropout);
DistilBertEmbedding { word_embeddings, position_embeddings, layer_norm, dropout } DistilBertEmbedding {
word_embeddings,
position_embeddings,
layer_norm,
dropout,
}
} }
pub fn _get_word_embeddings(&self) -> &nn::Embedding { pub fn _get_word_embeddings(&self) -> &nn::Embedding {
@ -94,10 +118,12 @@ impl ModuleT for DistilBertEmbedding {
let word_embed = input.apply(&self.word_embeddings); let word_embed = input.apply(&self.word_embeddings);
let position_embed = position_ids.apply(&self.position_embeddings); let position_embed = position_ids.apply(&self.position_embeddings);
// position_embed.get(0).get(0).print(); // position_embed.get(0).get(0).print();
let embeddings = word_embed + position_embed; let embeddings = word_embed + position_embed;
let embeddings = embeddings.apply(&self.layer_norm).apply_t(&self.dropout, train); let embeddings = embeddings
.apply(&self.layer_norm)
.apply_t(&self.dropout, train);
embeddings embeddings
} }
} }

View File

@ -18,18 +18,27 @@
//! Pretrained models are available and can be downloaded using RemoteResources. //! Pretrained models are available and can be downloaded using RemoteResources.
//! //!
//! ```no_run //! ```no_run
//!# fn main() -> failure::Fallible<()> { //! # fn main() -> failure::Fallible<()> {
//!# //! #
//! use rust_tokenizers::BertTokenizer; //! use rust_tokenizers::BertTokenizer;
//! use tch::{nn, Device}; //! use tch::{nn, Device};
//!# use std::path::PathBuf; //! # use std::path::PathBuf;
//! use rust_bert::distilbert::{
//! DistilBertConfig, DistilBertConfigResources, DistilBertModelMaskedLM,
//! DistilBertModelResources, DistilBertVocabResources,
//! };
//! use rust_bert::resources::{download_resource, LocalResource, RemoteResource, Resource};
//! use rust_bert::Config; //! use rust_bert::Config;
//! use rust_bert::distilbert::{DistilBertModelMaskedLM, DistilBertConfig, DistilBertConfigResources, DistilBertVocabResources, DistilBertModelResources};
//! use rust_bert::resources::{Resource, download_resource, RemoteResource, LocalResource};
//! //!
//! let config_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/config.json")}); //! let config_resource = Resource::Local(LocalResource {
//! let vocab_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/vocab.txt")}); //! local_path: PathBuf::from("path/to/config.json"),
//! let weights_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/model.ot")}); //! });
//! let vocab_resource = Resource::Local(LocalResource {
//! local_path: PathBuf::from("path/to/vocab.txt"),
//! });
//! let weights_resource = Resource::Local(LocalResource {
//! local_path: PathBuf::from("path/to/model.ot"),
//! });
//! let config_path = download_resource(&config_resource)?; //! let config_path = download_resource(&config_resource)?;
//! let vocab_path = download_resource(&vocab_resource)?; //! let vocab_path = download_resource(&vocab_resource)?;
//! let weights_path = download_resource(&weights_resource)?; //! let weights_path = download_resource(&weights_resource)?;
@ -40,17 +49,17 @@
//! let bert_model = DistilBertModelMaskedLM::new(&vs.root(), &config); //! let bert_model = DistilBertModelMaskedLM::new(&vs.root(), &config);
//! vs.load(weights_path)?; //! vs.load(weights_path)?;
//! //!
//!# Ok(()) //! # Ok(())
//!# } //! # }
//! ``` //! ```
mod attention;
mod distilbert; mod distilbert;
mod embeddings; mod embeddings;
mod attention;
mod transformer; mod transformer;
pub use distilbert::{DistilBertModelResources, DistilBertConfigResources, DistilBertVocabResources, pub use distilbert::{
DistilBertConfig, Activation, DistilBertModel, DistilBertForQuestionAnswering, DistilBertForTokenClassification, Activation, DistilBertConfig, DistilBertConfigResources, DistilBertForQuestionAnswering,
DistilBertModelMaskedLM, DistilBertModelClassifier}; DistilBertForTokenClassification, DistilBertModel, DistilBertModelClassifier,
DistilBertModelMaskedLM, DistilBertModelResources, DistilBertVocabResources,
};

View File

@ -10,13 +10,13 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use tch::{Tensor, nn};
use crate::distilbert::distilbert::{DistilBertConfig, Activation};
use crate::distilbert::attention::MultiHeadSelfAttention;
use tch::nn::LayerNorm;
use std::borrow::BorrowMut;
use crate::common::dropout::Dropout;
use crate::common::activations::{_gelu, _relu}; use crate::common::activations::{_gelu, _relu};
use crate::common::dropout::Dropout;
use crate::distilbert::attention::MultiHeadSelfAttention;
use crate::distilbert::distilbert::{Activation, DistilBertConfig};
use std::borrow::BorrowMut;
use tch::nn::LayerNorm;
use tch::{nn, Tensor};
pub struct FeedForwardNetwork { pub struct FeedForwardNetwork {
lin1: nn::Linear, lin1: nn::Linear,
@ -27,18 +27,35 @@ pub struct FeedForwardNetwork {
impl FeedForwardNetwork { impl FeedForwardNetwork {
pub fn new(p: nn::Path, config: &DistilBertConfig) -> FeedForwardNetwork { pub fn new(p: nn::Path, config: &DistilBertConfig) -> FeedForwardNetwork {
let lin1 = nn::linear(&p / "lin1", config.dim, config.hidden_dim, Default::default()); let lin1 = nn::linear(
let lin2 = nn::linear(&p / "lin2", config.hidden_dim, config.dim, Default::default()); &p / "lin1",
config.dim,
config.hidden_dim,
Default::default(),
);
let lin2 = nn::linear(
&p / "lin2",
config.hidden_dim,
config.dim,
Default::default(),
);
let dropout = Dropout::new(config.dropout); let dropout = Dropout::new(config.dropout);
let activation = Box::new(match &config.activation { let activation = Box::new(match &config.activation {
Activation::gelu => _gelu, Activation::gelu => _gelu,
Activation::relu => _relu Activation::relu => _relu,
}); });
FeedForwardNetwork { lin1, lin2, dropout, activation } FeedForwardNetwork {
lin1,
lin2,
dropout,
activation,
}
} }
pub fn forward_t(&self, input: &Tensor, train: bool) -> Tensor { pub fn forward_t(&self, input: &Tensor, train: bool) -> Tensor {
(self.activation)(&input.apply(&self.lin1)).apply(&self.lin2).apply_t(&self.dropout, train) (self.activation)(&input.apply(&self.lin1))
.apply(&self.lin2)
.apply_t(&self.dropout, train)
} }
} }
@ -52,10 +69,15 @@ pub struct TransformerBlock {
impl TransformerBlock { impl TransformerBlock {
pub fn new(p: &nn::Path, config: &DistilBertConfig) -> TransformerBlock { pub fn new(p: &nn::Path, config: &DistilBertConfig) -> TransformerBlock {
let attention = MultiHeadSelfAttention::new(p / "attention", &config); let attention = MultiHeadSelfAttention::new(p / "attention", &config);
let layer_norm_config = nn::LayerNormConfig { eps: 1e-12, ..Default::default() }; let layer_norm_config = nn::LayerNormConfig {
let sa_layer_norm = nn::layer_norm(p / "sa_layer_norm", vec![config.dim], layer_norm_config); eps: 1e-12,
..Default::default()
};
let sa_layer_norm =
nn::layer_norm(p / "sa_layer_norm", vec![config.dim], layer_norm_config);
let ffn = FeedForwardNetwork::new(p / "ffn", &config); let ffn = FeedForwardNetwork::new(p / "ffn", &config);
let output_layer_norm = nn::layer_norm(p / "output_layer_norm", vec![config.dim], layer_norm_config); let output_layer_norm =
nn::layer_norm(p / "output_layer_norm", vec![config.dim], layer_norm_config);
TransformerBlock { TransformerBlock {
attention, attention,
@ -65,8 +87,15 @@ impl TransformerBlock {
} }
} }
pub fn forward_t(&self, input: &Tensor, mask: &Option<Tensor>, train: bool) -> (Tensor, Option<Tensor>) { pub fn forward_t(
let (output, sa_weights) = self.attention.forward_t(&input, &input, &input, mask, train); &self,
input: &Tensor,
mask: &Option<Tensor>,
train: bool,
) -> (Tensor, Option<Tensor>) {
let (output, sa_weights) = self
.attention
.forward_t(&input, &input, &input, mask, train);
let output = (input + &output).apply(&self.sa_layer_norm); let output = (input + &output).apply(&self.sa_layer_norm);
let output = (&output + self.ffn.forward_t(&output, train)).apply(&self.output_layer_norm); let output = (&output + self.ffn.forward_t(&output, train)).apply(&self.output_layer_norm);
(output, sa_weights) (output, sa_weights)
@ -84,25 +113,41 @@ impl Transformer {
let p = &(p / "layer"); let p = &(p / "layer");
let output_attentions = match config.output_attentions { let output_attentions = match config.output_attentions {
Some(value) => value, Some(value) => value,
None => false None => false,
}; };
let output_hidden_states = match config.output_hidden_states { let output_hidden_states = match config.output_hidden_states {
Some(value) => value, Some(value) => value,
None => false None => false,
}; };
let mut layers: Vec<TransformerBlock> = vec!(); let mut layers: Vec<TransformerBlock> = vec![];
for layer_index in 0..config.n_layers { for layer_index in 0..config.n_layers {
layers.push(TransformerBlock::new(&(p / layer_index), config)); layers.push(TransformerBlock::new(&(p / layer_index), config));
}; }
Transformer { output_attentions, output_hidden_states, layers } Transformer {
output_attentions,
output_hidden_states,
layers,
}
} }
pub fn forward_t(&self, input: &Tensor, mask: Option<Tensor>, train: bool) pub fn forward_t(
-> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) { &self,
let mut all_hidden_states: Option<Vec<Tensor>> = if self.output_hidden_states { Some(vec!()) } else { None }; input: &Tensor,
let mut all_attentions: Option<Vec<Tensor>> = if self.output_attentions { Some(vec!()) } else { None }; mask: Option<Tensor>,
train: bool,
) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
let mut all_hidden_states: Option<Vec<Tensor>> = if self.output_hidden_states {
Some(vec![])
} else {
None
};
let mut all_attentions: Option<Vec<Tensor>> = if self.output_attentions {
Some(vec![])
} else {
None
};
let mut hidden_state = input.copy(); let mut hidden_state = input.copy();
let mut attention_weights: Option<Tensor>; let mut attention_weights: Option<Tensor>;
@ -121,10 +166,10 @@ impl Transformer {
attentions.push(attention_weights.as_ref().unwrap().copy()); attentions.push(attention_weights.as_ref().unwrap().copy());
}; };
} }
None => break None => break,
}; };
}; }
(hidden_state, all_hidden_states, all_attentions) (hidden_state, all_hidden_states, all_attentions)
} }
} }

View File

@ -12,15 +12,15 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use crate::bert::encoder::BertEncoder;
use crate::bert::{Activation, BertConfig};
use crate::common::activations::{_gelu, _mish, _relu};
use crate::common::dropout::Dropout;
use crate::electra::embeddings::ElectraEmbeddings;
use crate::Config;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::collections::HashMap; use std::collections::HashMap;
use crate::bert::{Activation, BertConfig}; use tch::{nn, Kind, Tensor};
use crate::Config;
use crate::electra::embeddings::ElectraEmbeddings;
use tch::{nn, Tensor, Kind};
use crate::bert::encoder::BertEncoder;
use crate::common::activations::{_gelu, _relu, _mish};
use crate::common::dropout::Dropout;
/// # Electra Pretrained model weight files /// # Electra Pretrained model weight files
pub struct ElectraModelResources; pub struct ElectraModelResources;
@ -33,23 +33,41 @@ pub struct ElectraVocabResources;
impl ElectraModelResources { impl ElectraModelResources {
/// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/electra. Modified with conversion to C-array format. /// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/electra. Modified with conversion to C-array format.
pub const BASE_GENERATOR: (&'static str, &'static str) = ("electra-base-generator/model.ot", "https://cdn.huggingface.co/google/electra-base-generator/rust_model.ot"); pub const BASE_GENERATOR: (&'static str, &'static str) = (
"electra-base-generator/model.ot",
"https://cdn.huggingface.co/google/electra-base-generator/rust_model.ot",
);
/// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/electra. Modified with conversion to C-array format. /// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/electra. Modified with conversion to C-array format.
pub const BASE_DISCRIMINATOR: (&'static str, &'static str) = ("electra-base-discriminator/model.ot", "https://cdn.huggingface.co/google/electra-base-discriminator/rust_model.ot"); pub const BASE_DISCRIMINATOR: (&'static str, &'static str) = (
"electra-base-discriminator/model.ot",
"https://cdn.huggingface.co/google/electra-base-discriminator/rust_model.ot",
);
} }
impl ElectraConfigResources { impl ElectraConfigResources {
/// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/electra. Modified with conversion to C-array format. /// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/electra. Modified with conversion to C-array format.
pub const BASE_GENERATOR: (&'static str, &'static str) = ("electra-base-generator/config.json", "https://cdn.huggingface.co/google/electra-base-generator/config.json"); pub const BASE_GENERATOR: (&'static str, &'static str) = (
"electra-base-generator/config.json",
"https://cdn.huggingface.co/google/electra-base-generator/config.json",
);
/// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/electra. Modified with conversion to C-array format. /// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/electra. Modified with conversion to C-array format.
pub const BASE_DISCRIMINATOR: (&'static str, &'static str) = ("electra-base-discriminator/config.json", "https://cdn.huggingface.co/google/electra-base-discriminator/config.json"); pub const BASE_DISCRIMINATOR: (&'static str, &'static str) = (
"electra-base-discriminator/config.json",
"https://cdn.huggingface.co/google/electra-base-discriminator/config.json",
);
} }
impl ElectraVocabResources { impl ElectraVocabResources {
/// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/electra. Modified with conversion to C-array format. /// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/electra. Modified with conversion to C-array format.
pub const BASE_GENERATOR: (&'static str, &'static str) = ("electra-base-generator/vocab.txt", "https://cdn.huggingface.co/google/electra-base-generator/vocab.txt"); pub const BASE_GENERATOR: (&'static str, &'static str) = (
"electra-base-generator/vocab.txt",
"https://cdn.huggingface.co/google/electra-base-generator/vocab.txt",
);
/// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/electra. Modified with conversion to C-array format. /// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/electra. Modified with conversion to C-array format.
pub const BASE_DISCRIMINATOR: (&'static str, &'static str) = ("electra-base-discriminator/vocab.txt", "https://cdn.huggingface.co/google/electra-base-discriminator/vocab.txt"); pub const BASE_DISCRIMINATOR: (&'static str, &'static str) = (
"electra-base-discriminator/vocab.txt",
"https://cdn.huggingface.co/google/electra-base-discriminator/vocab.txt",
);
} }
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
@ -103,10 +121,10 @@ impl ElectraModel {
/// # Example /// # Example
/// ///
/// ```no_run /// ```no_run
/// use rust_bert::electra::{ElectraModel, ElectraConfig}; /// use rust_bert::electra::{ElectraConfig, ElectraModel};
/// use tch::{nn, Device};
/// use rust_bert::Config; /// use rust_bert::Config;
/// use std::path::Path; /// use std::path::Path;
/// use tch::{nn, Device};
/// ///
/// let config_path = Path::new("path/to/config.json"); /// let config_path = Path::new("path/to/config.json");
/// let device = Device::Cpu; /// let device = Device::Cpu;
@ -114,11 +132,15 @@ impl ElectraModel {
/// let config = ElectraConfig::from_file(config_path); /// let config = ElectraConfig::from_file(config_path);
/// let electra_model: ElectraModel = ElectraModel::new(&(&p.root() / "electra"), &config); /// let electra_model: ElectraModel = ElectraModel::new(&(&p.root() / "electra"), &config);
/// ``` /// ```
///
pub fn new(p: &nn::Path, config: &ElectraConfig) -> ElectraModel { pub fn new(p: &nn::Path, config: &ElectraConfig) -> ElectraModel {
let embeddings = ElectraEmbeddings::new(&(p / "embeddings"), config); let embeddings = ElectraEmbeddings::new(&(p / "embeddings"), config);
let embeddings_project = if config.embedding_size != config.hidden_size { let embeddings_project = if config.embedding_size != config.hidden_size {
Some(nn::linear(&(p / "embeddings_project"), config.embedding_size, config.hidden_size, Default::default())) Some(nn::linear(
&(p / "embeddings_project"),
config.embedding_size,
config.hidden_size,
Default::default(),
))
} else { } else {
None None
}; };
@ -141,7 +163,11 @@ impl ElectraModel {
label2id: config.label2id.clone(), label2id: config.label2id.clone(),
}; };
let encoder = BertEncoder::new(&(p / "encoder"), &bert_config); let encoder = BertEncoder::new(&(p / "encoder"), &bert_config);
ElectraModel { embeddings, embeddings_project, encoder } ElectraModel {
embeddings,
embeddings_project,
encoder,
}
} }
/// Forward pass through the model /// Forward pass through the model
@ -164,80 +190,98 @@ impl ElectraModel {
/// # Example /// # Example
/// ///
/// ```no_run /// ```no_run
///# use rust_bert::electra::{ElectraModel, ElectraConfig}; /// # use rust_bert::electra::{ElectraModel, ElectraConfig};
///# use tch::{nn, Device, Tensor, no_grad}; /// # use tch::{nn, Device, Tensor, no_grad};
///# use rust_bert::Config; /// # use rust_bert::Config;
///# use std::path::Path; /// # use std::path::Path;
///# use tch::kind::Kind::Int64; /// # use tch::kind::Kind::Int64;
///# let config_path = Path::new("path/to/config.json"); /// # let config_path = Path::new("path/to/config.json");
///# let device = Device::Cpu; /// # let device = Device::Cpu;
///# let vs = nn::VarStore::new(device); /// # let vs = nn::VarStore::new(device);
///# let config = ElectraConfig::from_file(config_path); /// # let config = ElectraConfig::from_file(config_path);
///# let electra_model: ElectraModel = ElectraModel::new(&vs.root(), &config); /// # let electra_model: ElectraModel = ElectraModel::new(&vs.root(), &config);
/// let (batch_size, sequence_length) = (64, 128); /// let (batch_size, sequence_length) = (64, 128);
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device)); /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device)); /// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
/// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device)); /// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
/// let position_ids = Tensor::arange(sequence_length, (Int64, device)).expand(&[batch_size, sequence_length], true); /// let position_ids = Tensor::arange(sequence_length, (Int64, device))
/// /// .expand(&[batch_size, sequence_length], true);
/// let (output, all_hidden_states, all_attentions) = no_grad(|| {
/// electra_model
/// .forward_t(Some(input_tensor),
/// Some(mask),
/// Some(token_type_ids),
/// Some(position_ids),
/// None,
/// false).unwrap()
/// });
/// ///
/// let (output, all_hidden_states, all_attentions) = no_grad(|| {
/// electra_model
/// .forward_t(
/// Some(input_tensor),
/// Some(mask),
/// Some(token_type_ids),
/// Some(position_ids),
/// None,
/// false,
/// )
/// .unwrap()
/// });
/// ``` /// ```
/// pub fn forward_t(
pub fn forward_t(&self, &self,
input_ids: Option<Tensor>, input_ids: Option<Tensor>,
mask: Option<Tensor>, mask: Option<Tensor>,
token_type_ids: Option<Tensor>, token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>, position_ids: Option<Tensor>,
input_embeds: Option<Tensor>, input_embeds: Option<Tensor>,
train: bool) train: bool,
-> Result<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> { ) -> Result<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
let (input_shape, device) = match &input_ids { let (input_shape, device) = match &input_ids {
Some(input_value) => match &input_embeds { Some(input_value) => match &input_embeds {
Some(_) => { return Err("Only one of input ids or input embeddings may be set"); } Some(_) => {
None => (input_value.size(), input_value.device()) return Err("Only one of input ids or input embeddings may be set");
} }
None => (input_value.size(), input_value.device()),
},
None => match &input_embeds { None => match &input_embeds {
Some(embeds) => (vec!(embeds.size()[0], embeds.size()[1]), embeds.device()), Some(embeds) => (vec![embeds.size()[0], embeds.size()[1]], embeds.device()),
None => { return Err("At least one of input ids or input embeddings must be set"); } None => {
} return Err("At least one of input ids or input embeddings must be set");
}
},
}; };
let mask = match mask { let mask = match mask {
Some(value) => value, Some(value) => value,
None => Tensor::ones(&input_shape, (Kind::Int64, device)) None => Tensor::ones(&input_shape, (Kind::Int64, device)),
}; };
let extended_attention_mask = match mask.dim() { let extended_attention_mask = match mask.dim() {
3 => mask.unsqueeze(1), 3 => mask.unsqueeze(1),
2 => mask.unsqueeze(1).unsqueeze(1), 2 => mask.unsqueeze(1).unsqueeze(1),
_ => { return Err("Invalid attention mask dimension, must be 2 or 3"); } _ => {
return Err("Invalid attention mask dimension, must be 2 or 3");
}
}; };
let hidden_states = match self.embeddings.forward_t(input_ids, token_type_ids, position_ids, input_embeds, train) { let hidden_states = match self.embeddings.forward_t(
input_ids,
token_type_ids,
position_ids,
input_embeds,
train,
) {
Ok(value) => value, Ok(value) => value,
Err(e) => { return Err(e); } Err(e) => {
return Err(e);
}
}; };
let hidden_states = match &self.embeddings_project { let hidden_states = match &self.embeddings_project {
Some(layer) => hidden_states.apply(layer), Some(layer) => hidden_states.apply(layer),
None => hidden_states None => hidden_states,
}; };
let (hidden_state, all_hidden_states, all_attentions) = let (hidden_state, all_hidden_states, all_attentions) = self.encoder.forward_t(
self.encoder.forward_t(&hidden_states, &hidden_states,
&Some(extended_attention_mask), &Some(extended_attention_mask),
&None, &None,
&None, &None,
train); train,
);
Ok((hidden_state, all_hidden_states, all_attentions)) Ok((hidden_state, all_hidden_states, all_attentions))
} }
@ -268,9 +312,9 @@ impl ElectraDiscriminatorHead {
/// ///
/// ```no_run /// ```no_run
/// use rust_bert::electra::{ElectraConfig, ElectraDiscriminatorHead}; /// use rust_bert::electra::{ElectraConfig, ElectraDiscriminatorHead};
/// use tch::{nn, Device};
/// use rust_bert::Config; /// use rust_bert::Config;
/// use std::path::Path; /// use std::path::Path;
/// use tch::{nn, Device};
/// ///
/// let config_path = Path::new("path/to/config.json"); /// let config_path = Path::new("path/to/config.json");
/// let device = Device::Cpu; /// let device = Device::Cpu;
@ -278,16 +322,29 @@ impl ElectraDiscriminatorHead {
/// let config = ElectraConfig::from_file(config_path); /// let config = ElectraConfig::from_file(config_path);
/// let discriminator_head = ElectraDiscriminatorHead::new(&(&p.root() / "electra"), &config); /// let discriminator_head = ElectraDiscriminatorHead::new(&(&p.root() / "electra"), &config);
/// ``` /// ```
///
pub fn new(p: &nn::Path, config: &ElectraConfig) -> ElectraDiscriminatorHead { pub fn new(p: &nn::Path, config: &ElectraConfig) -> ElectraDiscriminatorHead {
let dense = nn::linear(&(p / "dense"), config.hidden_size, config.hidden_size, Default::default()); let dense = nn::linear(
let dense_prediction = nn::linear(&(p / "dense_prediction"), config.hidden_size, 1, Default::default()); &(p / "dense"),
config.hidden_size,
config.hidden_size,
Default::default(),
);
let dense_prediction = nn::linear(
&(p / "dense_prediction"),
config.hidden_size,
1,
Default::default(),
);
let activation = Box::new(match &config.hidden_act { let activation = Box::new(match &config.hidden_act {
Activation::gelu => _gelu, Activation::gelu => _gelu,
Activation::relu => _relu, Activation::relu => _relu,
Activation::mish => _mish Activation::mish => _mish,
}); });
ElectraDiscriminatorHead { dense, dense_prediction, activation } ElectraDiscriminatorHead {
dense,
dense_prediction,
activation,
}
} }
/// Forward pass through the discriminator head /// Forward pass through the discriminator head
@ -305,25 +362,24 @@ impl ElectraDiscriminatorHead {
/// # Example /// # Example
/// ///
/// ```no_run /// ```no_run
///# use rust_bert::electra::{ElectraConfig, ElectraDiscriminatorHead}; /// # use rust_bert::electra::{ElectraConfig, ElectraDiscriminatorHead};
///# use tch::{nn, Device, Tensor, no_grad}; /// # use tch::{nn, Device, Tensor, no_grad};
///# use rust_bert::Config; /// # use rust_bert::Config;
///# use std::path::Path; /// # use std::path::Path;
///# use tch::kind::Kind::Float; /// # use tch::kind::Kind::Float;
///# let config_path = Path::new("path/to/config.json"); /// # let config_path = Path::new("path/to/config.json");
///# let device = Device::Cpu; /// # let device = Device::Cpu;
///# let vs = nn::VarStore::new(device); /// # let vs = nn::VarStore::new(device);
///# let config = ElectraConfig::from_file(config_path); /// # let config = ElectraConfig::from_file(config_path);
///# let discriminator_head = ElectraDiscriminatorHead::new(&vs.root(), &config); /// # let discriminator_head = ElectraDiscriminatorHead::new(&vs.root(), &config);
/// let (batch_size, sequence_length) = (64, 128); /// let (batch_size, sequence_length) = (64, 128);
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length, config.hidden_size], (Float, device)); /// let input_tensor = Tensor::rand(
/// /// &[batch_size, sequence_length, config.hidden_size],
/// let output = no_grad(|| { /// (Float, device),
/// discriminator_head.forward(&input_tensor) /// );
/// });
/// ///
/// let output = no_grad(|| discriminator_head.forward(&input_tensor));
/// ``` /// ```
///
pub fn forward(&self, encoder_hidden_states: &Tensor) -> Tensor { pub fn forward(&self, encoder_hidden_states: &Tensor) -> Tensor {
let output = encoder_hidden_states.apply(&self.dense); let output = encoder_hidden_states.apply(&self.dense);
let output = (self.activation)(&output); let output = (self.activation)(&output);
@ -356,9 +412,9 @@ impl ElectraGeneratorHead {
/// ///
/// ```no_run /// ```no_run
/// use rust_bert::electra::{ElectraConfig, ElectraGeneratorHead}; /// use rust_bert::electra::{ElectraConfig, ElectraGeneratorHead};
/// use tch::{nn, Device};
/// use rust_bert::Config; /// use rust_bert::Config;
/// use std::path::Path; /// use std::path::Path;
/// use tch::{nn, Device};
/// ///
/// let config_path = Path::new("path/to/config.json"); /// let config_path = Path::new("path/to/config.json");
/// let device = Device::Cpu; /// let device = Device::Cpu;
@ -366,13 +422,25 @@ impl ElectraGeneratorHead {
/// let config = ElectraConfig::from_file(config_path); /// let config = ElectraConfig::from_file(config_path);
/// let generator_head = ElectraGeneratorHead::new(&(&p.root() / "electra"), &config); /// let generator_head = ElectraGeneratorHead::new(&(&p.root() / "electra"), &config);
/// ``` /// ```
///
pub fn new(p: &nn::Path, config: &ElectraConfig) -> ElectraGeneratorHead { pub fn new(p: &nn::Path, config: &ElectraConfig) -> ElectraGeneratorHead {
let layer_norm = nn::layer_norm(p / "LayerNorm", vec![config.embedding_size], Default::default()); let layer_norm = nn::layer_norm(
let dense = nn::linear(&(p / "dense"), config.hidden_size, config.embedding_size, Default::default()); p / "LayerNorm",
vec![config.embedding_size],
Default::default(),
);
let dense = nn::linear(
&(p / "dense"),
config.hidden_size,
config.embedding_size,
Default::default(),
);
let activation = Box::new(_gelu); let activation = Box::new(_gelu);
ElectraGeneratorHead { layer_norm, dense, activation } ElectraGeneratorHead {
layer_norm,
dense,
activation,
}
} }
/// Forward pass through the generator head /// Forward pass through the generator head
@ -388,25 +456,24 @@ impl ElectraGeneratorHead {
/// # Example /// # Example
/// ///
/// ```no_run /// ```no_run
///# use rust_bert::electra::{ElectraConfig, ElectraGeneratorHead}; /// # use rust_bert::electra::{ElectraConfig, ElectraGeneratorHead};
///# use tch::{nn, Device, Tensor, no_grad}; /// # use tch::{nn, Device, Tensor, no_grad};
///# use rust_bert::Config; /// # use rust_bert::Config;
///# use std::path::Path; /// # use std::path::Path;
///# use tch::kind::Kind::Float; /// # use tch::kind::Kind::Float;
///# let config_path = Path::new("path/to/config.json"); /// # let config_path = Path::new("path/to/config.json");
///# let device = Device::Cpu; /// # let device = Device::Cpu;
///# let vs = nn::VarStore::new(device); /// # let vs = nn::VarStore::new(device);
///# let config = ElectraConfig::from_file(config_path); /// # let config = ElectraConfig::from_file(config_path);
///# let generator_head = ElectraGeneratorHead::new(&vs.root(), &config); /// # let generator_head = ElectraGeneratorHead::new(&vs.root(), &config);
/// let (batch_size, sequence_length) = (64, 128); /// let (batch_size, sequence_length) = (64, 128);
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length, config.hidden_size], (Float, device)); /// let input_tensor = Tensor::rand(
/// /// &[batch_size, sequence_length, config.hidden_size],
/// let output = no_grad(|| { /// (Float, device),
/// generator_head.forward(&input_tensor) /// );
/// });
/// ///
/// let output = no_grad(|| generator_head.forward(&input_tensor));
/// ``` /// ```
///
pub fn forward(&self, encoder_hidden_states: &Tensor) -> Tensor { pub fn forward(&self, encoder_hidden_states: &Tensor) -> Tensor {
let output = encoder_hidden_states.apply(&self.dense); let output = encoder_hidden_states.apply(&self.dense);
let output = (self.activation)(&output); let output = (self.activation)(&output);
@ -438,10 +505,10 @@ impl ElectraForMaskedLM {
/// # Example /// # Example
/// ///
/// ```no_run /// ```no_run
/// use rust_bert::electra::{ElectraForMaskedLM, ElectraConfig}; /// use rust_bert::electra::{ElectraConfig, ElectraForMaskedLM};
/// use tch::{nn, Device};
/// use rust_bert::Config; /// use rust_bert::Config;
/// use std::path::Path; /// use std::path::Path;
/// use tch::{nn, Device};
/// ///
/// let config_path = Path::new("path/to/config.json"); /// let config_path = Path::new("path/to/config.json");
/// let device = Device::Cpu; /// let device = Device::Cpu;
@ -449,13 +516,21 @@ impl ElectraForMaskedLM {
/// let config = ElectraConfig::from_file(config_path); /// let config = ElectraConfig::from_file(config_path);
/// let electra_model: ElectraForMaskedLM = ElectraForMaskedLM::new(&p.root(), &config); /// let electra_model: ElectraForMaskedLM = ElectraForMaskedLM::new(&p.root(), &config);
/// ``` /// ```
///
pub fn new(p: &nn::Path, config: &ElectraConfig) -> ElectraForMaskedLM { pub fn new(p: &nn::Path, config: &ElectraConfig) -> ElectraForMaskedLM {
let electra = ElectraModel::new(&(p / "electra"), config); let electra = ElectraModel::new(&(p / "electra"), config);
let generator_head = ElectraGeneratorHead::new(&(p / "generator_predictions"), config); let generator_head = ElectraGeneratorHead::new(&(p / "generator_predictions"), config);
let lm_head = nn::linear(&(p / "generator_lm_head"), config.embedding_size, config.vocab_size, Default::default()); let lm_head = nn::linear(
&(p / "generator_lm_head"),
config.embedding_size,
config.vocab_size,
Default::default(),
);
ElectraForMaskedLM { electra, generator_head, lm_head } ElectraForMaskedLM {
electra,
generator_head,
lm_head,
}
} }
/// Forward pass through the model /// Forward pass through the model
@ -478,46 +553,53 @@ impl ElectraForMaskedLM {
/// # Example /// # Example
/// ///
/// ```no_run /// ```no_run
///# use rust_bert::electra::{ElectraForMaskedLM, ElectraConfig}; /// # use rust_bert::electra::{ElectraForMaskedLM, ElectraConfig};
///# use tch::{nn, Device, Tensor, no_grad}; /// # use tch::{nn, Device, Tensor, no_grad};
///# use rust_bert::Config; /// # use rust_bert::Config;
///# use std::path::Path; /// # use std::path::Path;
///# use tch::kind::Kind::Int64; /// # use tch::kind::Kind::Int64;
///# let config_path = Path::new("path/to/config.json"); /// # let config_path = Path::new("path/to/config.json");
///# let device = Device::Cpu; /// # let device = Device::Cpu;
///# let vs = nn::VarStore::new(device); /// # let vs = nn::VarStore::new(device);
///# let config = ElectraConfig::from_file(config_path); /// # let config = ElectraConfig::from_file(config_path);
///# let electra_model: ElectraForMaskedLM = ElectraForMaskedLM::new(&vs.root(), &config); /// # let electra_model: ElectraForMaskedLM = ElectraForMaskedLM::new(&vs.root(), &config);
/// let (batch_size, sequence_length) = (64, 128); /// let (batch_size, sequence_length) = (64, 128);
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device)); /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device)); /// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
/// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device)); /// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
/// let position_ids = Tensor::arange(sequence_length, (Int64, device)).expand(&[batch_size, sequence_length], true); /// let position_ids = Tensor::arange(sequence_length, (Int64, device))
/// /// .expand(&[batch_size, sequence_length], true);
/// let (output, all_hidden_states, all_attentions) = no_grad(|| {
/// electra_model
/// .forward_t(Some(input_tensor),
/// Some(mask),
/// Some(token_type_ids),
/// Some(position_ids),
/// None,
/// false)
/// });
/// ///
/// let (output, all_hidden_states, all_attentions) = no_grad(|| {
/// electra_model.forward_t(
/// Some(input_tensor),
/// Some(mask),
/// Some(token_type_ids),
/// Some(position_ids),
/// None,
/// false,
/// )
/// });
/// ``` /// ```
/// pub fn forward_t(
pub fn forward_t(&self, &self,
input_ids: Option<Tensor>, input_ids: Option<Tensor>,
mask: Option<Tensor>, mask: Option<Tensor>,
token_type_ids: Option<Tensor>, token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>, position_ids: Option<Tensor>,
input_embeds: Option<Tensor>, input_embeds: Option<Tensor>,
train: bool) train: bool,
-> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) { ) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
let (hidden_states, let (hidden_states, all_hidden_states, all_attentions) = self
all_hidden_states, .electra
all_attentions) = self.electra .forward_t(
.forward_t(input_ids, mask, token_type_ids, position_ids, input_embeds, train) input_ids,
mask,
token_type_ids,
position_ids,
input_embeds,
train,
)
.unwrap(); .unwrap();
let hidden_states = self.generator_head.forward(&hidden_states); let hidden_states = self.generator_head.forward(&hidden_states);
let hidden_states = hidden_states.apply(&self.lm_head); let hidden_states = hidden_states.apply(&self.lm_head);
@ -547,10 +629,10 @@ impl ElectraDiscriminator {
/// # Example /// # Example
/// ///
/// ```no_run /// ```no_run
/// use rust_bert::electra::{ElectraDiscriminator, ElectraConfig}; /// use rust_bert::electra::{ElectraConfig, ElectraDiscriminator};
/// use tch::{nn, Device};
/// use rust_bert::Config; /// use rust_bert::Config;
/// use std::path::Path; /// use std::path::Path;
/// use tch::{nn, Device};
/// ///
/// let config_path = Path::new("path/to/config.json"); /// let config_path = Path::new("path/to/config.json");
/// let device = Device::Cpu; /// let device = Device::Cpu;
@ -558,12 +640,15 @@ impl ElectraDiscriminator {
/// let config = ElectraConfig::from_file(config_path); /// let config = ElectraConfig::from_file(config_path);
/// let electra_model: ElectraDiscriminator = ElectraDiscriminator::new(&p.root(), &config); /// let electra_model: ElectraDiscriminator = ElectraDiscriminator::new(&p.root(), &config);
/// ``` /// ```
///
pub fn new(p: &nn::Path, config: &ElectraConfig) -> ElectraDiscriminator { pub fn new(p: &nn::Path, config: &ElectraConfig) -> ElectraDiscriminator {
let electra = ElectraModel::new(&(p / "electra"), config); let electra = ElectraModel::new(&(p / "electra"), config);
let discriminator_head = ElectraDiscriminatorHead::new(&(p / "discriminator_predictions"), config); let discriminator_head =
ElectraDiscriminatorHead::new(&(p / "discriminator_predictions"), config);
ElectraDiscriminator { electra, discriminator_head } ElectraDiscriminator {
electra,
discriminator_head,
}
} }
/// Forward pass through the model /// Forward pass through the model
@ -586,16 +671,16 @@ impl ElectraDiscriminator {
/// # Example /// # Example
/// ///
/// ```no_run /// ```no_run
///# use rust_bert::electra::{ElectraDiscriminator, ElectraConfig}; /// # use rust_bert::electra::{ElectraDiscriminator, ElectraConfig};
///# use tch::{nn, Device, Tensor, no_grad}; /// # use tch::{nn, Device, Tensor, no_grad};
///# use rust_bert::Config; /// # use rust_bert::Config;
///# use std::path::Path; /// # use std::path::Path;
///# use tch::kind::Kind::Int64; /// # use tch::kind::Kind::Int64;
///# let config_path = Path::new("path/to/config.json"); /// # let config_path = Path::new("path/to/config.json");
///# let device = Device::Cpu; /// # let device = Device::Cpu;
///# let vs = nn::VarStore::new(device); /// # let vs = nn::VarStore::new(device);
///# let config = ElectraConfig::from_file(config_path); /// # let config = ElectraConfig::from_file(config_path);
///# let electra_model: ElectraDiscriminator = ElectraDiscriminator::new(&vs.root(), &config); /// # let electra_model: ElectraDiscriminator = ElectraDiscriminator::new(&vs.root(), &config);
/// let (batch_size, sequence_length) = (64, 128); /// let (batch_size, sequence_length) = (64, 128);
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device)); /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device)); /// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
@ -611,21 +696,26 @@ impl ElectraDiscriminator {
/// None, /// None,
/// false) /// false)
/// }); /// });
///
/// ``` /// ```
/// pub fn forward_t(
pub fn forward_t(&self, &self,
input_ids: Option<Tensor>, input_ids: Option<Tensor>,
mask: Option<Tensor>, mask: Option<Tensor>,
token_type_ids: Option<Tensor>, token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>, position_ids: Option<Tensor>,
input_embeds: Option<Tensor>, input_embeds: Option<Tensor>,
train: bool) train: bool,
-> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) { ) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
let (hidden_states, let (hidden_states, all_hidden_states, all_attentions) = self
all_hidden_states, .electra
all_attentions) = self.electra .forward_t(
.forward_t(input_ids, mask, token_type_ids, position_ids, input_embeds, train) input_ids,
mask,
token_type_ids,
position_ids,
input_embeds,
train,
)
.unwrap(); .unwrap();
let probabilities = self.discriminator_head.forward(&hidden_states).sigmoid(); let probabilities = self.discriminator_head.forward(&hidden_states).sigmoid();
(probabilities, all_hidden_states, all_attentions) (probabilities, all_hidden_states, all_attentions)
@ -656,24 +746,37 @@ impl ElectraForTokenClassification {
/// # Example /// # Example
/// ///
/// ```no_run /// ```no_run
/// use rust_bert::electra::{ElectraForTokenClassification, ElectraConfig}; /// use rust_bert::electra::{ElectraConfig, ElectraForTokenClassification};
/// use tch::{nn, Device};
/// use rust_bert::Config; /// use rust_bert::Config;
/// use std::path::Path; /// use std::path::Path;
/// use tch::{nn, Device};
/// let config_path = Path::new("path/to/config.json"); /// let config_path = Path::new("path/to/config.json");
/// let device = Device::Cpu; /// let device = Device::Cpu;
/// let p = nn::VarStore::new(device); /// let p = nn::VarStore::new(device);
/// let config = ElectraConfig::from_file(config_path); /// let config = ElectraConfig::from_file(config_path);
/// let electra_model: ElectraForTokenClassification = ElectraForTokenClassification::new(&p.root(), &config); /// let electra_model: ElectraForTokenClassification =
/// ElectraForTokenClassification::new(&p.root(), &config);
/// ``` /// ```
///
pub fn new(p: &nn::Path, config: &ElectraConfig) -> ElectraForTokenClassification { pub fn new(p: &nn::Path, config: &ElectraConfig) -> ElectraForTokenClassification {
let electra = ElectraModel::new(&(p / "electra"), config); let electra = ElectraModel::new(&(p / "electra"), config);
let dropout = Dropout::new(config.hidden_dropout_prob); let dropout = Dropout::new(config.hidden_dropout_prob);
let num_labels = config.id2label.as_ref().expect("id2label must be provided for classifiers").len() as i64; let num_labels = config
let classifier = nn::linear(&(p / "classifier"), config.hidden_size, num_labels, Default::default()); .id2label
.as_ref()
.expect("id2label must be provided for classifiers")
.len() as i64;
let classifier = nn::linear(
&(p / "classifier"),
config.hidden_size,
num_labels,
Default::default(),
);
ElectraForTokenClassification { electra, dropout, classifier } ElectraForTokenClassification {
electra,
dropout,
classifier,
}
} }
/// Forward pass through the model /// Forward pass through the model
@ -696,16 +799,16 @@ impl ElectraForTokenClassification {
/// # Example /// # Example
/// ///
/// ```no_run /// ```no_run
///# use rust_bert::electra::{ElectraForTokenClassification, ElectraConfig}; /// # use rust_bert::electra::{ElectraForTokenClassification, ElectraConfig};
///# use tch::{nn, Device, Tensor, no_grad}; /// # use tch::{nn, Device, Tensor, no_grad};
///# use rust_bert::Config; /// # use rust_bert::Config;
///# use std::path::Path; /// # use std::path::Path;
///# use tch::kind::Kind::Int64; /// # use tch::kind::Kind::Int64;
///# let config_path = Path::new("path/to/config.json"); /// # let config_path = Path::new("path/to/config.json");
///# let device = Device::Cpu; /// # let device = Device::Cpu;
///# let vs = nn::VarStore::new(device); /// # let vs = nn::VarStore::new(device);
///# let config = ElectraConfig::from_file(config_path); /// # let config = ElectraConfig::from_file(config_path);
///# let electra_model: ElectraForTokenClassification = ElectraForTokenClassification::new(&vs.root(), &config); /// # let electra_model: ElectraForTokenClassification = ElectraForTokenClassification::new(&vs.root(), &config);
/// let (batch_size, sequence_length) = (64, 128); /// let (batch_size, sequence_length) = (64, 128);
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device)); /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device)); /// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
@ -721,21 +824,26 @@ impl ElectraForTokenClassification {
/// None, /// None,
/// false) /// false)
/// }); /// });
///
/// ``` /// ```
/// pub fn forward_t(
pub fn forward_t(&self, &self,
input_ids: Option<Tensor>, input_ids: Option<Tensor>,
mask: Option<Tensor>, mask: Option<Tensor>,
token_type_ids: Option<Tensor>, token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>, position_ids: Option<Tensor>,
input_embeds: Option<Tensor>, input_embeds: Option<Tensor>,
train: bool) train: bool,
-> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) { ) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
let (hidden_states, let (hidden_states, all_hidden_states, all_attentions) = self
all_hidden_states, .electra
all_attentions) = self.electra .forward_t(
.forward_t(input_ids, mask, token_type_ids, position_ids, input_embeds, train) input_ids,
mask,
token_type_ids,
position_ids,
input_embeds,
train,
)
.unwrap(); .unwrap();
let output = hidden_states let output = hidden_states
.apply_t(&self.dropout, train) .apply_t(&self.dropout, train)

View File

@ -12,10 +12,10 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use tch::{nn, Tensor, Kind};
use crate::common::dropout::Dropout; use crate::common::dropout::Dropout;
use crate::electra::electra::ElectraConfig; use crate::electra::electra::ElectraConfig;
use tch::nn::{EmbeddingConfig, embedding}; use tch::nn::{embedding, EmbeddingConfig};
use tch::{nn, Kind, Tensor};
#[derive(Debug)] #[derive(Debug)]
/// # Embeddings implementation for Electra model /// # Embeddings implementation for Electra model
@ -34,49 +34,77 @@ impl ElectraEmbeddings {
..Default::default() ..Default::default()
}; };
let word_embeddings: nn::Embedding = embedding(p / "word_embeddings", let word_embeddings: nn::Embedding = embedding(
config.vocab_size, p / "word_embeddings",
config.embedding_size, config.vocab_size,
embedding_config); config.embedding_size,
embedding_config,
);
let position_embeddings: nn::Embedding = embedding(p / "position_embeddings", let position_embeddings: nn::Embedding = embedding(
config.max_position_embeddings, p / "position_embeddings",
config.embedding_size, config.max_position_embeddings,
Default::default()); config.embedding_size,
Default::default(),
);
let token_type_embeddings: nn::Embedding = embedding(p / "token_type_embeddings", let token_type_embeddings: nn::Embedding = embedding(
config.type_vocab_size, p / "token_type_embeddings",
config.embedding_size, config.type_vocab_size,
Default::default()); config.embedding_size,
Default::default(),
);
let layer_norm_eps = match config.layer_norm_eps { let layer_norm_eps = match config.layer_norm_eps {
Some(value) => value, Some(value) => value,
None => 1e-12 None => 1e-12,
}; };
let layer_norm_config = nn::LayerNormConfig { eps: layer_norm_eps, ..Default::default() }; let layer_norm_config = nn::LayerNormConfig {
let layer_norm: nn::LayerNorm = nn::layer_norm(p / "LayerNorm", vec![config.embedding_size], layer_norm_config); eps: layer_norm_eps,
..Default::default()
};
let layer_norm: nn::LayerNorm = nn::layer_norm(
p / "LayerNorm",
vec![config.embedding_size],
layer_norm_config,
);
let dropout: Dropout = Dropout::new(config.hidden_dropout_prob); let dropout: Dropout = Dropout::new(config.hidden_dropout_prob);
ElectraEmbeddings { word_embeddings, position_embeddings, token_type_embeddings, layer_norm, dropout} ElectraEmbeddings {
word_embeddings,
position_embeddings,
token_type_embeddings,
layer_norm,
dropout,
}
} }
pub fn forward_t(&self, pub fn forward_t(
input_ids: Option<Tensor>, &self,
token_type_ids: Option<Tensor>, input_ids: Option<Tensor>,
position_ids: Option<Tensor>, token_type_ids: Option<Tensor>,
input_embeds: Option<Tensor>, position_ids: Option<Tensor>,
train: bool) -> Result<Tensor, &'static str> { input_embeds: Option<Tensor>,
train: bool,
) -> Result<Tensor, &'static str> {
let (input_embeddings, input_shape) = match input_ids { let (input_embeddings, input_shape) = match input_ids {
Some(input_value) => match input_embeds { Some(input_value) => match input_embeds {
Some(_) => { return Err("Only one of input ids or input embeddings may be set"); } Some(_) => {
None => (input_value.apply_t(&self.word_embeddings, train), input_value.size()) return Err("Only one of input ids or input embeddings may be set");
} }
None => (
input_value.apply_t(&self.word_embeddings, train),
input_value.size(),
),
},
None => match input_embeds { None => match input_embeds {
Some(embeds) => { Some(embeds) => {
let size = vec!(embeds.size()[0], embeds.size()[1]); let size = vec![embeds.size()[0], embeds.size()[1]];
(embeds, size) (embeds, size)
}, }
None => { return Err("Only one of input ids or input embeddings may be set"); } None => {
} return Err("Only one of input ids or input embeddings may be set");
}
},
}; };
let seq_length = input_embeddings.as_ref().size()[1].to_owned(); let seq_length = input_embeddings.as_ref().size()[1].to_owned();
@ -84,19 +112,22 @@ impl ElectraEmbeddings {
let position_ids = match position_ids { let position_ids = match position_ids {
Some(value) => value, Some(value) => value,
None => Tensor::arange(seq_length, (Kind::Int64, input_embeddings.device())) None => Tensor::arange(seq_length, (Kind::Int64, input_embeddings.device()))
.unsqueeze(0). .unsqueeze(0)
expand(&input_shape, true) .expand(&input_shape, true),
}; };
let token_type_ids = match token_type_ids { let token_type_ids = match token_type_ids {
Some(value) => value, Some(value) => value,
None => Tensor::zeros(&input_shape, (Kind::Int64, input_embeddings.device())) None => Tensor::zeros(&input_shape, (Kind::Int64, input_embeddings.device())),
}; };
let position_embeddings = position_ids.apply(&self.position_embeddings); let position_embeddings = position_ids.apply(&self.position_embeddings);
let token_type_embeddings = token_type_ids.apply(&self.token_type_embeddings); let token_type_embeddings = token_type_ids.apply(&self.token_type_embeddings);
let input_embeddings: Tensor = input_embeddings + position_embeddings + token_type_embeddings; let input_embeddings: Tensor =
Ok(input_embeddings.apply(&self.layer_norm).apply_t(&self.dropout, train)) input_embeddings + position_embeddings + token_type_embeddings;
Ok(input_embeddings
.apply(&self.layer_norm)
.apply_t(&self.dropout, train))
} }
} }

View File

@ -23,18 +23,24 @@
//! Pretrained models are available and can be downloaded using RemoteResources. //! Pretrained models are available and can be downloaded using RemoteResources.
//! //!
//! ```no_run //! ```no_run
//!# fn main() -> failure::Fallible<()> { //! # fn main() -> failure::Fallible<()> {
//!# //! #
//! use rust_tokenizers::BertTokenizer; //! use rust_tokenizers::BertTokenizer;
//! use tch::{nn, Device}; //! use tch::{nn, Device};
//!# use std::path::PathBuf; //! # use std::path::PathBuf;
//! use rust_bert::electra::{ElectraForMaskedLM, ElectraConfig}; //! use rust_bert::electra::{ElectraConfig, ElectraForMaskedLM};
//! use rust_bert::resources::{download_resource, LocalResource, Resource};
//! use rust_bert::Config; //! use rust_bert::Config;
//! use rust_bert::resources::{Resource, download_resource, LocalResource};
//! //!
//! let config_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/config.json")}); //! let config_resource = Resource::Local(LocalResource {
//! let vocab_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/vocab.txt")}); //! local_path: PathBuf::from("path/to/config.json"),
//! let weights_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/model.ot")}); //! });
//! let vocab_resource = Resource::Local(LocalResource {
//! local_path: PathBuf::from("path/to/vocab.txt"),
//! });
//! let weights_resource = Resource::Local(LocalResource {
//! local_path: PathBuf::from("path/to/model.ot"),
//! });
//! let config_path = download_resource(&config_resource)?; //! let config_path = download_resource(&config_resource)?;
//! let vocab_path = download_resource(&vocab_resource)?; //! let vocab_path = download_resource(&vocab_resource)?;
//! let weights_path = download_resource(&weights_resource)?; //! let weights_path = download_resource(&weights_resource)?;
@ -45,13 +51,15 @@
//! let electra_model = ElectraForMaskedLM::new(&vs.root(), &config); //! let electra_model = ElectraForMaskedLM::new(&vs.root(), &config);
//! vs.load(weights_path)?; //! vs.load(weights_path)?;
//! //!
//!# Ok(()) //! # Ok(())
//!# } //! # }
//! ``` //! ```
mod embeddings;
mod electra; mod electra;
mod embeddings;
pub use electra::{ElectraModelResources, ElectraVocabResources, ElectraConfigResources, ElectraConfig, pub use electra::{
ElectraModel, ElectraDiscriminator, ElectraForMaskedLM, ElectraDiscriminatorHead, ElectraGeneratorHead, ElectraForTokenClassification}; ElectraConfig, ElectraConfigResources, ElectraDiscriminator, ElectraDiscriminatorHead,
ElectraForMaskedLM, ElectraForTokenClassification, ElectraGeneratorHead, ElectraModel,
ElectraModelResources, ElectraVocabResources,
};

View File

@ -12,11 +12,11 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use tch::{Tensor, nn};
use crate::common::dropout::Dropout; use crate::common::dropout::Dropout;
use crate::gpt2::gpt2::Gpt2Config; use crate::gpt2::gpt2::Gpt2Config;
use tch::kind::Kind::Float; use tch::kind::Kind::Float;
use tch::nn::{Init, Module}; use tch::nn::{Init, Module};
use tch::{nn, Tensor};
#[derive(Debug)] #[derive(Debug)]
pub struct GPTConv1D { pub struct GPTConv1D {
@ -26,7 +26,14 @@ pub struct GPTConv1D {
impl GPTConv1D { impl GPTConv1D {
pub fn new(p: &nn::Path, nf: i64, nx: i64) -> GPTConv1D { pub fn new(p: &nn::Path, nf: i64, nx: i64) -> GPTConv1D {
let weight = p.var("weight", &[nx, nf], Init::Randn { mean: 0., stdev: 0.02 }); let weight = p.var(
"weight",
&[nx, nf],
Init::Randn {
mean: 0.,
stdev: 0.02,
},
);
let bias = p.var("bias", &[nf], Init::Const(0.)); let bias = p.var("bias", &[nf], Init::Const(0.));
GPTConv1D { weight, bias } GPTConv1D { weight, bias }
} }
@ -38,7 +45,6 @@ impl Module for GPTConv1D {
} }
} }
pub struct Attention { pub struct Attention {
bias: Tensor, bias: Tensor,
c_attn: GPTConv1D, c_attn: GPTConv1D,
@ -62,23 +68,27 @@ impl Attention {
let attn_pdrop = match config.attn_pdrop { let attn_pdrop = match config.attn_pdrop {
Some(value) => value, Some(value) => value,
None => 0.1 None => 0.1,
}; };
let resid_pdrop = match config.resid_pdrop { let resid_pdrop = match config.resid_pdrop {
Some(value) => value, Some(value) => value,
None => 0.1 None => 0.1,
}; };
let output_attentions = match config.output_attentions { let output_attentions = match config.output_attentions {
Some(value) => value, Some(value) => value,
None => false None => false,
}; };
let attn_dropout = Dropout::new(attn_pdrop); let attn_dropout = Dropout::new(attn_pdrop);
let resid_dropout = Dropout::new(resid_pdrop); let resid_dropout = Dropout::new(resid_pdrop);
assert_eq!(config.n_embd % config.n_head, 0, "Attention hidden states not a multiple of the number of heads"); assert_eq!(
config.n_embd % config.n_head,
0,
"Attention hidden states not a multiple of the number of heads"
);
let dim_per_head = config.n_embd / config.n_head; let dim_per_head = config.n_embd / config.n_head;
Attention { Attention {
@ -105,19 +115,31 @@ impl Attention {
} }
fn flatten(&self, x: Tensor) -> Tensor { fn flatten(&self, x: Tensor) -> Tensor {
x.transpose(1, 2).contiguous().view((x.size()[0], -1, &self.n_head * self.dim_per_head)) x.transpose(1, 2)
.contiguous()
.view((x.size()[0], -1, &self.n_head * self.dim_per_head))
} }
fn attention(&self, q: &Tensor, k: &Tensor, v: &Tensor, attention_mask: &Option<Tensor>, train: bool) fn attention(
-> (Tensor, Option<Tensor>) { &self,
q: &Tensor,
k: &Tensor,
v: &Tensor,
attention_mask: &Option<Tensor>,
train: bool,
) -> (Tensor, Option<Tensor>) {
let mut w = q.matmul(&k); let mut w = q.matmul(&k);
if self.scale { w = w / (*v.size().last().unwrap() as f64).sqrt(); } if self.scale {
w = w / (*v.size().last().unwrap() as f64).sqrt();
}
let (nd, ns) = (w.size()[2], w.size()[3]); let (nd, ns) = (w.size()[2], w.size()[3]);
let b = self.bias.narrow(2, ns - nd, nd).narrow(3, 0, ns); let b = self.bias.narrow(2, ns - nd, nd).narrow(3, 0, ns);
let mut w: Tensor = w * &b + 1e4 * (&b - 1); let mut w: Tensor = w * &b + 1e4 * (&b - 1);
if let Some(mask) = attention_mask { w = w + mask; } if let Some(mask) = attention_mask {
w = w + mask;
}
w = w.softmax(-1, Float).apply_t(&self.attn_dropout, train); w = w.softmax(-1, Float).apply_t(&self.attn_dropout, train);
let output = w.matmul(&v); let output = w.matmul(&v);
@ -128,29 +150,36 @@ impl Attention {
} }
} }
pub fn forward_t(&self, x: &Tensor, layer_past: &Option<Tensor>, attention_mask: &Option<Tensor>, train: bool) pub fn forward_t(
-> (Tensor, Tensor, Option<Tensor>) { &self,
x: &Tensor,
layer_past: &Option<Tensor>,
attention_mask: &Option<Tensor>,
train: bool,
) -> (Tensor, Tensor, Option<Tensor>) {
let x = x.apply(&self.c_attn).split(self.n_state, 2); let x = x.apply(&self.c_attn).split(self.n_state, 2);
let (query, key, value) = let (query, key, value) = (
( self.split_heads(&x[0], false),
self.split_heads(&x[0], false), self.split_heads(&x[1], true),
self.split_heads(&x[1], true), self.split_heads(&x[2], false),
self.split_heads(&x[2], false) );
);
let (key, value) = match layer_past { let (key, value) = match layer_past {
Some(past) => { Some(past) => {
let key = Tensor::cat(&[past.get(0).transpose(-2, -1), key], -1); let key = Tensor::cat(&[past.get(0).transpose(-2, -1), key], -1);
let value = Tensor::cat(&[past.get(1), value], -2); let value = Tensor::cat(&[past.get(1), value], -2);
(key, value) (key, value)
} }
None => (key, value) None => (key, value),
}; };
let present = Tensor::stack(&[key.transpose(-2, -1), value.copy()], 0); let present = Tensor::stack(&[key.transpose(-2, -1), value.copy()], 0);
let (a, attentions) = self.attention(&query, &key, &value, &attention_mask, train); let (a, attentions) = self.attention(&query, &key, &value, &attention_mask, train);
let a = self.flatten(a).apply(&self.c_proj).apply_t(&self.resid_dropout, train); let a = self
.flatten(a)
.apply(&self.c_proj)
.apply_t(&self.resid_dropout, train);
(a, present, attentions) (a, present, attentions)
} }
} }

View File

@ -12,16 +12,16 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use serde::{Deserialize, Serialize};
use tch::{nn, Tensor};
use crate::common::dropout::Dropout; use crate::common::dropout::Dropout;
use tch::nn::embedding; use crate::common::linear::{linear_no_bias, LinearNoBias};
use crate::gpt2::transformer::Block; use crate::gpt2::transformer::Block;
use tch::kind::Kind::Int64; use crate::pipelines::generation::{Cache, LMHeadModel};
use std::borrow::BorrowMut;
use crate::common::linear::{LinearNoBias, linear_no_bias};
use crate::Config; use crate::Config;
use crate::pipelines::generation::{LMHeadModel, Cache}; use serde::{Deserialize, Serialize};
use std::borrow::BorrowMut;
use tch::kind::Kind::Int64;
use tch::nn::embedding;
use tch::{nn, Tensor};
/// # GPT2 Pretrained model weight files /// # GPT2 Pretrained model weight files
pub struct Gpt2ModelResources; pub struct Gpt2ModelResources;
@ -37,54 +37,114 @@ pub struct Gpt2MergesResources;
impl Gpt2ModelResources { impl Gpt2ModelResources {
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format. /// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
pub const GPT2: (&'static str, &'static str) = ("gpt2/model.ot", "https://cdn.huggingface.co/gpt2-rust_model.ot"); pub const GPT2: (&'static str, &'static str) = (
"gpt2/model.ot",
"https://cdn.huggingface.co/gpt2-rust_model.ot",
);
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format. /// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
pub const GPT2_MEDIUM: (&'static str, &'static str) = ("gpt2-medium/model.ot", "https://cdn.huggingface.co/gpt2-medium-rust_model.ot"); pub const GPT2_MEDIUM: (&'static str, &'static str) = (
"gpt2-medium/model.ot",
"https://cdn.huggingface.co/gpt2-medium-rust_model.ot",
);
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format. /// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
pub const GPT2_LARGE: (&'static str, &'static str) = ("gpt2-large/model.ot", "https://cdn.huggingface.co/gpt2-large-rust_model.ot"); pub const GPT2_LARGE: (&'static str, &'static str) = (
"gpt2-large/model.ot",
"https://cdn.huggingface.co/gpt2-large-rust_model.ot",
);
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format. /// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
pub const GPT2_XL: (&'static str, &'static str) = ("gpt2-xl/model.ot", "https://cdn.huggingface.co/gpt2-xl-rust_model.ot"); pub const GPT2_XL: (&'static str, &'static str) = (
"gpt2-xl/model.ot",
"https://cdn.huggingface.co/gpt2-xl-rust_model.ot",
);
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format. /// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
pub const DISTIL_GPT2: (&'static str, &'static str) = ("distilgpt2/model.ot", "https://cdn.huggingface.co/distilgpt2-rust_model.ot"); pub const DISTIL_GPT2: (&'static str, &'static str) = (
"distilgpt2/model.ot",
"https://cdn.huggingface.co/distilgpt2-rust_model.ot",
);
} }
impl Gpt2ConfigResources { impl Gpt2ConfigResources {
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format. /// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
pub const GPT2: (&'static str, &'static str) = ("gpt2/config.json", "https://cdn.huggingface.co/gpt2-config.json"); pub const GPT2: (&'static str, &'static str) = (
"gpt2/config.json",
"https://cdn.huggingface.co/gpt2-config.json",
);
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format. /// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
pub const GPT2_MEDIUM: (&'static str, &'static str) = ("gpt2-medium/config.json", "https://cdn.huggingface.co/gpt2-medium-config.json"); pub const GPT2_MEDIUM: (&'static str, &'static str) = (
"gpt2-medium/config.json",
"https://cdn.huggingface.co/gpt2-medium-config.json",
);
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format. /// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
pub const GPT2_LARGE: (&'static str, &'static str) = ("gpt2-large/config.json", "https://cdn.huggingface.co/gpt2-large-config.json"); pub const GPT2_LARGE: (&'static str, &'static str) = (
"gpt2-large/config.json",
"https://cdn.huggingface.co/gpt2-large-config.json",
);
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format. /// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
pub const GPT2_XL: (&'static str, &'static str) = ("gpt2-xl/config.json", "https://cdn.huggingface.co/gpt2-xl-config.json"); pub const GPT2_XL: (&'static str, &'static str) = (
"gpt2-xl/config.json",
"https://cdn.huggingface.co/gpt2-xl-config.json",
);
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format. /// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
pub const DISTIL_GPT2: (&'static str, &'static str) = ("distilgpt2/config.json", "https://cdn.huggingface.co/distilgpt2-config.json"); pub const DISTIL_GPT2: (&'static str, &'static str) = (
"distilgpt2/config.json",
"https://cdn.huggingface.co/distilgpt2-config.json",
);
} }
impl Gpt2VocabResources { impl Gpt2VocabResources {
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format. /// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
pub const GPT2: (&'static str, &'static str) = ("gpt2/vocab.txt", "https://cdn.huggingface.co/gpt2-vocab.json"); pub const GPT2: (&'static str, &'static str) = (
"gpt2/vocab.txt",
"https://cdn.huggingface.co/gpt2-vocab.json",
);
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format. /// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
pub const GPT2_MEDIUM: (&'static str, &'static str) = ("gpt2-medium/vocab.txt", "https://cdn.huggingface.co/gpt2-medium-vocab.json"); pub const GPT2_MEDIUM: (&'static str, &'static str) = (
"gpt2-medium/vocab.txt",
"https://cdn.huggingface.co/gpt2-medium-vocab.json",
);
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format. /// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
pub const GPT2_LARGE: (&'static str, &'static str) = ("gpt2-large/vocab.txt", "https://cdn.huggingface.co/gpt2-large-vocab.json"); pub const GPT2_LARGE: (&'static str, &'static str) = (
"gpt2-large/vocab.txt",
"https://cdn.huggingface.co/gpt2-large-vocab.json",
);
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format. /// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
pub const GPT2_XL: (&'static str, &'static str) = ("gpt2-xl/vocab.txt", "https://cdn.huggingface.co/gpt2-xl-vocab.json"); pub const GPT2_XL: (&'static str, &'static str) = (
"gpt2-xl/vocab.txt",
"https://cdn.huggingface.co/gpt2-xl-vocab.json",
);
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format. /// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
pub const DISTIL_GPT2: (&'static str, &'static str) = ("distilgpt2/vocab.txt", "https://cdn.huggingface.co/distilgpt2-vocab.json"); pub const DISTIL_GPT2: (&'static str, &'static str) = (
"distilgpt2/vocab.txt",
"https://cdn.huggingface.co/distilgpt2-vocab.json",
);
} }
impl Gpt2MergesResources { impl Gpt2MergesResources {
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format. /// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
pub const GPT2: (&'static str, &'static str) = ("gpt2/merges.txt", "https://cdn.huggingface.co/gpt2-merges.txt"); pub const GPT2: (&'static str, &'static str) = (
"gpt2/merges.txt",
"https://cdn.huggingface.co/gpt2-merges.txt",
);
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format. /// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
pub const GPT2_MEDIUM: (&'static str, &'static str) = ("gpt2-medium/merges.txt", "https://cdn.huggingface.co/gpt2-medium-merges.txt"); pub const GPT2_MEDIUM: (&'static str, &'static str) = (
"gpt2-medium/merges.txt",
"https://cdn.huggingface.co/gpt2-medium-merges.txt",
);
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format. /// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
pub const GPT2_LARGE: (&'static str, &'static str) = ("gpt2-large/merges.txt", "https://cdn.huggingface.co/gpt2-large-merges.txt"); pub const GPT2_LARGE: (&'static str, &'static str) = (
"gpt2-large/merges.txt",
"https://cdn.huggingface.co/gpt2-large-merges.txt",
);
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format. /// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
pub const GPT2_XL: (&'static str, &'static str) = ("gpt2-xl/merges.txt", "https://cdn.huggingface.co/gpt2-xl-merges.txt"); pub const GPT2_XL: (&'static str, &'static str) = (
"gpt2-xl/merges.txt",
"https://cdn.huggingface.co/gpt2-xl-merges.txt",
);
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format. /// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
pub const DISTIL_GPT2: (&'static str, &'static str) = ("distilgpt2/merges.txt", "https://cdn.huggingface.co/distilgpt2-merges.txt"); pub const DISTIL_GPT2: (&'static str, &'static str) = (
"distilgpt2/merges.txt",
"https://cdn.huggingface.co/distilgpt2-merges.txt",
);
} }
#[allow(non_camel_case_types)] #[allow(non_camel_case_types)]
@ -156,10 +216,10 @@ impl Gpt2Model {
/// # Example /// # Example
/// ///
/// ```no_run /// ```no_run
/// use tch::{nn, Device}; /// use rust_bert::gpt2::{Gpt2Config, Gpt2Model};
/// use rust_bert::Config; /// use rust_bert::Config;
/// use std::path::Path; /// use std::path::Path;
/// use rust_bert::gpt2::{Gpt2Config, Gpt2Model}; /// use tch::{nn, Device};
/// ///
/// let config_path = Path::new("path/to/config.json"); /// let config_path = Path::new("path/to/config.json");
/// let device = Device::Cpu; /// let device = Device::Cpu;
@ -167,37 +227,58 @@ impl Gpt2Model {
/// let config = Gpt2Config::from_file(config_path); /// let config = Gpt2Config::from_file(config_path);
/// let gpt2: Gpt2Model = Gpt2Model::new(&(&p.root() / "gpt2"), &config); /// let gpt2: Gpt2Model = Gpt2Model::new(&(&p.root() / "gpt2"), &config);
/// ``` /// ```
///
pub fn new(p: &nn::Path, config: &Gpt2Config) -> Gpt2Model { pub fn new(p: &nn::Path, config: &Gpt2Config) -> Gpt2Model {
let p = &(p / "transformer"); let p = &(p / "transformer");
let wte = embedding(&(p / "wte"), config.vocab_size, config.n_embd, Default::default()); let wte = embedding(
let wpe = embedding(&(p / "wpe"), config.n_positions, config.n_embd, Default::default()); &(p / "wte"),
config.vocab_size,
config.n_embd,
Default::default(),
);
let wpe = embedding(
&(p / "wpe"),
config.n_positions,
config.n_embd,
Default::default(),
);
let embd_pdrop = match config.embd_pdrop { let embd_pdrop = match config.embd_pdrop {
Some(value) => value, Some(value) => value,
None => 0.1 None => 0.1,
}; };
let drop = Dropout::new(embd_pdrop); let drop = Dropout::new(embd_pdrop);
let layer_norm_config = nn::LayerNormConfig { eps: config.layer_norm_epsilon, ..Default::default() }; let layer_norm_config = nn::LayerNormConfig {
eps: config.layer_norm_epsilon,
..Default::default()
};
let ln_f = nn::layer_norm(p / "ln_f", vec![config.n_embd], layer_norm_config); let ln_f = nn::layer_norm(p / "ln_f", vec![config.n_embd], layer_norm_config);
let mut h: Vec<Block> = vec!(); let mut h: Vec<Block> = vec![];
let h_path = &(p / "h"); let h_path = &(p / "h");
for layer_index in 0..config.n_layer { for layer_index in 0..config.n_layer {
h.push(Block::new(&(h_path / layer_index), config, true)); h.push(Block::new(&(h_path / layer_index), config, true));
}; }
let output_attentions = match config.output_attentions { let output_attentions = match config.output_attentions {
Some(value) => value, Some(value) => value,
None => false None => false,
}; };
let output_past = match config.output_past { let output_past = match config.output_past {
Some(value) => value, Some(value) => value,
None => true None => true,
}; };
let output_hidden_states = match config.output_hidden_states { let output_hidden_states = match config.output_hidden_states {
Some(value) => value, Some(value) => value,
None => false None => false,
}; };
Gpt2Model { wte, wpe, drop, ln_f, h, output_past, output_hidden_states, output_attentions } Gpt2Model {
wte,
wpe,
drop,
ln_f,
h,
output_past,
output_hidden_states,
output_attentions,
}
} }
/// Forward pass through the model /// Forward pass through the model
@ -222,63 +303,101 @@ impl Gpt2Model {
/// # Example /// # Example
/// ///
/// ```no_run /// ```no_run
///# use tch::{nn, Device, Tensor, no_grad}; /// # use tch::{nn, Device, Tensor, no_grad};
///# use rust_bert::Config; /// # use rust_bert::Config;
///# use std::path::Path; /// # use std::path::Path;
///# use tch::kind::Kind::{Int64, Double}; /// # use tch::kind::Kind::{Int64, Double};
/// use rust_bert::gpt2::{Gpt2Model, Gpt2Config}; /// use rust_bert::gpt2::{Gpt2Config, Gpt2Model};
///# let config_path = Path::new("path/to/config.json"); /// # let config_path = Path::new("path/to/config.json");
///# let vocab_path = Path::new("path/to/vocab.txt"); /// # let vocab_path = Path::new("path/to/vocab.txt");
///# let device = Device::Cpu; /// # let device = Device::Cpu;
///# let vs = nn::VarStore::new(device); /// # let vs = nn::VarStore::new(device);
///# let config = Gpt2Config::from_file(config_path); /// # let config = Gpt2Config::from_file(config_path);
///# let gpt2_model: Gpt2Model = Gpt2Model::new(&vs.root(), &config); /// # let gpt2_model: Gpt2Model = Gpt2Model::new(&vs.root(), &config);
/// let (batch_size, sequence_length, past_sequence_length) = (64, 128, 56); /// let (batch_size, sequence_length, past_sequence_length) = (64, 128, 56);
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device)); /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
/// let mut past: Vec<Tensor> = Vec::with_capacity(config.n_layer as usize); /// let mut past: Vec<Tensor> = Vec::with_capacity(config.n_layer as usize);
/// for _ in 0..config.n_layer as usize { /// for _ in 0..config.n_layer as usize {
/// past.push(Tensor::rand(&[2, batch_size, config.n_head, past_sequence_length, config.n_embd / config.n_head], (Double, device))) /// past.push(Tensor::rand(
/// &[
/// 2,
/// batch_size,
/// config.n_head,
/// past_sequence_length,
/// config.n_embd / config.n_head,
/// ],
/// (Double, device),
/// ))
/// } /// }
/// let attention_mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device)); /// let attention_mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
/// let token_type_ids = Tensor::ones(&[batch_size, sequence_length], (Int64, device)); /// let token_type_ids = Tensor::ones(&[batch_size, sequence_length], (Int64, device));
/// let position_ids = Tensor::arange(sequence_length, (Int64, device)).expand(&[batch_size, sequence_length], true); /// let position_ids = Tensor::arange(sequence_length, (Int64, device))
/// /// .expand(&[batch_size, sequence_length], true);
/// let (output, past, hidden_states, attentions) = no_grad(|| {
/// gpt2_model
/// .forward_t(&Some(input_tensor),
/// &Some(past),
/// &Some(attention_mask),
/// &Some(token_type_ids),
/// &Some(position_ids),
/// &None,
/// false).unwrap()
/// });
/// ///
/// let (output, past, hidden_states, attentions) = no_grad(|| {
/// gpt2_model
/// .forward_t(
/// &Some(input_tensor),
/// &Some(past),
/// &Some(attention_mask),
/// &Some(token_type_ids),
/// &Some(position_ids),
/// &None,
/// false,
/// )
/// .unwrap()
/// });
/// ``` /// ```
/// pub fn forward_t(
pub fn forward_t(&self, &self,
input_ids: &Option<Tensor>, input_ids: &Option<Tensor>,
layer_past: &Option<Vec<Tensor>>, layer_past: &Option<Vec<Tensor>>,
attention_mask: &Option<Tensor>, attention_mask: &Option<Tensor>,
token_type_ids: &Option<Tensor>, token_type_ids: &Option<Tensor>,
position_ids: &Option<Tensor>, position_ids: &Option<Tensor>,
input_embeds: &Option<Tensor>, input_embeds: &Option<Tensor>,
train: bool) -> Result<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> { train: bool,
) -> Result<
(
Tensor,
Option<Vec<Tensor>>,
Option<Vec<Tensor>>,
Option<Vec<Tensor>>,
),
&'static str,
> {
let (input_embeddings, seq_length) = match input_ids { let (input_embeddings, seq_length) = match input_ids {
Some(input_value) => match input_embeds { Some(input_value) => match input_embeds {
Some(_) => { return Err("Only one of input ids or input embeddings may be set"); } Some(_) => {
None => (input_value.apply(&self.wte), *input_value.size().last().unwrap()) return Err("Only one of input ids or input embeddings may be set");
} }
None => (
input_value.apply(&self.wte),
*input_value.size().last().unwrap(),
),
},
None => match input_embeds { None => match input_embeds {
Some(embeds) => (embeds.copy(), embeds.size()[1]), Some(embeds) => (embeds.copy(), embeds.size()[1]),
None => { return Err("At least one of input ids or input embeddings must be set"); } None => {
} return Err("At least one of input ids or input embeddings must be set");
}
},
}; };
let (layer_past, layer_past_length) = match layer_past { let (layer_past, layer_past_length) = match layer_past {
Some(value) => { Some(value) => {
assert_eq!(value.len(), self.h.len(), "Past activations vector must be of length equal to the number of layers"); assert_eq!(
(value.iter().map(|v| Some(v.copy())).collect::<Vec<Option<Tensor>>>(), value[0].size()[3]) value.len(),
self.h.len(),
"Past activations vector must be of length equal to the number of layers"
);
(
value
.iter()
.map(|v| Some(v.copy()))
.collect::<Vec<Option<Tensor>>>(),
value[0].size()[3],
)
} }
None => { None => {
let mut out = Vec::with_capacity(self.h.len()); let mut out = Vec::with_capacity(self.h.len());
@ -289,31 +408,45 @@ impl Gpt2Model {
let position_ids = match position_ids { let position_ids = match position_ids {
Some(value) => value.copy(), Some(value) => value.copy(),
None => Tensor::arange1(layer_past_length, seq_length + layer_past_length, (Int64, input_embeddings.device())).unsqueeze(0) None => Tensor::arange1(
layer_past_length,
seq_length + layer_past_length,
(Int64, input_embeddings.device()),
)
.unsqueeze(0),
}; };
let attention_mask: Option<Tensor> = match attention_mask { let attention_mask: Option<Tensor> = match attention_mask {
Some(value) => { Some(value) => Some(
Some( (value
(value .view((input_embeddings.size()[0], -1))
.view((input_embeddings.size()[0], -1)) .unsqueeze(1)
.unsqueeze(1) .unsqueeze(2)
.unsqueeze(2) - 1.0)
- 1.0 * 10000.0,
) * 10000.0) ),
} None => None,
None => None
}; };
let position_embeds = position_ids.apply(&self.wpe); let position_embeds = position_ids.apply(&self.wpe);
let token_type_embeds = match token_type_ids { let token_type_embeds = match token_type_ids {
Some(value) => value.apply(&self.wte), Some(value) => value.apply(&self.wte),
None => Tensor::zeros_like(&position_embeds) None => Tensor::zeros_like(&position_embeds),
};
let mut hidden_state: Tensor =
(input_embeddings + position_embeds + token_type_embeds).apply_t(&self.drop, train);
let mut all_presents: Option<Vec<Tensor>> =
if self.output_past { Some(vec![]) } else { None };
let mut all_hidden_states: Option<Vec<Tensor>> = if self.output_hidden_states {
Some(vec![])
} else {
None
};
let mut all_attentions: Option<Vec<Tensor>> = if self.output_attentions {
Some(vec![])
} else {
None
}; };
let mut hidden_state: Tensor = (input_embeddings + position_embeds + token_type_embeds).apply_t(&self.drop, train);
let mut all_presents: Option<Vec<Tensor>> = if self.output_past { Some(vec!()) } else { None };
let mut all_hidden_states: Option<Vec<Tensor>> = if self.output_hidden_states { Some(vec!()) } else { None };
let mut all_attentions: Option<Vec<Tensor>> = if self.output_attentions { Some(vec!()) } else { None };
let mut layer_iter = self.h.iter().zip(layer_past); let mut layer_iter = self.h.iter().zip(layer_past);
loop { loop {
@ -333,11 +466,16 @@ impl Gpt2Model {
attentions.push(temp.2.as_ref().unwrap().copy()); attentions.push(temp.2.as_ref().unwrap().copy());
}; };
} }
None => break None => break,
}; };
}; }
Ok((hidden_state.apply(&self.ln_f), all_presents, all_hidden_states, all_attentions)) Ok((
hidden_state.apply(&self.ln_f),
all_presents,
all_hidden_states,
all_attentions,
))
} }
} }
@ -362,10 +500,10 @@ impl GPT2LMHeadModel {
/// # Example /// # Example
/// ///
/// ```no_run /// ```no_run
/// use tch::{nn, Device}; /// use rust_bert::gpt2::{GPT2LMHeadModel, Gpt2Config};
/// use rust_bert::Config; /// use rust_bert::Config;
/// use std::path::Path; /// use std::path::Path;
/// use rust_bert::gpt2::{Gpt2Config, GPT2LMHeadModel}; /// use tch::{nn, Device};
/// ///
/// let config_path = Path::new("path/to/config.json"); /// let config_path = Path::new("path/to/config.json");
/// let device = Device::Cpu; /// let device = Device::Cpu;
@ -373,11 +511,18 @@ impl GPT2LMHeadModel {
/// let config = Gpt2Config::from_file(config_path); /// let config = Gpt2Config::from_file(config_path);
/// let gpt2: GPT2LMHeadModel = GPT2LMHeadModel::new(&(&p.root() / "gpt2"), &config); /// let gpt2: GPT2LMHeadModel = GPT2LMHeadModel::new(&(&p.root() / "gpt2"), &config);
/// ``` /// ```
///
pub fn new(p: &nn::Path, config: &Gpt2Config) -> GPT2LMHeadModel { pub fn new(p: &nn::Path, config: &Gpt2Config) -> GPT2LMHeadModel {
let transformer = Gpt2Model::new(&p, config); let transformer = Gpt2Model::new(&p, config);
let lm_head = linear_no_bias(&(p / "lm_head"), config.n_embd, config.vocab_size, Default::default()); let lm_head = linear_no_bias(
GPT2LMHeadModel { transformer, lm_head } &(p / "lm_head"),
config.n_embd,
config.vocab_size,
Default::default(),
);
GPT2LMHeadModel {
transformer,
lm_head,
}
} }
} }
@ -408,75 +553,104 @@ impl LMHeadModel for GPT2LMHeadModel {
/// # Example /// # Example
/// ///
/// ```no_run /// ```no_run
///# use tch::{nn, Device, Tensor, no_grad}; /// # use tch::{nn, Device, Tensor, no_grad};
///# use rust_bert::Config; /// # use rust_bert::Config;
///# use std::path::Path; /// # use std::path::Path;
///# use tch::kind::Kind::{Int64, Double}; /// # use tch::kind::Kind::{Int64, Double};
/// use rust_bert::gpt2::{Gpt2Config, GPT2LMHeadModel}; /// use rust_bert::gpt2::{GPT2LMHeadModel, Gpt2Config};
/// use rust_bert::pipelines::generation::{LMHeadModel, Cache}; /// use rust_bert::pipelines::generation::{Cache, LMHeadModel};
///# let config_path = Path::new("path/to/config.json"); /// # let config_path = Path::new("path/to/config.json");
///# let vocab_path = Path::new("path/to/vocab.txt"); /// # let vocab_path = Path::new("path/to/vocab.txt");
///# let device = Device::Cpu; /// # let device = Device::Cpu;
///# let vs = nn::VarStore::new(device); /// # let vs = nn::VarStore::new(device);
///# let config = Gpt2Config::from_file(config_path); /// # let config = Gpt2Config::from_file(config_path);
///# let mut gpt2_model: GPT2LMHeadModel = GPT2LMHeadModel::new(&vs.root(), &config); /// # let mut gpt2_model: GPT2LMHeadModel = GPT2LMHeadModel::new(&vs.root(), &config);
/// let (batch_size, sequence_length, past_sequence_length) = (64, 128, 56); /// let (batch_size, sequence_length, past_sequence_length) = (64, 128, 56);
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device)); /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
/// let mut past: Vec<Tensor> = Vec::with_capacity(config.n_layer as usize); /// let mut past: Vec<Tensor> = Vec::with_capacity(config.n_layer as usize);
/// for _ in 0..config.n_layer as usize { /// for _ in 0..config.n_layer as usize {
/// past.push(Tensor::rand(&[2, batch_size, config.n_head, past_sequence_length, config.n_embd / config.n_head], (Double, device))) /// past.push(Tensor::rand(
/// &[
/// 2,
/// batch_size,
/// config.n_head,
/// past_sequence_length,
/// config.n_embd / config.n_head,
/// ],
/// (Double, device),
/// ))
/// } /// }
/// let attention_mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device)); /// let attention_mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
/// let token_type_ids = Tensor::ones(&[batch_size, sequence_length], (Int64, device)); /// let token_type_ids = Tensor::ones(&[batch_size, sequence_length], (Int64, device));
/// let position_ids = Tensor::arange(sequence_length, (Int64, device)).expand(&[batch_size, sequence_length], true); /// let position_ids = Tensor::arange(sequence_length, (Int64, device))
/// /// .expand(&[batch_size, sequence_length], true);
/// let (output, _, past, hidden_states, attentions) = no_grad(|| {
/// gpt2_model
/// .forward_t(&Some(input_tensor),
/// Cache::GPT2Cache(Some(past)),
/// &Some(attention_mask),
/// &Some(token_type_ids),
/// &Some(position_ids),
/// &None,
/// None,
/// &None,
/// false).unwrap()
/// });
/// ///
/// let (output, _, past, hidden_states, attentions) = no_grad(|| {
/// gpt2_model
/// .forward_t(
/// &Some(input_tensor),
/// Cache::GPT2Cache(Some(past)),
/// &Some(attention_mask),
/// &Some(token_type_ids),
/// &Some(position_ids),
/// &None,
/// None,
/// &None,
/// false,
/// )
/// .unwrap()
/// });
/// ``` /// ```
/// fn forward_t(
fn forward_t(&self, &self,
input_ids: &Option<Tensor>, input_ids: &Option<Tensor>,
layer_past: Cache, layer_past: Cache,
attention_mask: &Option<Tensor>, attention_mask: &Option<Tensor>,
token_type_ids: &Option<Tensor>, token_type_ids: &Option<Tensor>,
position_ids: &Option<Tensor>, position_ids: &Option<Tensor>,
input_embeds: &Option<Tensor>, input_embeds: &Option<Tensor>,
_encoder_outputs: Option<&Tensor>, _encoder_outputs: Option<&Tensor>,
_decoder_input_ids: &Option<Tensor>, _decoder_input_ids: &Option<Tensor>,
train: bool) -> Result<(Tensor, Option<Tensor>, Cache, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> { train: bool,
let (output, ) -> Result<
past, (
all_hidden_states, Tensor,
all_attentions) = match layer_past { Option<Tensor>,
Cache::GPT2Cache(layer_past) => Ok(self.transformer.forward_t(input_ids, Cache,
&layer_past, Option<Vec<Tensor>>,
attention_mask, Option<Vec<Tensor>>,
token_type_ids, ),
position_ids, &'static str,
input_embeds, > {
train)?), let (output, past, all_hidden_states, all_attentions) = match layer_past {
Cache::None => Ok(self.transformer.forward_t(input_ids, Cache::GPT2Cache(layer_past) => Ok(self.transformer.forward_t(
&None, input_ids,
attention_mask, &layer_past,
token_type_ids, attention_mask,
position_ids, token_type_ids,
input_embeds, position_ids,
train)?), input_embeds,
_ => Err("Cache not compatible with GPT2 model") train,
)?),
Cache::None => Ok(self.transformer.forward_t(
input_ids,
&None,
attention_mask,
token_type_ids,
position_ids,
input_embeds,
train,
)?),
_ => Err("Cache not compatible with GPT2 model"),
}?; }?;
let lm_logits = output.apply(&self.lm_head); let lm_logits = output.apply(&self.lm_head);
Ok((lm_logits, None, Cache::GPT2Cache(past), all_hidden_states, all_attentions)) Ok((
lm_logits,
None,
Cache::GPT2Cache(past),
all_hidden_states,
all_attentions,
))
} }
} }

View File

@ -14,19 +14,27 @@
//! Pretrained models are available and can be downloaded using RemoteResources. //! Pretrained models are available and can be downloaded using RemoteResources.
//! //!
//! ```no_run //! ```no_run
//!# fn main() -> failure::Fallible<()> { //! # fn main() -> failure::Fallible<()> {
//!# //! #
//! use rust_tokenizers::Gpt2Tokenizer; //! use rust_tokenizers::Gpt2Tokenizer;
//! use tch::{nn, Device}; //! use tch::{nn, Device};
//!# use std::path::PathBuf; //! # use std::path::PathBuf;
//! use rust_bert::gpt2::{GPT2LMHeadModel, Gpt2Config};
//! use rust_bert::resources::{download_resource, LocalResource, Resource};
//! use rust_bert::Config; //! use rust_bert::Config;
//! use rust_bert::gpt2::{Gpt2Config, GPT2LMHeadModel};
//! use rust_bert::resources::{Resource, download_resource, LocalResource};
//! //!
//! let config_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/config.json")}); //! let config_resource = Resource::Local(LocalResource {
//! let vocab_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/vocab.txt")}); //! local_path: PathBuf::from("path/to/config.json"),
//! let merges_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/vocab.txt")}); //! });
//! let weights_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/model.ot")}); //! let vocab_resource = Resource::Local(LocalResource {
//! local_path: PathBuf::from("path/to/vocab.txt"),
//! });
//! let merges_resource = Resource::Local(LocalResource {
//! local_path: PathBuf::from("path/to/vocab.txt"),
//! });
//! let weights_resource = Resource::Local(LocalResource {
//! local_path: PathBuf::from("path/to/model.ot"),
//! });
//! let config_path = download_resource(&config_resource)?; //! let config_path = download_resource(&config_resource)?;
//! let vocab_path = download_resource(&vocab_resource)?; //! let vocab_path = download_resource(&vocab_resource)?;
//! let merges_path = download_resource(&merges_resource)?; //! let merges_path = download_resource(&merges_resource)?;
@ -34,18 +42,24 @@
//! //!
//! let device = Device::cuda_if_available(); //! let device = Device::cuda_if_available();
//! let mut vs = nn::VarStore::new(device); //! let mut vs = nn::VarStore::new(device);
//! let tokenizer: Gpt2Tokenizer = Gpt2Tokenizer::from_file(vocab_path.to_str().unwrap(), merges_path.to_str().unwrap(), true); //! let tokenizer: Gpt2Tokenizer = Gpt2Tokenizer::from_file(
//! vocab_path.to_str().unwrap(),
//! merges_path.to_str().unwrap(),
//! true,
//! );
//! let config = Gpt2Config::from_file(config_path); //! let config = Gpt2Config::from_file(config_path);
//! let gpt2_model = GPT2LMHeadModel::new(&vs.root(), &config); //! let gpt2_model = GPT2LMHeadModel::new(&vs.root(), &config);
//! vs.load(weights_path)?; //! vs.load(weights_path)?;
//! //!
//!# Ok(()) //! # Ok(())
//!# } //! # }
//! ``` //! ```
mod gpt2;
pub(crate) mod attention; pub(crate) mod attention;
mod gpt2;
pub(crate) mod transformer; pub(crate) mod transformer;
pub use gpt2::{Gpt2ModelResources, Gpt2ConfigResources, Gpt2VocabResources, Gpt2MergesResources, pub use gpt2::{
Gpt2Config, Gpt2Model, GptActivation, GPT2LMHeadModel}; GPT2LMHeadModel, Gpt2Config, Gpt2ConfigResources, Gpt2MergesResources, Gpt2Model,
Gpt2ModelResources, Gpt2VocabResources, GptActivation,
};

View File

@ -12,11 +12,11 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use crate::gpt2::attention::{GPTConv1D, Attention};
use tch::{Tensor, nn};
use crate::common::dropout::Dropout;
use crate::gpt2::gpt2::{Gpt2Config, GptActivation};
use crate::common::activations::{_gelu_new, _relu, _swish}; use crate::common::activations::{_gelu_new, _relu, _swish};
use crate::common::dropout::Dropout;
use crate::gpt2::attention::{Attention, GPTConv1D};
use crate::gpt2::gpt2::{Gpt2Config, GptActivation};
use tch::{nn, Tensor};
pub struct MLP { pub struct MLP {
c_fc: GPTConv1D, c_fc: GPTConv1D,
@ -35,14 +35,19 @@ impl MLP {
GptActivation::relu => _relu, GptActivation::relu => _relu,
GptActivation::swish => _swish, GptActivation::swish => _swish,
}, },
None => _gelu_new None => _gelu_new,
}); });
let resid_pdrop = match config.resid_pdrop { let resid_pdrop = match config.resid_pdrop {
Some(value) => value, Some(value) => value,
None => 0.1 None => 0.1,
}; };
let dropout = Dropout::new(resid_pdrop); let dropout = Dropout::new(resid_pdrop);
MLP { c_fc, c_proj, activation, dropout } MLP {
c_fc,
c_proj,
activation,
dropout,
}
} }
pub fn forward_t(&self, x: &Tensor, train: bool) -> Tensor { pub fn forward_t(&self, x: &Tensor, train: bool) -> Tensor {
@ -60,21 +65,36 @@ pub struct Block {
impl Block { impl Block {
pub fn new(p: &nn::Path, config: &Gpt2Config, scale: bool) -> Block { pub fn new(p: &nn::Path, config: &Gpt2Config, scale: bool) -> Block {
let layer_norm_config = nn::LayerNormConfig { eps: config.layer_norm_epsilon, ..Default::default() }; let layer_norm_config = nn::LayerNormConfig {
eps: config.layer_norm_epsilon,
..Default::default()
};
let ln_1 = nn::layer_norm(p / "ln_1", vec![config.n_embd], layer_norm_config); let ln_1 = nn::layer_norm(p / "ln_1", vec![config.n_embd], layer_norm_config);
let ln_2 = nn::layer_norm(p / "ln_2", vec![config.n_embd], layer_norm_config); let ln_2 = nn::layer_norm(p / "ln_2", vec![config.n_embd], layer_norm_config);
let attn = Attention::new(&(p / "attn"), config, scale); let attn = Attention::new(&(p / "attn"), config, scale);
let mlp = MLP::new(&(p / "mlp"), config); let mlp = MLP::new(&(p / "mlp"), config);
Block { ln_1, attn, ln_2, mlp } Block {
ln_1,
attn,
ln_2,
mlp,
}
} }
pub fn forward_t(&self, x: &Tensor, layer_past: &Option<Tensor>, attention_mask: &Option<Tensor>, train: bool) pub fn forward_t(
-> (Tensor, Tensor, Option<Tensor>) { &self,
let (output, present, attentions) = self.attn.forward_t(&x.apply(&self.ln_1), layer_past, attention_mask, train); x: &Tensor,
layer_past: &Option<Tensor>,
attention_mask: &Option<Tensor>,
train: bool,
) -> (Tensor, Tensor, Option<Tensor>) {
let (output, present, attentions) =
self.attn
.forward_t(&x.apply(&self.ln_1), layer_past, attention_mask, train);
let x = x + output; let x = x + output;
let m = self.mlp.forward_t(&x.apply(&self.ln_2), train); let m = self.mlp.forward_t(&x.apply(&self.ln_2), train);
let x = x + m; let x = x + m;
(x, present, attentions) (x, present, attentions)
} }
} }

View File

@ -16,14 +16,14 @@
//! //!
//! More information on these can be found in the [`pipelines` module](./pipelines/index.html) //! More information on these can be found in the [`pipelines` module](./pipelines/index.html)
//! ```no_run //! ```no_run
//! use rust_bert::pipelines::question_answering::{QuestionAnsweringModel, QaInput}; //! use rust_bert::pipelines::question_answering::{QaInput, QuestionAnsweringModel};
//! //!
//!# fn main() -> failure::Fallible<()> { //! # fn main() -> failure::Fallible<()> {
//! let qa_model = QuestionAnsweringModel::new(Default::default())?; //! let qa_model = QuestionAnsweringModel::new(Default::default())?;
//! //!
//! let question = String::from("Where does Amy live ?"); //! let question = String::from("Where does Amy live ?");
//! let context = String::from("Amy lives in Amsterdam"); //! let context = String::from("Amy lives in Amsterdam");
//! let answers = qa_model.predict(&vec!(QaInput { question, context }), 1, 32); //! let answers = qa_model.predict(&vec![QaInput { question, context }], 1, 32);
//! # Ok(()) //! # Ok(())
//! # } //! # }
//! ``` //! ```
@ -55,19 +55,18 @@
//! - Set-up a virtual environment and install dependencies //! - Set-up a virtual environment and install dependencies
//! - run the conversion script python /utils/download-dependencies_{MODEL_TO_DOWNLOAD}.py. The dependencies will be downloaded to the user's home directory, under ~/rustbert/{} //! - run the conversion script python /utils/download-dependencies_{MODEL_TO_DOWNLOAD}.py. The dependencies will be downloaded to the user's home directory, under ~/rustbert/{}
//! 3. Run the example cargo run --release //! 3. Run the example cargo run --release
//!
pub mod distilbert;
pub mod bert;
pub mod roberta;
pub mod openai_gpt;
pub mod gpt2;
pub mod bart;
pub mod electra;
pub mod marian;
pub mod albert; pub mod albert;
pub mod bart;
pub mod bert;
mod common; mod common;
pub mod distilbert;
pub mod electra;
pub mod gpt2;
pub mod marian;
pub mod openai_gpt;
pub mod pipelines; pub mod pipelines;
pub mod roberta;
pub use common::Config;
pub use common::resources; pub use common::resources;
pub use common::Config;

View File

@ -11,10 +11,10 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use crate::bart::{BartModel, BartConfig, LayerState}; use crate::bart::{BartConfig, BartModel, LayerState};
use tch::{Tensor, nn}; use crate::pipelines::generation::{Cache, LMHeadModel};
use crate::pipelines::generation::{LMHeadModel, Cache};
use tch::nn::Init; use tch::nn::Init;
use tch::{nn, Tensor};
/// # Marian Pretrained model weight files /// # Marian Pretrained model weight files
pub struct MarianModelResources; pub struct MarianModelResources;
@ -33,78 +33,174 @@ pub struct MarianPrefix;
impl MarianModelResources { impl MarianModelResources {
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. Modified with conversion to C-array format. /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. Modified with conversion to C-array format.
pub const ENGLISH2ROMANCE: (&'static str, &'static str) = ("marian-mt-en-ROMANCE/model.ot", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-ROMANCE/rust_model.ot"); pub const ENGLISH2ROMANCE: (&'static str, &'static str) = (
"marian-mt-en-ROMANCE/model.ot",
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-ROMANCE/rust_model.ot",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. Modified with conversion to C-array format. /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. Modified with conversion to C-array format.
pub const ROMANCE2ENGLISH: (&'static str, &'static str) = ("marian-mt-ROMANCE-en/model.ot", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-ROMANCE-en/rust_model.ot"); pub const ROMANCE2ENGLISH: (&'static str, &'static str) = (
"marian-mt-ROMANCE-en/model.ot",
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-ROMANCE-en/rust_model.ot",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. Modified with conversion to C-array format. /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. Modified with conversion to C-array format.
pub const ENGLISH2GERMAN: (&'static str, &'static str) = ("marian-mt-en-de/model.ot", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-de/rust_model.ot"); pub const ENGLISH2GERMAN: (&'static str, &'static str) = (
"marian-mt-en-de/model.ot",
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-de/rust_model.ot",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. Modified with conversion to C-array format. /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. Modified with conversion to C-array format.
pub const GERMAN2ENGLISH: (&'static str, &'static str) = ("marian-mt-de-en/model.ot", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-de-en/rust_model.ot"); pub const GERMAN2ENGLISH: (&'static str, &'static str) = (
"marian-mt-de-en/model.ot",
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-de-en/rust_model.ot",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. Modified with conversion to C-array format. /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. Modified with conversion to C-array format.
pub const ENGLISH2RUSSIAN: (&'static str, &'static str) = ("marian-mt-en-ru/model.ot", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-ru/rust_model.ot"); pub const ENGLISH2RUSSIAN: (&'static str, &'static str) = (
"marian-mt-en-ru/model.ot",
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-ru/rust_model.ot",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. Modified with conversion to C-array format. /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. Modified with conversion to C-array format.
pub const RUSSIAN2ENGLISH: (&'static str, &'static str) = ("marian-mt-ru-en/model.ot", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-ru-en/rust_model.ot"); pub const RUSSIAN2ENGLISH: (&'static str, &'static str) = (
"marian-mt-ru-en/model.ot",
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-ru-en/rust_model.ot",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. Modified with conversion to C-array format. /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. Modified with conversion to C-array format.
pub const FRENCH2GERMAN: (&'static str, &'static str) = ("marian-mt-fr-de/model.ot", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-fr-de/rust_model.ot"); pub const FRENCH2GERMAN: (&'static str, &'static str) = (
"marian-mt-fr-de/model.ot",
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-fr-de/rust_model.ot",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. Modified with conversion to C-array format. /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. Modified with conversion to C-array format.
pub const GERMAN2FRENCH: (&'static str, &'static str) = ("marian-mt-de-fr/model.ot", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-de-fr/rust_model.ot"); pub const GERMAN2FRENCH: (&'static str, &'static str) = (
"marian-mt-de-fr/model.ot",
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-de-fr/rust_model.ot",
);
} }
impl MarianConfigResources { impl MarianConfigResources {
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
pub const ENGLISH2ROMANCE: (&'static str, &'static str) = ("marian-mt-en-ROMANCE/config.json", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-ROMANCE/config.json"); pub const ENGLISH2ROMANCE: (&'static str, &'static str) = (
"marian-mt-en-ROMANCE/config.json",
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-ROMANCE/config.json",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
pub const ROMANCE2ENGLISH: (&'static str, &'static str) = ("marian-mt-ROMANCE-en/config.json", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-ROMANCE-en/config.json"); pub const ROMANCE2ENGLISH: (&'static str, &'static str) = (
"marian-mt-ROMANCE-en/config.json",
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-ROMANCE-en/config.json",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
pub const ENGLISH2GERMAN: (&'static str, &'static str) = ("marian-mt-en-de/config.json", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-de/config.json"); pub const ENGLISH2GERMAN: (&'static str, &'static str) = (
"marian-mt-en-de/config.json",
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-de/config.json",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
pub const GERMAN2ENGLISH: (&'static str, &'static str) = ("marian-mt-de-en/config.json", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-de-en/config.json"); pub const GERMAN2ENGLISH: (&'static str, &'static str) = (
"marian-mt-de-en/config.json",
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-de-en/config.json",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
pub const ENGLISH2RUSSIAN: (&'static str, &'static str) = ("marian-mt-en-ru/config.json", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-ru/config.json"); pub const ENGLISH2RUSSIAN: (&'static str, &'static str) = (
"marian-mt-en-ru/config.json",
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-ru/config.json",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
pub const RUSSIAN2ENGLISH: (&'static str, &'static str) = ("marian-mt-ru-en/config.json", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-ru-en/config.json"); pub const RUSSIAN2ENGLISH: (&'static str, &'static str) = (
"marian-mt-ru-en/config.json",
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-ru-en/config.json",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
pub const FRENCH2GERMAN: (&'static str, &'static str) = ("marian-mt-fr-de/config.json", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-fr-de/config.json"); pub const FRENCH2GERMAN: (&'static str, &'static str) = (
"marian-mt-fr-de/config.json",
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-fr-de/config.json",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
pub const GERMAN2FRENCH: (&'static str, &'static str) = ("marian-mt-de-fr/config.json", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-de-fr/config.json"); pub const GERMAN2FRENCH: (&'static str, &'static str) = (
"marian-mt-de-fr/config.json",
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-de-fr/config.json",
);
} }
impl MarianVocabResources { impl MarianVocabResources {
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
pub const ENGLISH2ROMANCE: (&'static str, &'static str) = ("marian-mt-en-ROMANCE/vocab.json", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-ROMANCE/vocab.json"); pub const ENGLISH2ROMANCE: (&'static str, &'static str) = (
"marian-mt-en-ROMANCE/vocab.json",
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-ROMANCE/vocab.json",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
pub const ROMANCE2ENGLISH: (&'static str, &'static str) = ("marian-mt-ROMANCE-en/vocab.json", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-ROMANCE-en/vocab.json"); pub const ROMANCE2ENGLISH: (&'static str, &'static str) = (
"marian-mt-ROMANCE-en/vocab.json",
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-ROMANCE-en/vocab.json",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
pub const ENGLISH2GERMAN: (&'static str, &'static str) = ("marian-mt-en-de/vocab.json", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-de/vocab.json"); pub const ENGLISH2GERMAN: (&'static str, &'static str) = (
"marian-mt-en-de/vocab.json",
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-de/vocab.json",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
pub const GERMAN2ENGLISH: (&'static str, &'static str) = ("marian-mt-de-en/vocab.json", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-de-en/vocab.json"); pub const GERMAN2ENGLISH: (&'static str, &'static str) = (
"marian-mt-de-en/vocab.json",
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-de-en/vocab.json",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
pub const ENGLISH2RUSSIAN: (&'static str, &'static str) = ("marian-mt-en-ru/vocab.json", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-ru/vocab.json"); pub const ENGLISH2RUSSIAN: (&'static str, &'static str) = (
"marian-mt-en-ru/vocab.json",
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-ru/vocab.json",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
pub const RUSSIAN2ENGLISH: (&'static str, &'static str) = ("marian-mt-ru-en/vocab.json", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-ru-en/vocab.json"); pub const RUSSIAN2ENGLISH: (&'static str, &'static str) = (
"marian-mt-ru-en/vocab.json",
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-ru-en/vocab.json",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
pub const FRENCH2GERMAN: (&'static str, &'static str) = ("marian-mt-fr-de/vocab.json", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-fr-de/vocab.json"); pub const FRENCH2GERMAN: (&'static str, &'static str) = (
"marian-mt-fr-de/vocab.json",
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-fr-de/vocab.json",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
pub const GERMAN2FRENCH: (&'static str, &'static str) = ("marian-mt-de-fr/vocab.json", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-de-fr/vocab.json"); pub const GERMAN2FRENCH: (&'static str, &'static str) = (
"marian-mt-de-fr/vocab.json",
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-de-fr/vocab.json",
);
} }
impl MarianSpmResources { impl MarianSpmResources {
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
pub const ENGLISH2ROMANCE: (&'static str, &'static str) = ("marian-mt-en-ROMANCE/spiece.model", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-ROMANCE/source.spm"); pub const ENGLISH2ROMANCE: (&'static str, &'static str) = (
"marian-mt-en-ROMANCE/spiece.model",
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-ROMANCE/source.spm",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
pub const ROMANCE2ENGLISH: (&'static str, &'static str) = ("marian-mt-ROMANCE-en/spiece.model", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-ROMANCE-en/source.spm"); pub const ROMANCE2ENGLISH: (&'static str, &'static str) = (
"marian-mt-ROMANCE-en/spiece.model",
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-ROMANCE-en/source.spm",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
pub const ENGLISH2GERMAN: (&'static str, &'static str) = ("marian-mt-en-de/spiece.model", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-de/source.spm"); pub const ENGLISH2GERMAN: (&'static str, &'static str) = (
"marian-mt-en-de/spiece.model",
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-de/source.spm",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
pub const GERMAN2ENGLISH: (&'static str, &'static str) = ("marian-mt-de-en/spiece.model", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-de-en/source.spm"); pub const GERMAN2ENGLISH: (&'static str, &'static str) = (
"marian-mt-de-en/spiece.model",
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-de-en/source.spm",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
pub const ENGLISH2RUSSIAN: (&'static str, &'static str) = ("marian-mt-en-ru/spiece.model", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-ru/source.spm"); pub const ENGLISH2RUSSIAN: (&'static str, &'static str) = (
"marian-mt-en-ru/spiece.model",
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-ru/source.spm",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
pub const RUSSIAN2ENGLISH: (&'static str, &'static str) = ("marian-mt-ru-en/spiece.model", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-ru-en/source.spm"); pub const RUSSIAN2ENGLISH: (&'static str, &'static str) = (
"marian-mt-ru-en/spiece.model",
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-ru-en/source.spm",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
pub const FRENCH2GERMAN: (&'static str, &'static str) = ("marian-mt-fr-de/spiece.model", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-fr-de/source.spm"); pub const FRENCH2GERMAN: (&'static str, &'static str) = (
"marian-mt-fr-de/spiece.model",
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-fr-de/source.spm",
);
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. /// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
pub const GERMAN2FRENCH: (&'static str, &'static str) = ("marian-mt-de-fr/spiece.model", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-de-fr/source.spm"); pub const GERMAN2FRENCH: (&'static str, &'static str) = (
"marian-mt-de-fr/spiece.model",
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-de-fr/source.spm",
);
} }
impl MarianPrefix { impl MarianPrefix {
@ -150,23 +246,34 @@ impl MarianForConditionalGeneration {
/// # Example /// # Example
/// ///
/// ```no_run /// ```no_run
/// use tch::{nn, Device}; /// use rust_bert::bart::{BartConfig, BartForConditionalGeneration};
/// use rust_bert::Config; /// use rust_bert::Config;
/// use std::path::Path; /// use std::path::Path;
/// use rust_bert::bart::{BartConfig, BartForConditionalGeneration}; /// use tch::{nn, Device};
/// ///
/// let config_path = Path::new("path/to/config.json"); /// let config_path = Path::new("path/to/config.json");
/// let device = Device::Cpu; /// let device = Device::Cpu;
/// let p = nn::VarStore::new(device); /// let p = nn::VarStore::new(device);
/// let config = BartConfig::from_file(config_path); /// let config = BartConfig::from_file(config_path);
/// let generation_mode = true; /// let generation_mode = true;
/// let bart: BartForConditionalGeneration = BartForConditionalGeneration::new(&(&p.root() / "bart"), &config, generation_mode); /// let bart: BartForConditionalGeneration =
/// BartForConditionalGeneration::new(&(&p.root() / "bart"), &config, generation_mode);
/// ``` /// ```
/// pub fn new(
pub fn new(p: &nn::Path, config: &BartConfig, generation_mode: bool) -> MarianForConditionalGeneration { p: &nn::Path,
config: &BartConfig,
generation_mode: bool,
) -> MarianForConditionalGeneration {
let base_model = BartModel::new(&(p / "model"), config, generation_mode); let base_model = BartModel::new(&(p / "model"), config, generation_mode);
let final_logits_bias = p.var("final_logits_bias", &[1, config.vocab_size], Init::Const(0.)); let final_logits_bias = p.var(
MarianForConditionalGeneration { base_model, final_logits_bias } "final_logits_bias",
&[1, config.vocab_size],
Init::Const(0.),
);
MarianForConditionalGeneration {
base_model,
final_logits_bias,
}
} }
/// Forward pass through the model /// Forward pass through the model
@ -193,64 +300,101 @@ impl MarianForConditionalGeneration {
/// # Example /// # Example
/// ///
/// ```no_run /// ```no_run
///# use tch::{nn, Device, Tensor, no_grad}; /// # use tch::{nn, Device, Tensor, no_grad};
///# use rust_bert::Config; /// # use rust_bert::Config;
///# use std::path::Path; /// # use std::path::Path;
///# use tch::kind::Kind::{Int64, Double}; /// # use tch::kind::Kind::{Int64, Double};
/// use rust_bert::bart::{BartConfig}; /// use rust_bert::bart::BartConfig;
/// use rust_bert::marian::MarianForConditionalGeneration; /// use rust_bert::marian::MarianForConditionalGeneration;
///# let config_path = Path::new("path/to/config.json"); /// # let config_path = Path::new("path/to/config.json");
///# let vocab_path = Path::new("path/to/vocab.txt"); /// # let vocab_path = Path::new("path/to/vocab.txt");
///# let device = Device::Cpu; /// # let device = Device::Cpu;
///# let vs = nn::VarStore::new(device); /// # let vs = nn::VarStore::new(device);
///# let config = BartConfig::from_file(config_path); /// # let config = BartConfig::from_file(config_path);
///# let mut marian_model = MarianForConditionalGeneration::new(&vs.root(), &config, false); /// # let mut marian_model = MarianForConditionalGeneration::new(&vs.root(), &config, false);
/// let (batch_size, source_sequence_length, target_sequence_length) = (64, 128, 56); /// let (batch_size, source_sequence_length, target_sequence_length) = (64, 128, 56);
/// let input_tensor = Tensor::rand(&[batch_size, source_sequence_length], (Int64, device)); /// let input_tensor = Tensor::rand(&[batch_size, source_sequence_length], (Int64, device));
/// let target_tensor = Tensor::rand(&[batch_size, target_sequence_length], (Int64, device)); /// let target_tensor = Tensor::rand(&[batch_size, target_sequence_length], (Int64, device));
/// let encoder_attention_mask = Tensor::ones(&[batch_size, source_sequence_length], (Int64, device)); /// let encoder_attention_mask =
/// let decoder_attention_mask = Tensor::ones(&[batch_size, source_sequence_length], (Int64, device)); /// Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
/// /// let decoder_attention_mask =
/// let (decoder_output, encoder_hidden_states, cache, /// Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
/// all_encoder_hidden_states, all_encoder_attentions,
/// all_decoder_hidden_states, all_decoder_attentions) = no_grad(|| {
/// marian_model
/// .forward_t(Some(&input_tensor),
/// Some(&encoder_attention_mask),
/// None,
/// Some(&target_tensor),
/// Some(&decoder_attention_mask),
/// None,
/// false)
/// });
/// ///
/// let (
/// decoder_output,
/// encoder_hidden_states,
/// cache,
/// all_encoder_hidden_states,
/// all_encoder_attentions,
/// all_decoder_hidden_states,
/// all_decoder_attentions,
/// ) = no_grad(|| {
/// marian_model.forward_t(
/// Some(&input_tensor),
/// Some(&encoder_attention_mask),
/// None,
/// Some(&target_tensor),
/// Some(&decoder_attention_mask),
/// None,
/// false,
/// )
/// });
/// ``` /// ```
/// pub fn forward_t(
pub fn forward_t(&self, &self,
input_ids: Option<&Tensor>, input_ids: Option<&Tensor>,
attention_mask: Option<&Tensor>, attention_mask: Option<&Tensor>,
encoder_outputs: Option<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>)>, encoder_outputs: Option<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>)>,
decoder_input_ids: Option<&Tensor>, decoder_input_ids: Option<&Tensor>,
decoder_attention_mask: Option<&Tensor>, decoder_attention_mask: Option<&Tensor>,
old_layer_states: Option<Vec<(Option<LayerState>, Option<LayerState>)>>, old_layer_states: Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
train: bool) train: bool,
-> (Tensor, Tensor, Option<Vec<(Option<LayerState>, Option<LayerState>)>>, ) -> (
Option<Vec<Tensor>>, Option<Vec<Tensor>>, Tensor,
Option<Vec<Tensor>>, Option<Vec<Tensor>>) Tensor,
{ Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
let (decoder_outputs, encoder_hidden_states, decoder_cache, Option<Vec<Tensor>>,
all_decoder_hidden_states, all_decoder_attentions, Option<Vec<Tensor>>,
all_encoder_hidden_states, all_encoder_attentions) = Option<Vec<Tensor>>,
self.base_model.forward_t(input_ids, attention_mask, decoder_input_ids, encoder_outputs, decoder_attention_mask, old_layer_states, train); Option<Vec<Tensor>>,
) {
let (
decoder_outputs,
encoder_hidden_states,
decoder_cache,
all_decoder_hidden_states,
all_decoder_attentions,
all_encoder_hidden_states,
all_encoder_attentions,
) = self.base_model.forward_t(
input_ids,
attention_mask,
decoder_input_ids,
encoder_outputs,
decoder_attention_mask,
old_layer_states,
train,
);
let lm_logits = decoder_outputs.linear::<Tensor>(&self.base_model.embeddings.ws, None); let lm_logits = decoder_outputs.linear::<Tensor>(&self.base_model.embeddings.ws, None);
(lm_logits, encoder_hidden_states, decoder_cache, (
all_decoder_hidden_states, all_decoder_attentions, lm_logits,
all_encoder_hidden_states, all_encoder_attentions) encoder_hidden_states,
decoder_cache,
all_decoder_hidden_states,
all_decoder_attentions,
all_encoder_hidden_states,
all_encoder_attentions,
)
} }
pub fn encode(&self, input_ids: &Tensor, attention_mask: Option<&Tensor>) -> Tensor { pub fn encode(&self, input_ids: &Tensor, attention_mask: Option<&Tensor>) -> Tensor {
let (encoder_hidden_states, _, _) = self.base_model.encoder.forward_t(input_ids, attention_mask, &self.base_model.embeddings, false); let (encoder_hidden_states, _, _) = self.base_model.encoder.forward_t(
input_ids,
attention_mask,
&self.base_model.embeddings,
false,
);
encoder_hidden_states encoder_hidden_states
} }
} }
@ -283,68 +427,97 @@ impl LMHeadModel for MarianForConditionalGeneration {
/// # Example /// # Example
/// ///
/// ```no_run /// ```no_run
///# use tch::{nn, Device, Tensor, no_grad}; /// # use tch::{nn, Device, Tensor, no_grad};
///# use rust_bert::Config; /// # use rust_bert::Config;
///# use std::path::Path; /// # use std::path::Path;
///# use tch::kind::Kind::{Int64, Double}; /// # use tch::kind::Kind::{Int64, Double};
/// use rust_bert::bart::{BartConfig}; /// use rust_bert::bart::BartConfig;
/// use rust_bert::marian::MarianForConditionalGeneration; /// use rust_bert::marian::MarianForConditionalGeneration;
///# let config_path = Path::new("path/to/config.json"); /// # let config_path = Path::new("path/to/config.json");
///# let vocab_path = Path::new("path/to/vocab.txt"); /// # let vocab_path = Path::new("path/to/vocab.txt");
///# let device = Device::Cpu; /// # let device = Device::Cpu;
///# let vs = nn::VarStore::new(device); /// # let vs = nn::VarStore::new(device);
///# let config = BartConfig::from_file(config_path); /// # let config = BartConfig::from_file(config_path);
///# let marian_model = MarianForConditionalGeneration::new(&vs.root(), &config, false); /// # let marian_model = MarianForConditionalGeneration::new(&vs.root(), &config, false);
/// let (batch_size, source_sequence_length, target_sequence_length) = (64, 128, 56); /// let (batch_size, source_sequence_length, target_sequence_length) = (64, 128, 56);
/// let input_tensor = Tensor::rand(&[batch_size, source_sequence_length], (Int64, device)); /// let input_tensor = Tensor::rand(&[batch_size, source_sequence_length], (Int64, device));
/// let target_tensor = Tensor::rand(&[batch_size, target_sequence_length], (Int64, device)); /// let target_tensor = Tensor::rand(&[batch_size, target_sequence_length], (Int64, device));
/// let encoder_attention_mask = Tensor::ones(&[batch_size, source_sequence_length], (Int64, device)); /// let encoder_attention_mask =
/// let decoder_attention_mask = Tensor::ones(&[batch_size, source_sequence_length], (Int64, device)); /// Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
/// /// let decoder_attention_mask =
/// let (decoder_output, encoder_hidden_states, cache, /// Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
/// all_encoder_hidden_states, all_encoder_attentions,
/// all_decoder_hidden_states, all_decoder_attentions) = no_grad(|| {
/// marian_model
/// .forward_t(Some(&input_tensor),
/// Some(&encoder_attention_mask),
/// None,
/// Some(&target_tensor),
/// Some(&decoder_attention_mask),
/// None,
/// false)
/// });
/// ///
/// let (
/// decoder_output,
/// encoder_hidden_states,
/// cache,
/// all_encoder_hidden_states,
/// all_encoder_attentions,
/// all_decoder_hidden_states,
/// all_decoder_attentions,
/// ) = no_grad(|| {
/// marian_model.forward_t(
/// Some(&input_tensor),
/// Some(&encoder_attention_mask),
/// None,
/// Some(&target_tensor),
/// Some(&decoder_attention_mask),
/// None,
/// false,
/// )
/// });
/// ``` /// ```
/// fn forward_t(
fn forward_t(&self, &self,
input_ids: &Option<Tensor>, input_ids: &Option<Tensor>,
cache: Cache, cache: Cache,
attention_mask: &Option<Tensor>, attention_mask: &Option<Tensor>,
_token_type_ids: &Option<Tensor>, _token_type_ids: &Option<Tensor>,
_position_ids: &Option<Tensor>, _position_ids: &Option<Tensor>,
_input_embeds: &Option<Tensor>, _input_embeds: &Option<Tensor>,
encoder_outputs: Option<&Tensor>, encoder_outputs: Option<&Tensor>,
decoder_input_ids: &Option<Tensor>, decoder_input_ids: &Option<Tensor>,
train: bool) -> Result<(Tensor, Option<Tensor>, Cache, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> { train: bool,
) -> Result<
(
Tensor,
Option<Tensor>,
Cache,
Option<Vec<Tensor>>,
Option<Vec<Tensor>>,
),
&'static str,
> {
let (decoder_output, encoder_hidden_states, new_cache, _, _, _, _) = match cache { let (decoder_output, encoder_hidden_states, new_cache, _, _, _, _) = match cache {
Cache::BARTCache(cached_layer_states) => self.base_model.forward_t(input_ids.as_ref(), Cache::BARTCache(cached_layer_states) => self.base_model.forward_t(
attention_mask.as_ref(), input_ids.as_ref(),
decoder_input_ids.as_ref(), attention_mask.as_ref(),
Some((encoder_outputs.as_ref().unwrap().copy(), None, None)), decoder_input_ids.as_ref(),
None, Some((encoder_outputs.as_ref().unwrap().copy(), None, None)),
cached_layer_states, None,
train), cached_layer_states,
Cache::None => self.base_model.forward_t(input_ids.as_ref(), train,
attention_mask.as_ref(), ),
decoder_input_ids.as_ref(), Cache::None => self.base_model.forward_t(
Some((encoder_outputs.as_ref().unwrap().copy(), None, None)), input_ids.as_ref(),
None, attention_mask.as_ref(),
None, decoder_input_ids.as_ref(),
train), Some((encoder_outputs.as_ref().unwrap().copy(), None, None)),
_ => Err("Cache not compatible with Marian Model")? None,
None,
train,
),
_ => Err("Cache not compatible with Marian Model")?,
}; };
let lm_logits = decoder_output.linear::<Tensor>(&self.base_model.embeddings.ws, None) + &self.final_logits_bias; let lm_logits = decoder_output.linear::<Tensor>(&self.base_model.embeddings.ws, None)
Ok((lm_logits, Some(encoder_hidden_states), Cache::BARTCache(new_cache), None, None)) + &self.final_logits_bias;
Ok((
lm_logits,
Some(encoder_hidden_states),
Cache::BARTCache(new_cache),
None,
None,
))
} }
} }

View File

@ -15,20 +15,28 @@
//! Pretrained models for a number of language pairs are available and can be downloaded using RemoteResources. These are shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. //! Pretrained models for a number of language pairs are available and can be downloaded using RemoteResources. These are shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
//! //!
//! ```no_run //! ```no_run
//!# fn main() -> failure::Fallible<()> { //! # fn main() -> failure::Fallible<()> {
//!# //! #
//! use tch::{nn, Device}; //! use tch::{nn, Device};
//!# use std::path::PathBuf; //! # use std::path::PathBuf;
//! use rust_bert::Config;
//! use rust_bert::bart::{BartConfig, BartModel}; //! use rust_bert::bart::{BartConfig, BartModel};
//! use rust_bert::resources::{Resource, download_resource, LocalResource};
//! use rust_tokenizers::preprocessing::tokenizer::marian_tokenizer::MarianTokenizer;
//! use rust_bert::marian::MarianForConditionalGeneration; //! use rust_bert::marian::MarianForConditionalGeneration;
//! use rust_bert::resources::{download_resource, LocalResource, Resource};
//! use rust_bert::Config;
//! use rust_tokenizers::preprocessing::tokenizer::marian_tokenizer::MarianTokenizer;
//! //!
//! let config_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/config.json")}); //! let config_resource = Resource::Local(LocalResource {
//! let vocab_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/vocab.json")}); //! local_path: PathBuf::from("path/to/config.json"),
//! let sentence_piece_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/spiece.model")}); //! });
//! let weights_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/model.ot")}); //! let vocab_resource = Resource::Local(LocalResource {
//! local_path: PathBuf::from("path/to/vocab.json"),
//! });
//! let sentence_piece_resource = Resource::Local(LocalResource {
//! local_path: PathBuf::from("path/to/spiece.model"),
//! });
//! let weights_resource = Resource::Local(LocalResource {
//! local_path: PathBuf::from("path/to/model.ot"),
//! });
//! let config_path = download_resource(&config_resource)?; //! let config_path = download_resource(&config_resource)?;
//! let vocab_path = download_resource(&vocab_resource)?; //! let vocab_path = download_resource(&vocab_resource)?;
//! let spiece_path = download_resource(&sentence_piece_resource)?; //! let spiece_path = download_resource(&sentence_piece_resource)?;
@ -36,15 +44,22 @@
//! //!
//! let device = Device::cuda_if_available(); //! let device = Device::cuda_if_available();
//! let mut vs = nn::VarStore::new(device); //! let mut vs = nn::VarStore::new(device);
//! let tokenizer = MarianTokenizer::from_files(vocab_path.to_str().unwrap(), spiece_path.to_str().unwrap(), true); //! let tokenizer = MarianTokenizer::from_files(
//! vocab_path.to_str().unwrap(),
//! spiece_path.to_str().unwrap(),
//! true,
//! );
//! let config = BartConfig::from_file(config_path); //! let config = BartConfig::from_file(config_path);
//! let marian_model = MarianForConditionalGeneration::new(&vs.root(), &config, false); //! let marian_model = MarianForConditionalGeneration::new(&vs.root(), &config, false);
//! vs.load(weights_path)?; //! vs.load(weights_path)?;
//! //!
//!# Ok(()) //! # Ok(())
//!# } //! # }
//! ``` //! ```
mod marian; mod marian;
pub use marian::{MarianForConditionalGeneration, MarianModelResources, MarianConfigResources, MarianVocabResources, MarianSpmResources, MarianPrefix}; pub use marian::{
MarianConfigResources, MarianForConditionalGeneration, MarianModelResources, MarianPrefix,
MarianSpmResources, MarianVocabResources,
};

View File

@ -14,19 +14,27 @@
//! Pretrained models are available and can be downloaded using RemoteResources. //! Pretrained models are available and can be downloaded using RemoteResources.
//! //!
//! ```no_run //! ```no_run
//!# fn main() -> failure::Fallible<()> { //! # fn main() -> failure::Fallible<()> {
//! use rust_tokenizers::OpenAiGptTokenizer; //! use rust_tokenizers::OpenAiGptTokenizer;
//! use tch::{nn, Device}; //! use tch::{nn, Device};
//!# use std::path::PathBuf; //! # use std::path::PathBuf;
//! use rust_bert::Config;
//! use rust_bert::gpt2::Gpt2Config; //! use rust_bert::gpt2::Gpt2Config;
//! use rust_bert::openai_gpt::OpenAiGptModel; //! use rust_bert::openai_gpt::OpenAiGptModel;
//! use rust_bert::resources::{Resource, download_resource, LocalResource}; //! use rust_bert::resources::{download_resource, LocalResource, Resource};
//! use rust_bert::Config;
//! //!
//! let config_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/config.json")}); //! let config_resource = Resource::Local(LocalResource {
//! let vocab_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/vocab.txt")}); //! local_path: PathBuf::from("path/to/config.json"),
//! let merges_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/vocab.txt")}); //! });
//! let weights_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/model.ot")}); //! let vocab_resource = Resource::Local(LocalResource {
//! local_path: PathBuf::from("path/to/vocab.txt"),
//! });
//! let merges_resource = Resource::Local(LocalResource {
//! local_path: PathBuf::from("path/to/vocab.txt"),
//! });
//! let weights_resource = Resource::Local(LocalResource {
//! local_path: PathBuf::from("path/to/model.ot"),
//! });
//! let config_path = download_resource(&config_resource)?; //! let config_path = download_resource(&config_resource)?;
//! let vocab_path = download_resource(&vocab_resource)?; //! let vocab_path = download_resource(&vocab_resource)?;
//! let merges_path = download_resource(&merges_resource)?; //! let merges_path = download_resource(&merges_resource)?;
@ -34,17 +42,23 @@
//! //!
//! let device = Device::cuda_if_available(); //! let device = Device::cuda_if_available();
//! let mut vs = nn::VarStore::new(device); //! let mut vs = nn::VarStore::new(device);
//! let tokenizer: OpenAiGptTokenizer = OpenAiGptTokenizer::from_file(vocab_path.to_str().unwrap(), merges_path.to_str().unwrap(), true); //! let tokenizer: OpenAiGptTokenizer = OpenAiGptTokenizer::from_file(
//! vocab_path.to_str().unwrap(),
//! merges_path.to_str().unwrap(),
//! true,
//! );
//! let config = Gpt2Config::from_file(config_path); //! let config = Gpt2Config::from_file(config_path);
//! let gpt_model = OpenAiGptModel::new(&vs.root(), &config); //! let gpt_model = OpenAiGptModel::new(&vs.root(), &config);
//! vs.load(weights_path)?; //! vs.load(weights_path)?;
//! //!
//!# Ok(()) //! # Ok(())
//!# } //! # }
//! ``` //! ```
mod openai_gpt; mod openai_gpt;
mod transformer; mod transformer;
pub use openai_gpt::{OpenAiGptModelResources, OpenAiGptConfigResources, OpenAiGptVocabResources, OpenAiGptMergesResources, pub use openai_gpt::{
OpenAiGptModel, OpenAIGPTLMHeadModel}; OpenAIGPTLMHeadModel, OpenAiGptConfigResources, OpenAiGptMergesResources, OpenAiGptModel,
OpenAiGptModelResources, OpenAiGptVocabResources,
};

View File

@ -12,15 +12,15 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use tch::{nn, Tensor};
use crate::common::dropout::Dropout; use crate::common::dropout::Dropout;
use tch::nn::embedding; use crate::common::linear::{linear_no_bias, LinearNoBias};
use tch::kind::Kind::Int64;
use std::borrow::BorrowMut;
use crate::common::linear::{LinearNoBias, linear_no_bias};
use crate::openai_gpt::transformer::Block;
use crate::gpt2::Gpt2Config; use crate::gpt2::Gpt2Config;
use crate::pipelines::generation::{LMHeadModel, Cache}; use crate::openai_gpt::transformer::Block;
use crate::pipelines::generation::{Cache, LMHeadModel};
use std::borrow::BorrowMut;
use tch::kind::Kind::Int64;
use tch::nn::embedding;
use tch::{nn, Tensor};
/// # GPT Pretrained model weight files /// # GPT Pretrained model weight files
pub struct OpenAiGptModelResources; pub struct OpenAiGptModelResources;
@ -36,22 +36,34 @@ pub struct OpenAiGptMergesResources;
impl OpenAiGptModelResources { impl OpenAiGptModelResources {
/// Shared under MIT license by the OpenAI team at https://github.com/openai/finetune-transformer-lm. Modified with conversion to C-array format. /// Shared under MIT license by the OpenAI team at https://github.com/openai/finetune-transformer-lm. Modified with conversion to C-array format.
pub const GPT: (&'static str, &'static str) = ("openai-gpt/model.ot", "https://cdn.huggingface.co/openai-gpt-rust_model.ot"); pub const GPT: (&'static str, &'static str) = (
"openai-gpt/model.ot",
"https://cdn.huggingface.co/openai-gpt-rust_model.ot",
);
} }
impl OpenAiGptConfigResources { impl OpenAiGptConfigResources {
/// Shared under MIT license by the OpenAI team at https://github.com/openai/finetune-transformer-lm. Modified with conversion to C-array format. /// Shared under MIT license by the OpenAI team at https://github.com/openai/finetune-transformer-lm. Modified with conversion to C-array format.
pub const GPT: (&'static str, &'static str) = ("openai-gpt/config.json", "https://cdn.huggingface.co/openai-gpt-config.json"); pub const GPT: (&'static str, &'static str) = (
"openai-gpt/config.json",
"https://cdn.huggingface.co/openai-gpt-config.json",
);
} }
impl OpenAiGptVocabResources { impl OpenAiGptVocabResources {
/// Shared under MIT license by the OpenAI team at https://github.com/openai/finetune-transformer-lm. Modified with conversion to C-array format. /// Shared under MIT license by the OpenAI team at https://github.com/openai/finetune-transformer-lm. Modified with conversion to C-array format.
pub const GPT: (&'static str, &'static str) = ("openai-gpt/vocab.txt", "https://cdn.huggingface.co/openai-gpt-vocab.json"); pub const GPT: (&'static str, &'static str) = (
"openai-gpt/vocab.txt",
"https://cdn.huggingface.co/openai-gpt-vocab.json",
);
} }
impl OpenAiGptMergesResources { impl OpenAiGptMergesResources {
/// Shared under MIT license by the OpenAI team at https://github.com/openai/finetune-transformer-lm. Modified with conversion to C-array format. /// Shared under MIT license by the OpenAI team at https://github.com/openai/finetune-transformer-lm. Modified with conversion to C-array format.
pub const GPT: (&'static str, &'static str) = ("openai-gpt/merges.txt", "https://cdn.huggingface.co/openai-gpt-merges.txt"); pub const GPT: (&'static str, &'static str) = (
"openai-gpt/merges.txt",
"https://cdn.huggingface.co/openai-gpt-merges.txt",
);
} }
/// # GPT Base model /// # GPT Base model
@ -82,11 +94,11 @@ impl OpenAiGptModel {
/// # Example /// # Example
/// ///
/// ```no_run /// ```no_run
/// use tch::{nn, Device};
/// use rust_bert::Config;
/// use std::path::Path;
/// use rust_bert::gpt2::Gpt2Config; /// use rust_bert::gpt2::Gpt2Config;
/// use rust_bert::openai_gpt::OpenAiGptModel; /// use rust_bert::openai_gpt::OpenAiGptModel;
/// use rust_bert::Config;
/// use std::path::Path;
/// use tch::{nn, Device};
/// ///
/// let config_path = Path::new("path/to/config.json"); /// let config_path = Path::new("path/to/config.json");
/// let device = Device::Cpu; /// let device = Device::Cpu;
@ -94,30 +106,46 @@ impl OpenAiGptModel {
/// let config = Gpt2Config::from_file(config_path); /// let config = Gpt2Config::from_file(config_path);
/// let gpt2: OpenAiGptModel = OpenAiGptModel::new(&(&p.root() / "gpt"), &config); /// let gpt2: OpenAiGptModel = OpenAiGptModel::new(&(&p.root() / "gpt"), &config);
/// ``` /// ```
///
pub fn new(p: &nn::Path, config: &Gpt2Config) -> OpenAiGptModel { pub fn new(p: &nn::Path, config: &Gpt2Config) -> OpenAiGptModel {
let tokens_embed = embedding(&(p / "tokens_embed"), config.vocab_size, config.n_embd, Default::default()); let tokens_embed = embedding(
let positions_embed = embedding(&(p / "positions_embed"), config.n_positions, config.n_embd, Default::default()); &(p / "tokens_embed"),
config.vocab_size,
config.n_embd,
Default::default(),
);
let positions_embed = embedding(
&(p / "positions_embed"),
config.n_positions,
config.n_embd,
Default::default(),
);
let embd_pdrop = match config.embd_pdrop { let embd_pdrop = match config.embd_pdrop {
Some(value) => value, Some(value) => value,
None => 0.1 None => 0.1,
}; };
let drop = Dropout::new(embd_pdrop); let drop = Dropout::new(embd_pdrop);
let mut h: Vec<Block> = vec!(); let mut h: Vec<Block> = vec![];
let h_path = &(p / "h"); let h_path = &(p / "h");
for layer_index in 0..config.n_layer { for layer_index in 0..config.n_layer {
h.push(Block::new(&(h_path / layer_index), config, true)); h.push(Block::new(&(h_path / layer_index), config, true));
}; }
let output_attentions = match config.output_attentions { let output_attentions = match config.output_attentions {
Some(value) => value, Some(value) => value,
None => false None => false,
}; };
let output_hidden_states = match config.output_hidden_states { let output_hidden_states = match config.output_hidden_states {
Some(value) => value, Some(value) => value,
None => false None => false,
}; };
OpenAiGptModel { tokens_embed, positions_embed, drop, h, output_hidden_states, output_attentions } OpenAiGptModel {
tokens_embed,
positions_embed,
drop,
h,
output_hidden_states,
output_attentions,
}
} }
/// Forward pass through the model /// Forward pass through the model
@ -140,80 +168,99 @@ impl OpenAiGptModel {
/// # Example /// # Example
/// ///
/// ```no_run /// ```no_run
///# use tch::{nn, Device, Tensor, no_grad}; /// # use tch::{nn, Device, Tensor, no_grad};
///# use rust_bert::Config; /// # use rust_bert::Config;
///# use std::path::Path; /// # use std::path::Path;
///# use tch::kind::Kind::{Int64, Double}; /// # use tch::kind::Kind::{Int64, Double};
/// use rust_bert::gpt2::Gpt2Config; /// use rust_bert::gpt2::Gpt2Config;
/// use rust_bert::openai_gpt::OpenAiGptModel; /// use rust_bert::openai_gpt::OpenAiGptModel;
///# let config_path = Path::new("path/to/config.json"); /// # let config_path = Path::new("path/to/config.json");
///# let vocab_path = Path::new("path/to/vocab.txt"); /// # let vocab_path = Path::new("path/to/vocab.txt");
///# let device = Device::Cpu; /// # let device = Device::Cpu;
///# let vs = nn::VarStore::new(device); /// # let vs = nn::VarStore::new(device);
///# let config = Gpt2Config::from_file(config_path); /// # let config = Gpt2Config::from_file(config_path);
///# let gpt_model: OpenAiGptModel = OpenAiGptModel::new(&vs.root(), &config); /// # let gpt_model: OpenAiGptModel = OpenAiGptModel::new(&vs.root(), &config);
/// let (batch_size, sequence_length, past_sequence_length) = (64, 128, 56); /// let (batch_size, sequence_length, past_sequence_length) = (64, 128, 56);
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device)); /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
/// let attention_mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device)); /// let attention_mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
/// let token_type_ids = Tensor::ones(&[batch_size, sequence_length], (Int64, device)); /// let token_type_ids = Tensor::ones(&[batch_size, sequence_length], (Int64, device));
/// let position_ids = Tensor::arange(sequence_length, (Int64, device)).expand(&[batch_size, sequence_length], true); /// let position_ids = Tensor::arange(sequence_length, (Int64, device))
/// /// .expand(&[batch_size, sequence_length], true);
/// let (output, hidden_states, attentions) = no_grad(|| {
/// gpt_model
/// .forward_t(&Some(input_tensor),
/// &Some(attention_mask),
/// &Some(token_type_ids),
/// &Some(position_ids),
/// &None,
/// false).unwrap()
/// });
/// ///
/// let (output, hidden_states, attentions) = no_grad(|| {
/// gpt_model
/// .forward_t(
/// &Some(input_tensor),
/// &Some(attention_mask),
/// &Some(token_type_ids),
/// &Some(position_ids),
/// &None,
/// false,
/// )
/// .unwrap()
/// });
/// ``` /// ```
/// pub fn forward_t(
pub fn forward_t(&self, &self,
input_ids: &Option<Tensor>, input_ids: &Option<Tensor>,
attention_mask: &Option<Tensor>, attention_mask: &Option<Tensor>,
token_type_ids: &Option<Tensor>, token_type_ids: &Option<Tensor>,
position_ids: &Option<Tensor>, position_ids: &Option<Tensor>,
input_embeds: &Option<Tensor>, input_embeds: &Option<Tensor>,
train: bool) -> Result<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> { train: bool,
) -> Result<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
let (input_embeddings, seq_length) = match input_ids { let (input_embeddings, seq_length) = match input_ids {
Some(input_value) => match input_embeds { Some(input_value) => match input_embeds {
Some(_) => { return Err("Only one of input ids or input embeddings may be set"); } Some(_) => {
None => (input_value.apply(&self.tokens_embed), *input_value.size().last().unwrap()) return Err("Only one of input ids or input embeddings may be set");
} }
None => (
input_value.apply(&self.tokens_embed),
*input_value.size().last().unwrap(),
),
},
None => match input_embeds { None => match input_embeds {
Some(embeds) => (embeds.copy(), embeds.size()[1]), Some(embeds) => (embeds.copy(), embeds.size()[1]),
None => { return Err("At least one of input ids or input embeddings must be set"); } None => {
} return Err("At least one of input ids or input embeddings must be set");
}
},
}; };
let position_ids = match position_ids { let position_ids = match position_ids {
Some(value) => value.copy(), Some(value) => value.copy(),
None => Tensor::arange(seq_length, (Int64, input_embeddings.device())).unsqueeze(0) None => Tensor::arange(seq_length, (Int64, input_embeddings.device())).unsqueeze(0),
}; };
let attention_mask: Option<Tensor> = match attention_mask { let attention_mask: Option<Tensor> = match attention_mask {
Some(value) => { Some(value) => Some(
Some( (value
(value .view((input_embeddings.size()[0], -1))
.view((input_embeddings.size()[0], -1)) .unsqueeze(1)
.unsqueeze(1) .unsqueeze(2)
.unsqueeze(2) - 1.0)
- 1.0 * 10000.0,
) * 10000.0) ),
} None => None,
None => None
}; };
let position_embeds = position_ids.apply(&self.positions_embed); let position_embeds = position_ids.apply(&self.positions_embed);
let token_type_embeds = match token_type_ids { let token_type_embeds = match token_type_ids {
Some(value) => value.apply(&self.tokens_embed), Some(value) => value.apply(&self.tokens_embed),
None => Tensor::zeros_like(&position_embeds) None => Tensor::zeros_like(&position_embeds),
};
let mut hidden_state: Tensor =
(input_embeddings + position_embeds + token_type_embeds).apply_t(&self.drop, train);
let mut all_hidden_states: Option<Vec<Tensor>> = if self.output_hidden_states {
Some(vec![])
} else {
None
};
let mut all_attentions: Option<Vec<Tensor>> = if self.output_attentions {
Some(vec![])
} else {
None
}; };
let mut hidden_state: Tensor = (input_embeddings + position_embeds + token_type_embeds).apply_t(&self.drop, train);
let mut all_hidden_states: Option<Vec<Tensor>> = if self.output_hidden_states { Some(vec!()) } else { None };
let mut all_attentions: Option<Vec<Tensor>> = if self.output_attentions { Some(vec!()) } else { None };
let mut layers = self.h.iter(); let mut layers = self.h.iter();
loop { loop {
@ -229,9 +276,9 @@ impl OpenAiGptModel {
attentions.push(temp.1.as_ref().unwrap().copy()); attentions.push(temp.1.as_ref().unwrap().copy());
}; };
} }
None => break None => break,
}; };
}; }
Ok((hidden_state, all_hidden_states, all_attentions)) Ok((hidden_state, all_hidden_states, all_attentions))
} }
@ -258,11 +305,11 @@ impl OpenAIGPTLMHeadModel {
/// # Example /// # Example
/// ///
/// ```no_run /// ```no_run
/// use tch::{nn, Device};
/// use rust_bert::Config;
/// use std::path::Path;
/// use rust_bert::gpt2::Gpt2Config; /// use rust_bert::gpt2::Gpt2Config;
/// use rust_bert::openai_gpt::OpenAIGPTLMHeadModel; /// use rust_bert::openai_gpt::OpenAIGPTLMHeadModel;
/// use rust_bert::Config;
/// use std::path::Path;
/// use tch::{nn, Device};
/// ///
/// let config_path = Path::new("path/to/config.json"); /// let config_path = Path::new("path/to/config.json");
/// let device = Device::Cpu; /// let device = Device::Cpu;
@ -270,11 +317,18 @@ impl OpenAIGPTLMHeadModel {
/// let config = Gpt2Config::from_file(config_path); /// let config = Gpt2Config::from_file(config_path);
/// let gpt2: OpenAIGPTLMHeadModel = OpenAIGPTLMHeadModel::new(&(&p.root() / "gpt"), &config); /// let gpt2: OpenAIGPTLMHeadModel = OpenAIGPTLMHeadModel::new(&(&p.root() / "gpt"), &config);
/// ``` /// ```
///
pub fn new(p: &nn::Path, config: &Gpt2Config) -> OpenAIGPTLMHeadModel { pub fn new(p: &nn::Path, config: &Gpt2Config) -> OpenAIGPTLMHeadModel {
let transformer = OpenAiGptModel::new(&p, config); let transformer = OpenAiGptModel::new(&p, config);
let lm_head = linear_no_bias(&(p / "lm_head"), config.n_embd, config.vocab_size, Default::default()); let lm_head = linear_no_bias(
OpenAIGPTLMHeadModel { transformer, lm_head } &(p / "lm_head"),
config.n_embd,
config.vocab_size,
Default::default(),
);
OpenAIGPTLMHeadModel {
transformer,
lm_head,
}
} }
} }
@ -305,19 +359,19 @@ impl LMHeadModel for OpenAIGPTLMHeadModel {
/// # Example /// # Example
/// ///
/// ```no_run /// ```no_run
///# use tch::{nn, Device, Tensor, no_grad}; /// # use tch::{nn, Device, Tensor, no_grad};
///# use rust_bert::Config; /// # use rust_bert::Config;
///# use std::path::Path; /// # use std::path::Path;
///# use tch::kind::Kind::{Int64, Double}; /// # use tch::kind::Kind::{Int64, Double};
/// use rust_bert::gpt2::Gpt2Config; /// use rust_bert::gpt2::Gpt2Config;
/// use rust_bert::openai_gpt::OpenAIGPTLMHeadModel; /// use rust_bert::openai_gpt::OpenAIGPTLMHeadModel;
/// use rust_bert::pipelines::generation::{LMHeadModel, Cache}; /// use rust_bert::pipelines::generation::{LMHeadModel, Cache};
///# let config_path = Path::new("path/to/config.json"); /// # let config_path = Path::new("path/to/config.json");
///# let vocab_path = Path::new("path/to/vocab.txt"); /// # let vocab_path = Path::new("path/to/vocab.txt");
///# let device = Device::Cpu; /// # let device = Device::Cpu;
///# let vs = nn::VarStore::new(device); /// # let vs = nn::VarStore::new(device);
///# let config = Gpt2Config::from_file(config_path); /// # let config = Gpt2Config::from_file(config_path);
///# let mut gpt_model: OpenAIGPTLMHeadModel = OpenAIGPTLMHeadModel::new(&vs.root(), &config); /// # let mut gpt_model: OpenAIGPTLMHeadModel = OpenAIGPTLMHeadModel::new(&vs.root(), &config);
/// let (batch_size, sequence_length, past_sequence_length) = (64, 128, 56); /// let (batch_size, sequence_length, past_sequence_length) = (64, 128, 56);
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device)); /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
/// let attention_mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device)); /// let attention_mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
@ -336,29 +390,44 @@ impl LMHeadModel for OpenAIGPTLMHeadModel {
/// &None, /// &None,
/// false).unwrap() /// false).unwrap()
/// }); /// });
///
/// ``` /// ```
/// fn forward_t(
fn forward_t(&self, &self,
input_ids: &Option<Tensor>, input_ids: &Option<Tensor>,
_layer_past: Cache, _layer_past: Cache,
attention_mask: &Option<Tensor>, attention_mask: &Option<Tensor>,
token_type_ids: &Option<Tensor>, token_type_ids: &Option<Tensor>,
position_ids: &Option<Tensor>, position_ids: &Option<Tensor>,
input_embeds: &Option<Tensor>, input_embeds: &Option<Tensor>,
_encoder_outputs: Option<&Tensor>, _encoder_outputs: Option<&Tensor>,
_decoder_input_ids: &Option<Tensor>, _decoder_input_ids: &Option<Tensor>,
train: bool) -> Result<(Tensor, Option<Tensor>, Cache, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> { train: bool,
let (output, ) -> Result<
all_hidden_states, (
all_attentions) = self.transformer.forward_t(input_ids, Tensor,
attention_mask, Option<Tensor>,
token_type_ids, Cache,
position_ids, Option<Vec<Tensor>>,
input_embeds, Option<Vec<Tensor>>,
train)?; ),
&'static str,
> {
let (output, all_hidden_states, all_attentions) = self.transformer.forward_t(
input_ids,
attention_mask,
token_type_ids,
position_ids,
input_embeds,
train,
)?;
let lm_logits = output.apply(&self.lm_head); let lm_logits = output.apply(&self.lm_head);
Ok((lm_logits, None, Cache::None, all_hidden_states, all_attentions)) Ok((
lm_logits,
None,
Cache::None,
all_hidden_states,
all_attentions,
))
} }
} }

View File

@ -13,9 +13,9 @@
// limitations under the License. // limitations under the License.
use crate::gpt2::attention::Attention; use crate::gpt2::attention::Attention;
use tch::{Tensor, nn};
use crate::gpt2::transformer::MLP; use crate::gpt2::transformer::MLP;
use crate::gpt2::Gpt2Config; use crate::gpt2::Gpt2Config;
use tch::{nn, Tensor};
pub struct Block { pub struct Block {
ln_1: nn::LayerNorm, ln_1: nn::LayerNorm,
@ -26,21 +26,33 @@ pub struct Block {
impl Block { impl Block {
pub fn new(p: &nn::Path, config: &Gpt2Config, scale: bool) -> Block { pub fn new(p: &nn::Path, config: &Gpt2Config, scale: bool) -> Block {
let layer_norm_config = nn::LayerNormConfig { eps: config.layer_norm_epsilon, ..Default::default() }; let layer_norm_config = nn::LayerNormConfig {
eps: config.layer_norm_epsilon,
..Default::default()
};
let ln_1 = nn::layer_norm(p / "ln_1", vec![config.n_embd], layer_norm_config); let ln_1 = nn::layer_norm(p / "ln_1", vec![config.n_embd], layer_norm_config);
let ln_2 = nn::layer_norm(p / "ln_2", vec![config.n_embd], layer_norm_config); let ln_2 = nn::layer_norm(p / "ln_2", vec![config.n_embd], layer_norm_config);
let attn = Attention::new(&(p / "attn"), config, scale); let attn = Attention::new(&(p / "attn"), config, scale);
let mlp = MLP::new(&(p / "mlp"), config); let mlp = MLP::new(&(p / "mlp"), config);
Block { ln_1, attn, ln_2, mlp } Block {
ln_1,
attn,
ln_2,
mlp,
}
} }
pub fn forward_t(&self, x: &Tensor, attention_mask: &Option<Tensor>, train: bool) pub fn forward_t(
-> (Tensor, Option<Tensor>) { &self,
x: &Tensor,
attention_mask: &Option<Tensor>,
train: bool,
) -> (Tensor, Option<Tensor>) {
let (output, _, attentions) = self.attn.forward_t(x, &None, attention_mask, train); let (output, _, attentions) = self.attn.forward_t(x, &None, attention_mask, train);
let x = (x + output).apply(&self.ln_1); let x = (x + output).apply(&self.ln_1);
let m = self.mlp.forward_t(&x, train); let m = self.mlp.forward_t(&x, train);
let x = (x + m).apply(&self.ln_2); let x = (x + m).apply(&self.ln_2);
(x, attentions) (x, attentions)
} }
} }

View File

@ -19,13 +19,13 @@
//! //!
use crate::bert::BertConfig; use crate::bert::BertConfig;
use crate::distilbert::DistilBertConfig; use crate::distilbert::DistilBertConfig;
use rust_tokenizers::{BertTokenizer, RobertaTokenizer, TokenizedInput, TruncationStrategy};
use rust_tokenizers::preprocessing::tokenizer::base_tokenizer::Tokenizer;
use std::path::Path;
use crate::Config;
use std::collections::HashMap;
use serde::{Serialize, Deserialize};
use crate::electra::ElectraConfig; use crate::electra::ElectraConfig;
use crate::Config;
use rust_tokenizers::preprocessing::tokenizer::base_tokenizer::Tokenizer;
use rust_tokenizers::{BertTokenizer, RobertaTokenizer, TokenizedInput, TruncationStrategy};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::Path;
#[derive(Clone, Copy, Serialize, Deserialize)] #[derive(Clone, Copy, Serialize, Deserialize)]
/// # Identifies the type of model /// # Identifies the type of model
@ -60,25 +60,42 @@ impl ConfigOption {
match model_type { match model_type {
ModelType::Bert | ModelType::Roberta => ConfigOption::Bert(BertConfig::from_file(path)), ModelType::Bert | ModelType::Roberta => ConfigOption::Bert(BertConfig::from_file(path)),
ModelType::DistilBert => ConfigOption::DistilBert(DistilBertConfig::from_file(path)), ModelType::DistilBert => ConfigOption::DistilBert(DistilBertConfig::from_file(path)),
ModelType::Electra => ConfigOption::Electra(ElectraConfig::from_file(path)) ModelType::Electra => ConfigOption::Electra(ElectraConfig::from_file(path)),
} }
} }
pub fn get_label_mapping(self) -> HashMap<i64, String> { pub fn get_label_mapping(self) -> HashMap<i64, String> {
match self { match self {
Self::Bert(config) => config.id2label.expect("No label dictionary (id2label) provided in configuration file"), Self::Bert(config) => config
Self::DistilBert(config) => config.id2label.expect("No label dictionary (id2label) provided in configuration file"), .id2label
Self::Electra(config) => config.id2label.expect("No label dictionary (id2label) provided in configuration file"), .expect("No label dictionary (id2label) provided in configuration file"),
Self::DistilBert(config) => config
.id2label
.expect("No label dictionary (id2label) provided in configuration file"),
Self::Electra(config) => config
.id2label
.expect("No label dictionary (id2label) provided in configuration file"),
} }
} }
} }
impl TokenizerOption { impl TokenizerOption {
/// Interface method to load a tokenizer from file /// Interface method to load a tokenizer from file
pub fn from_file(model_type: ModelType, vocab_path: &str, merges_path: Option<&str>, lower_case: bool) -> Self { pub fn from_file(
model_type: ModelType,
vocab_path: &str,
merges_path: Option<&str>,
lower_case: bool,
) -> Self {
match model_type { match model_type {
ModelType::Bert | ModelType::DistilBert | ModelType::Electra => TokenizerOption::Bert(BertTokenizer::from_file(vocab_path, lower_case)), ModelType::Bert | ModelType::DistilBert | ModelType::Electra => {
ModelType::Roberta => TokenizerOption::Roberta(RobertaTokenizer::from_file(vocab_path, merges_path.expect("No merges specified!"), lower_case)), TokenizerOption::Bert(BertTokenizer::from_file(vocab_path, lower_case))
}
ModelType::Roberta => TokenizerOption::Roberta(RobertaTokenizer::from_file(
vocab_path,
merges_path.expect("No merges specified!"),
lower_case,
)),
} }
} }
@ -86,15 +103,25 @@ impl TokenizerOption {
pub fn model_type(&self) -> ModelType { pub fn model_type(&self) -> ModelType {
match *self { match *self {
Self::Bert(_) => ModelType::Bert, Self::Bert(_) => ModelType::Bert,
Self::Roberta(_) => ModelType::Roberta Self::Roberta(_) => ModelType::Roberta,
} }
} }
/// Interface method /// Interface method
pub fn encode_list(&self, text_list: Vec<&str>, max_len: usize, truncation_strategy: &TruncationStrategy, stride: usize) -> Vec<TokenizedInput> { pub fn encode_list(
&self,
text_list: Vec<&str>,
max_len: usize,
truncation_strategy: &TruncationStrategy,
stride: usize,
) -> Vec<TokenizedInput> {
match *self { match *self {
Self::Bert(ref tokenizer) => tokenizer.encode_list(text_list, max_len, truncation_strategy, stride), Self::Bert(ref tokenizer) => {
Self::Roberta(ref tokenizer) => tokenizer.encode_list(text_list, max_len, truncation_strategy, stride) tokenizer.encode_list(text_list, max_len, truncation_strategy, stride)
}
Self::Roberta(ref tokenizer) => {
tokenizer.encode_list(text_list, max_len, truncation_strategy, stride)
}
} }
} }
} }

File diff suppressed because it is too large Load Diff

View File

@ -6,32 +6,29 @@
//! Extractive question answering from a given question and context. DistilBERT model finetuned on SQuAD (Stanford Question Answering Dataset) //! Extractive question answering from a given question and context. DistilBERT model finetuned on SQuAD (Stanford Question Answering Dataset)
//! //!
//! ```no_run //! ```no_run
//! use rust_bert::pipelines::question_answering::{QuestionAnsweringModel, QaInput}; //! use rust_bert::pipelines::question_answering::{QaInput, QuestionAnsweringModel};
//!# fn main() -> failure::Fallible<()> { //! # fn main() -> failure::Fallible<()> {
//! let qa_model = QuestionAnsweringModel::new(Default::default())?; //! let qa_model = QuestionAnsweringModel::new(Default::default())?;
//! //!
//! let question = String::from("Where does Amy live ?"); //! let question = String::from("Where does Amy live ?");
//! let context = String::from("Amy lives in Amsterdam"); //! let context = String::from("Amy lives in Amsterdam");
//! //!
//! let answers = qa_model.predict(&vec!(QaInput { question, context }), 1, 32); //! let answers = qa_model.predict(&vec![QaInput { question, context }], 1, 32);
//!# Ok(()) //! # Ok(())
//!# } //! # }
//! ``` //! ```
//! //!
//! Output: \ //! Output: \
//! ```no_run //! ```no_run
//!# use rust_bert::pipelines::question_answering::Answer; //! # use rust_bert::pipelines::question_answering::Answer;
//!# let output = //! # let output =
//! [ //! [Answer {
//! Answer { //! score: 0.9976,
//! score: 0.9976, //! start: 13,
//! start: 13, //! end: 21,
//! end: 21, //! answer: "Amsterdam", //#### # .to_owned()
//! answer: "Amsterdam" //! }]
//!# .to_owned() //! # ;
//! }
//! ]
//!# ;
//! ``` //! ```
//! //!
//! #### 2. Translation //! #### 2. Translation
@ -46,74 +43,75 @@
//! - English <-> Russian //! - English <-> Russian
//! - French <-> German //! - French <-> German
//! ```no_run //! ```no_run
//!# fn main() -> failure::Fallible<()> { //! # fn main() -> failure::Fallible<()> {
//!# use rust_bert::pipelines::generation::LanguageGenerator; //! # use rust_bert::pipelines::generation::LanguageGenerator;
//! use rust_bert::pipelines::translation::{TranslationModel, TranslationConfig, Language}; //! use rust_bert::pipelines::translation::{Language, TranslationConfig, TranslationModel};
//! use tch::Device; //! use tch::Device;
//! let translation_config = TranslationConfig::new(Language::EnglishToFrench, Device::cuda_if_available()); //! let translation_config =
//! TranslationConfig::new(Language::EnglishToFrench, Device::cuda_if_available());
//! let mut model = TranslationModel::new(translation_config)?; //! let mut model = TranslationModel::new(translation_config)?;
//! //!
//! let input = ["This is a sentence to be translated"]; //! let input = ["This is a sentence to be translated"];
//! //!
//! let output = model.translate(&input); //! let output = model.translate(&input);
//!# Ok(()) //! # Ok(())
//!# } //! # }
//! ``` //! ```
//! //!
//! Output: \ //! Output: \
//! ```no_run //! ```no_run
//!# let output = //! # let output =
//! "Il s'agit d'une phrase à traduire" //! "Il s'agit d'une phrase à traduire"
//!# ; //! # ;
//!``` //! ```
//! //!
//! #### 3. Summarization //! #### 3. Summarization
//! Abstractive summarization of texts based on the BART encoder-decoder architecture //! Abstractive summarization of texts based on the BART encoder-decoder architecture
//! Include techniques such as beam search, top-k and nucleus sampling, temperature setting and repetition penalty. //! Include techniques such as beam search, top-k and nucleus sampling, temperature setting and repetition penalty.
//! //!
//! ```no_run //! ```no_run
//!# fn main() -> failure::Fallible<()> { //! # fn main() -> failure::Fallible<()> {
//!# use rust_bert::pipelines::generation::LanguageGenerator; //! # use rust_bert::pipelines::generation::LanguageGenerator;
//! use rust_bert::pipelines::summarization::SummarizationModel; //! use rust_bert::pipelines::summarization::SummarizationModel;
//! //!
//! let mut model = SummarizationModel::new(Default::default())?; //! let mut model = SummarizationModel::new(Default::default())?;
//! //!
//! let input = ["In findings published Tuesday in Cornell University's arXiv by a team of scientists //! let input = ["In findings published Tuesday in Cornell University's arXiv by a team of scientists
//!from the University of Montreal and a separate report published Wednesday in Nature Astronomy by a team //! from the University of Montreal and a separate report published Wednesday in Nature Astronomy by a team
//!from University College London (UCL), the presence of water vapour was confirmed in the atmosphere of K2-18b, //! from University College London (UCL), the presence of water vapour was confirmed in the atmosphere of K2-18b,
//!a planet circling a star in the constellation Leo. This is the first such discovery in a planet in its star's //! a planet circling a star in the constellation Leo. This is the first such discovery in a planet in its star's
//!habitable zone — not too hot and not too cold for liquid water to exist. The Montreal team, led by Björn Benneke, //! habitable zone — not too hot and not too cold for liquid water to exist. The Montreal team, led by Björn Benneke,
//!used data from the NASA's Hubble telescope to assess changes in the light coming from K2-18b's star as the planet //! used data from the NASA's Hubble telescope to assess changes in the light coming from K2-18b's star as the planet
//!passed between it and Earth. They found that certain wavelengths of light, which are usually absorbed by water, //! passed between it and Earth. They found that certain wavelengths of light, which are usually absorbed by water,
//!weakened when the planet was in the way, indicating not only does K2-18b have an atmosphere, but the atmosphere //! weakened when the planet was in the way, indicating not only does K2-18b have an atmosphere, but the atmosphere
//!contains water in vapour form. The team from UCL then analyzed the Montreal team's data using their own software //! contains water in vapour form. The team from UCL then analyzed the Montreal team's data using their own software
//!and confirmed their conclusion. This was not the first time scientists have found signs of water on an exoplanet, //! and confirmed their conclusion. This was not the first time scientists have found signs of water on an exoplanet,
//!but previous discoveries were made on planets with high temperatures or other pronounced differences from Earth. //! but previous discoveries were made on planets with high temperatures or other pronounced differences from Earth.
//!\"This is the first potentially habitable planet where the temperature is right and where we now know there is water,\" //! \"This is the first potentially habitable planet where the temperature is right and where we now know there is water,\"
//!said UCL astronomer Angelos Tsiaras. \"It's the best candidate for habitability right now.\" \"It's a good sign\", //! said UCL astronomer Angelos Tsiaras. \"It's the best candidate for habitability right now.\" \"It's a good sign\",
//!said Ryan Cloutier of the HarvardSmithsonian Center for Astrophysics, who was not one of either study's authors. //! said Ryan Cloutier of the HarvardSmithsonian Center for Astrophysics, who was not one of either study's authors.
//!\"Overall,\" he continued, \"the presence of water in its atmosphere certainly improves the prospect of K2-18b being //! \"Overall,\" he continued, \"the presence of water in its atmosphere certainly improves the prospect of K2-18b being
//!a potentially habitable planet, but further observations will be required to say for sure. \" //! a potentially habitable planet, but further observations will be required to say for sure. \"
//!K2-18b was first identified in 2015 by the Kepler space telescope. It is about 110 light-years from Earth and larger //! K2-18b was first identified in 2015 by the Kepler space telescope. It is about 110 light-years from Earth and larger
//!but less dense. Its star, a red dwarf, is cooler than the Sun, but the planet's orbit is much closer, such that a year //! but less dense. Its star, a red dwarf, is cooler than the Sun, but the planet's orbit is much closer, such that a year
//!on K2-18b lasts 33 Earth days. According to The Guardian, astronomers were optimistic that NASA's James Webb space //! on K2-18b lasts 33 Earth days. According to The Guardian, astronomers were optimistic that NASA's James Webb space
//!telescope — scheduled for launch in 2021 — and the European Space Agency's 2028 ARIEL program, could reveal more //! telescope — scheduled for launch in 2021 — and the European Space Agency's 2028 ARIEL program, could reveal more
//!about exoplanets like K2-18b."]; //! about exoplanets like K2-18b."];
//! //!
//! let output = model.summarize(&input); //! let output = model.summarize(&input);
//!# Ok(()) //! # Ok(())
//!# } //! # }
//! ``` //! ```
//! (example from: [WikiNews](https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b)) //! (example from: [WikiNews](https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b))
//! //!
//! Example output: \ //! Example output: \
//! ```no_run //! ```no_run
//!# let output = //! # let output =
//! "Scientists have found water vapour on K2-18b, a planet 110 light-years from Earth. //! "Scientists have found water vapour on K2-18b, a planet 110 light-years from Earth.
//! This is the first such discovery in a planet in its star's habitable zone. //! This is the first such discovery in a planet in its star's habitable zone.
//! The planet is not too hot and not too cold for liquid water to exist." //! The planet is not too hot and not too cold for liquid water to exist."
//!# ; //! # ;
//!``` //! ```
//! //!
//! //!
//! #### 4. Natural Language Generation //! #### 4. Natural Language Generation
@ -124,18 +122,18 @@
//! //!
//! ```no_run //! ```no_run
//! use rust_bert::pipelines::generation::GPT2Generator; //! use rust_bert::pipelines::generation::GPT2Generator;
//!# fn main() -> failure::Fallible<()> { //! # fn main() -> failure::Fallible<()> {
//!# use rust_bert::pipelines::generation::LanguageGenerator; //! # use rust_bert::pipelines::generation::LanguageGenerator;
//! let mut model = GPT2Generator::new(Default::default())?; //! let mut model = GPT2Generator::new(Default::default())?;
//! let input_context_1 = "The dog"; //! let input_context_1 = "The dog";
//! let input_context_2 = "The cat was"; //! let input_context_2 = "The cat was";
//! let output = model.generate(Some(vec!(input_context_1, input_context_2)), None); //! let output = model.generate(Some(vec![input_context_1, input_context_2]), None);
//!# Ok(()) //! # Ok(())
//!# } //! # }
//! ``` //! ```
//! Example output: \ //! Example output: \
//! ```no_run //! ```no_run
//!# let output = //! # let output =
//! [ //! [
//! "The dog's owners, however, did not want to be named. According to the lawsuit, the animal's owner, a 29-year", //! "The dog's owners, however, did not want to be named. According to the lawsuit, the animal's owner, a 29-year",
//! "The dog has always been part of the family. \"He was always going to be my dog and he was always looking out for me", //! "The dog has always been part of the family. \"He was always going to be my dog and he was always looking out for me",
@ -144,14 +142,14 @@
//! "The cat was pulled from the street by two-year-old Jazmine.\"I didn't know what to do,\" she said", //! "The cat was pulled from the street by two-year-old Jazmine.\"I didn't know what to do,\" she said",
//! "The cat was attacked by two stray dogs and was taken to a hospital. Two other cats were also injured in the attack and are being treated." //! "The cat was attacked by two stray dogs and was taken to a hospital. Two other cats were also injured in the attack and are being treated."
//! ] //! ]
//!# ; //! # ;
//!``` //! ```
//! //!
//! #### 5. Sentiment analysis //! #### 5. Sentiment analysis
//! Predicts the binary sentiment for a sentence. DistilBERT model finetuned on SST-2. //! Predicts the binary sentiment for a sentence. DistilBERT model finetuned on SST-2.
//! ```no_run //! ```no_run
//! use rust_bert::pipelines::sentiment::SentimentModel; //! use rust_bert::pipelines::sentiment::SentimentModel;
//!# fn main() -> failure::Fallible<()> { //! # fn main() -> failure::Fallible<()> {
//! let sentiment_model = SentimentModel::new(Default::default())?; //! let sentiment_model = SentimentModel::new(Default::default())?;
//! let input = [ //! let input = [
//! "Probably my all-time favorite movie, a story of selflessness, sacrifice and dedication to a noble cause, but it's not preachy or boring.", //! "Probably my all-time favorite movie, a story of selflessness, sacrifice and dedication to a noble cause, but it's not preachy or boring.",
@ -159,59 +157,84 @@
//! "If you like original gut wrenching laughter you will like this movie. If you are young or old then you will love this movie, hell even my mom liked it.", //! "If you like original gut wrenching laughter you will like this movie. If you are young or old then you will love this movie, hell even my mom liked it.",
//! ]; //! ];
//! let output = sentiment_model.predict(&input); //! let output = sentiment_model.predict(&input);
//!# Ok(()) //! # Ok(())
//!# } //! # }
//! ``` //! ```
//! (Example courtesy of [IMDb](http://www.imdb.com)) //! (Example courtesy of [IMDb](http://www.imdb.com))
//! //!
//! Output: \ //! Output: \
//! ```no_run //! ```no_run
//!# use rust_bert::pipelines::sentiment::Sentiment; //! # use rust_bert::pipelines::sentiment::Sentiment;
//!# use rust_bert::pipelines::sentiment::SentimentPolarity::{Positive, Negative}; //! # use rust_bert::pipelines::sentiment::SentimentPolarity::{Positive, Negative};
//!# let output = //! # let output =
//! [ //! [
//! Sentiment { polarity: Positive, score: 0.998 }, //! Sentiment {
//! Sentiment { polarity: Negative, score: 0.992 }, //! polarity: Positive,
//! Sentiment { polarity: Positive, score: 0.999 } //! score: 0.998,
//! },
//! Sentiment {
//! polarity: Negative,
//! score: 0.992,
//! },
//! Sentiment {
//! polarity: Positive,
//! score: 0.999,
//! },
//! ] //! ]
//!# ; //! # ;
//! ``` //! ```
//! //!
//! #### 6. Named Entity Recognition //! #### 6. Named Entity Recognition
//! Extracts entities (Person, Location, Organization, Miscellaneous) from text. BERT cased large model finetuned on CoNNL03, contributed by the [MDZ Digital Library team at the Bavarian State Library](https://github.com/dbmdz) //! Extracts entities (Person, Location, Organization, Miscellaneous) from text. BERT cased large model finetuned on CoNNL03, contributed by the [MDZ Digital Library team at the Bavarian State Library](https://github.com/dbmdz)
//! ```no_run //! ```no_run
//! use rust_bert::pipelines::ner::NERModel; //! use rust_bert::pipelines::ner::NERModel;
//!# fn main() -> failure::Fallible<()> { //! # fn main() -> failure::Fallible<()> {
//! let ner_model = NERModel::new(Default::default())?; //! let ner_model = NERModel::new(Default::default())?;
//! let input = [ //! let input = [
//! "My name is Amy. I live in Paris.", //! "My name is Amy. I live in Paris.",
//! "Paris is a city in France." //! "Paris is a city in France.",
//! ]; //! ];
//! let output = ner_model.predict(&input); //! let output = ner_model.predict(&input);
//!# Ok(()) //! # Ok(())
//!# } //! # }
//! ``` //! ```
//! Output: \ //! Output: \
//! ```no_run //! ```no_run
//!# use rust_bert::pipelines::question_answering::Answer; //! # use rust_bert::pipelines::question_answering::Answer;
//!# use rust_bert::pipelines::ner::Entity; //! # use rust_bert::pipelines::ner::Entity;
//!# let output = //! # let output =
//! [ //! [
//! Entity { word: String::from("Amy"), score: 0.9986, label: String::from("I-PER") }, //! Entity {
//! Entity { word: String::from("Paris"), score: 0.9985, label: String::from("I-LOC") }, //! word: String::from("Amy"),
//! Entity { word: String::from("Paris"), score: 0.9988, label: String::from("I-LOC") }, //! score: 0.9986,
//! Entity { word: String::from("France"), score: 0.9993, label: String::from("I-LOC") }, //! label: String::from("I-PER"),
//! },
//! Entity {
//! word: String::from("Paris"),
//! score: 0.9985,
//! label: String::from("I-LOC"),
//! },
//! Entity {
//! word: String::from("Paris"),
//! score: 0.9988,
//! label: String::from("I-LOC"),
//! },
//! Entity {
//! word: String::from("France"),
//! score: 0.9993,
//! label: String::from("I-LOC"),
//! },
//! ] //! ]
//!# ; //! # ;
//! ``` //! ```
//! //!
pub mod sentiment;
pub mod common; pub mod common;
pub mod token_classification; pub mod generation;
pub mod sequence_classification;
pub mod ner; pub mod ner;
pub mod question_answering; pub mod question_answering;
pub mod generation; pub mod sentiment;
pub mod sequence_classification;
pub mod summarization; pub mod summarization;
pub mod token_classification;
pub mod translation; pub mod translation;

View File

@ -20,33 +20,48 @@
//! //!
//! ```no_run //! ```no_run
//! use rust_bert::pipelines::ner::NERModel; //! use rust_bert::pipelines::ner::NERModel;
//!# fn main() -> failure::Fallible<()> { //! # fn main() -> failure::Fallible<()> {
//! let ner_model = NERModel::new(Default::default())?; //! let ner_model = NERModel::new(Default::default())?;
//! //!
//! let input = [ //! let input = [
//! "My name is Amy. I live in Paris.", //! "My name is Amy. I live in Paris.",
//! "Paris is a city in France." //! "Paris is a city in France.",
//! ]; //! ];
//! let output = ner_model.predict(&input); //! let output = ner_model.predict(&input);
//!# Ok(()) //! # Ok(())
//!# } //! # }
//! ``` //! ```
//! Output: \ //! Output: \
//! ```no_run //! ```no_run
//!# use rust_bert::pipelines::question_answering::Answer; //! # use rust_bert::pipelines::question_answering::Answer;
//!# use rust_bert::pipelines::ner::Entity; //! # use rust_bert::pipelines::ner::Entity;
//!# let output = //! # let output =
//! [ //! [
//! Entity { word: String::from("Amy"), score: 0.9986, label: String::from("I-PER") }, //! Entity {
//! Entity { word: String::from("Paris"), score: 0.9985, label: String::from("I-LOC") }, //! word: String::from("Amy"),
//! Entity { word: String::from("Paris"), score: 0.9988, label: String::from("I-LOC") }, //! score: 0.9986,
//! Entity { word: String::from("France"), score: 0.9993, label: String::from("I-LOC") }, //! label: String::from("I-PER"),
//! },
//! Entity {
//! word: String::from("Paris"),
//! score: 0.9985,
//! label: String::from("I-LOC"),
//! },
//! Entity {
//! word: String::from("Paris"),
//! score: 0.9988,
//! label: String::from("I-LOC"),
//! },
//! Entity {
//! word: String::from("France"),
//! score: 0.9993,
//! label: String::from("I-LOC"),
//! },
//! ] //! ]
//!# ; //! # ;
//! ``` //! ```
use crate::pipelines::token_classification::{TokenClassificationModel, TokenClassificationConfig}; use crate::pipelines::token_classification::{TokenClassificationConfig, TokenClassificationModel};
#[derive(Debug)] #[derive(Debug)]
/// # Entity generated by a `NERModel` /// # Entity generated by a `NERModel`
@ -64,7 +79,7 @@ type NERConfig = TokenClassificationConfig;
/// # NERModel to extract named entities /// # NERModel to extract named entities
pub struct NERModel { pub struct NERModel {
token_classification_model: TokenClassificationModel token_classification_model: TokenClassificationModel,
} }
impl NERModel { impl NERModel {
@ -77,17 +92,18 @@ impl NERModel {
/// # Example /// # Example
/// ///
/// ```no_run /// ```no_run
///# fn main() -> failure::Fallible<()> { /// # fn main() -> failure::Fallible<()> {
/// use rust_bert::pipelines::ner::NERModel; /// use rust_bert::pipelines::ner::NERModel;
/// ///
/// let ner_model = NERModel::new(Default::default())?; /// let ner_model = NERModel::new(Default::default())?;
///# Ok(()) /// # Ok(())
///# } /// # }
/// ``` /// ```
///
pub fn new(ner_config: NERConfig) -> failure::Fallible<NERModel> { pub fn new(ner_config: NERConfig) -> failure::Fallible<NERModel> {
let model = TokenClassificationModel::new(ner_config)?; let model = TokenClassificationModel::new(ner_config)?;
Ok(NERModel { token_classification_model: model }) Ok(NERModel {
token_classification_model: model,
})
} }
/// Extract entities from a text /// Extract entities from a text
@ -103,30 +119,28 @@ impl NERModel {
/// # Example /// # Example
/// ///
/// ```no_run /// ```no_run
///# fn main() -> failure::Fallible<()> { /// # fn main() -> failure::Fallible<()> {
///# use rust_bert::pipelines::ner::NERModel; /// # use rust_bert::pipelines::ner::NERModel;
/// ///
/// let ner_model = NERModel::new(Default::default())?; /// let ner_model = NERModel::new(Default::default())?;
/// let input = [ /// let input = [
/// "My name is Amy. I live in Paris.", /// "My name is Amy. I live in Paris.",
/// "Paris is a city in France." /// "Paris is a city in France.",
/// ]; /// ];
/// let output = ner_model.predict(&input); /// let output = ner_model.predict(&input);
///# Ok(()) /// # Ok(())
///# } /// # }
/// ``` /// ```
///
pub fn predict(&self, input: &[&str]) -> Vec<Entity> { pub fn predict(&self, input: &[&str]) -> Vec<Entity> {
self.token_classification_model self.token_classification_model
.predict(input, true, false) .predict(input, true, false)
.into_iter() .into_iter()
.filter(|token| token.label != "O") .filter(|token| token.label != "O")
.map(|token| { .map(|token| Entity {
Entity { word: token.text,
word: token.text, score: token.score,
score: token.score, label: token.label,
label: token.label, })
} .collect()
}).collect()
} }
} }

View File

@ -17,48 +17,48 @@
//! The dependencies will be downloaded to the user's home directory, under ~/.cache/.rustbert/distilbert-qa //! The dependencies will be downloaded to the user's home directory, under ~/.cache/.rustbert/distilbert-qa
//! //!
//! ```no_run //! ```no_run
//! use rust_bert::pipelines::question_answering::{QuestionAnsweringModel, QaInput}; //! use rust_bert::pipelines::question_answering::{QaInput, QuestionAnsweringModel};
//! //!
//!# fn main() -> failure::Fallible<()> { //! # fn main() -> failure::Fallible<()> {
//! let qa_model = QuestionAnsweringModel::new(Default::default())?; //! let qa_model = QuestionAnsweringModel::new(Default::default())?;
//! //!
//! let question = String::from("Where does Amy live ?"); //! let question = String::from("Where does Amy live ?");
//! let context = String::from("Amy lives in Amsterdam"); //! let context = String::from("Amy lives in Amsterdam");
//! //!
//! let answers = qa_model.predict(&vec!(QaInput { question, context }), 1, 32); //! let answers = qa_model.predict(&vec![QaInput { question, context }], 1, 32);
//!# Ok(()) //! # Ok(())
//!# } //! # }
//! ``` //! ```
//! //!
//! Output: \ //! Output: \
//! ```no_run //! ```no_run
//!# use rust_bert::pipelines::question_answering::Answer; //! # use rust_bert::pipelines::question_answering::Answer;
//!# let output = //! # let output =
//! [ //! [Answer {
//! Answer { //! score: 0.9976,
//! score: 0.9976, //! start: 13,
//! start: 13, //! end: 21,
//! end: 21, //! answer: "Amsterdam", //#### # .to_owned()
//! answer: "Amsterdam" //! }]
//!# .to_owned() //! # ;
//! }
//! ]
//!# ;
//! ``` //! ```
use rust_tokenizers::{BertTokenizer, Tokenizer, TruncationStrategy, TokenizedInput}; use crate::common::resources::{download_resource, RemoteResource, Resource};
use rust_tokenizers::preprocessing::tokenizer::base_tokenizer::Mask; use crate::distilbert::{
use tch::{Device, Tensor, no_grad}; DistilBertConfig, DistilBertConfigResources, DistilBertForQuestionAnswering,
use std::path::PathBuf; DistilBertModelResources, DistilBertVocabResources,
use rust_tokenizers::tokenization_utils::truncate_sequences; };
use std::collections::HashMap;
use std::cmp::min;
use tch::nn::VarStore;
use tch::kind::Kind::Float;
use std::fs;
use crate::Config; use crate::Config;
use crate::distilbert::{DistilBertForQuestionAnswering, DistilBertConfig, DistilBertConfigResources, DistilBertModelResources, DistilBertVocabResources}; use rust_tokenizers::preprocessing::tokenizer::base_tokenizer::Mask;
use crate::common::resources::{Resource, RemoteResource, download_resource}; use rust_tokenizers::tokenization_utils::truncate_sequences;
use rust_tokenizers::{BertTokenizer, TokenizedInput, Tokenizer, TruncationStrategy};
use std::cmp::min;
use std::collections::HashMap;
use std::fs;
use std::path::PathBuf;
use tch::kind::Kind::Float;
use tch::nn::VarStore;
use tch::{no_grad, Device, Tensor};
/// # Input for Question Answering /// # Input for Question Answering
/// Includes a context (containing the answer) and question strings /// Includes a context (containing the answer) and question strings
@ -84,7 +84,6 @@ struct QaFeature {
pub token_to_orig_map: HashMap<i64, i64>, pub token_to_orig_map: HashMap<i64, i64>,
pub p_mask: Vec<i8>, pub p_mask: Vec<i8>,
pub example_index: i64, pub example_index: i64,
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
@ -102,34 +101,38 @@ pub struct Answer {
impl PartialEq for Answer { impl PartialEq for Answer {
fn eq(&self, other: &Self) -> bool { fn eq(&self, other: &Self) -> bool {
(self.start == other.start) && (self.start == other.start) && (self.end == other.end) && (self.answer == other.answer)
(self.end == other.end) &&
(self.answer == other.answer)
} }
} }
fn remove_duplicates<T: PartialEq + Clone>(vector: &mut Vec<T>) -> &mut Vec<T> { fn remove_duplicates<T: PartialEq + Clone>(vector: &mut Vec<T>) -> &mut Vec<T> {
let mut potential_duplicates = vec!(); let mut potential_duplicates = vec![];
vector.retain(|item| if potential_duplicates.contains(item) { vector.retain(|item| {
false if potential_duplicates.contains(item) {
} else { false
potential_duplicates.push(item.clone()); } else {
true potential_duplicates.push(item.clone());
true
}
}); });
vector vector
} }
impl QaExample { impl QaExample {
pub fn new(question: &str, context: &str) -> QaExample { pub fn new(question: &str, context: &str) -> QaExample {
let question = question.to_owned(); let question = question.to_owned();
let (doc_tokens, char_to_word_offset) = QaExample::split_context(context); let (doc_tokens, char_to_word_offset) = QaExample::split_context(context);
QaExample { question, context: context.to_owned(), doc_tokens, char_to_word_offset } QaExample {
question,
context: context.to_owned(),
doc_tokens,
char_to_word_offset,
}
} }
fn split_context(context: &str) -> (Vec<String>, Vec<i64>) { fn split_context(context: &str) -> (Vec<String>, Vec<i64>) {
let mut doc_tokens: Vec<String> = vec!(); let mut doc_tokens: Vec<String> = vec![];
let mut char_to_word_offset: Vec<i64> = vec!(); let mut char_to_word_offset: Vec<i64> = vec![];
let max_length = context.len(); let max_length = context.len();
let mut current_word = String::with_capacity(max_length); let mut current_word = String::with_capacity(max_length);
let mut previous_whitespace = false; let mut previous_whitespace = false;
@ -158,11 +161,11 @@ impl QaExample {
} }
fn is_whitespace(character: &char) -> bool { fn is_whitespace(character: &char) -> bool {
(character == &' ') | (character == &' ')
(character == &'\t') | | (character == &'\t')
(character == &'\r') | | (character == &'\r')
(character == &'\n') | | (character == &'\n')
(*character as u32 == 0x202F) | (*character as u32 == 0x202F)
} }
} }
@ -182,9 +185,15 @@ pub struct QuestionAnsweringConfig {
impl Default for QuestionAnsweringConfig { impl Default for QuestionAnsweringConfig {
fn default() -> QuestionAnsweringConfig { fn default() -> QuestionAnsweringConfig {
QuestionAnsweringConfig { QuestionAnsweringConfig {
model_resource: Resource::Remote(RemoteResource::from_pretrained(DistilBertModelResources::DISTIL_BERT_SQUAD)), model_resource: Resource::Remote(RemoteResource::from_pretrained(
config_resource: Resource::Remote(RemoteResource::from_pretrained(DistilBertConfigResources::DISTIL_BERT_SQUAD)), DistilBertModelResources::DISTIL_BERT_SQUAD,
vocab_resource: Resource::Remote(RemoteResource::from_pretrained(DistilBertVocabResources::DISTIL_BERT_SQUAD)), )),
config_resource: Resource::Remote(RemoteResource::from_pretrained(
DistilBertConfigResources::DISTIL_BERT_SQUAD,
)),
vocab_resource: Resource::Remote(RemoteResource::from_pretrained(
DistilBertVocabResources::DISTIL_BERT_SQUAD,
)),
device: Device::cuda_if_available(), device: Device::cuda_if_available(),
} }
} }
@ -213,26 +222,33 @@ impl QuestionAnsweringModel {
/// # Example /// # Example
/// ///
/// ```no_run /// ```no_run
///# fn main() -> failure::Fallible<()> { /// # fn main() -> failure::Fallible<()> {
/// use rust_bert::pipelines::question_answering::QuestionAnsweringModel; /// use rust_bert::pipelines::question_answering::QuestionAnsweringModel;
/// ///
/// let qa_model = QuestionAnsweringModel::new(Default::default())?; /// let qa_model = QuestionAnsweringModel::new(Default::default())?;
///# Ok(()) /// # Ok(())
///# } /// # }
/// ``` /// ```
/// pub fn new(
pub fn new(question_answering_config: QuestionAnsweringConfig) -> failure::Fallible<QuestionAnsweringModel> { question_answering_config: QuestionAnsweringConfig,
) -> failure::Fallible<QuestionAnsweringModel> {
let config_path = download_resource(&question_answering_config.config_resource)?; let config_path = download_resource(&question_answering_config.config_resource)?;
let vocab_path = download_resource(&question_answering_config.vocab_resource)?; let vocab_path = download_resource(&question_answering_config.vocab_resource)?;
let weights_path = download_resource(&question_answering_config.model_resource)?; let weights_path = download_resource(&question_answering_config.model_resource)?;
let device = question_answering_config.device; let device = question_answering_config.device;
let tokenizer = BertTokenizer::from_file(vocab_path.to_str().unwrap(), false); let tokenizer = BertTokenizer::from_file(vocab_path.to_str().unwrap(), false);
let pad_idx = *Tokenizer::vocab(&tokenizer).special_values.get("[PAD]").expect("[PAD] token not found in vocabulary"); let pad_idx = *Tokenizer::vocab(&tokenizer)
let sep_idx = *Tokenizer::vocab(&tokenizer).special_values.get("[SEP]").expect("[SEP] token not found in vocabulary"); .special_values
.get("[PAD]")
.expect("[PAD] token not found in vocabulary");
let sep_idx = *Tokenizer::vocab(&tokenizer)
.special_values
.get("[SEP]")
.expect("[SEP] token not found in vocabulary");
let mut var_store = VarStore::new(device); let mut var_store = VarStore::new(device);
let mut config = DistilBertConfig::from_file(config_path); let mut config = DistilBertConfig::from_file(config_path);
// The config for the current pre-trained question answering model indicates position embeddings which does not seem accurate // The config for the current pre-trained question answering model indicates position embeddings which does not seem accurate
config.sinusoidal_pos_embds = false; config.sinusoidal_pos_embds = false;
let distilbert_qa = DistilBertForQuestionAnswering::new(&var_store.root(), &config); let distilbert_qa = DistilBertForQuestionAnswering::new(&var_store.root(), &config);
var_store.load(weights_path)?; var_store.load(weights_path)?;
@ -249,7 +265,6 @@ impl QuestionAnsweringModel {
}) })
} }
/// Perform extractive question answering given a list of `QaInputs` /// Perform extractive question answering given a list of `QaInputs`
/// ///
/// # Arguments /// # Arguments
@ -264,25 +279,35 @@ impl QuestionAnsweringModel {
/// # Example /// # Example
/// ///
/// ```no_run /// ```no_run
///# fn main() -> failure::Fallible<()> { /// # fn main() -> failure::Fallible<()> {
/// use rust_bert::pipelines::question_answering::{QuestionAnsweringModel, QaInput}; /// use rust_bert::pipelines::question_answering::{QaInput, QuestionAnsweringModel};
/// ///
/// let qa_model = QuestionAnsweringModel::new(Default::default())?; /// let qa_model = QuestionAnsweringModel::new(Default::default())?;
/// ///
/// let question_1 = String::from("Where does Amy live ?"); /// let question_1 = String::from("Where does Amy live ?");
/// let context_1 = String::from("Amy lives in Amsterdam"); /// let context_1 = String::from("Amy lives in Amsterdam");
/// let question_2 = String::from("Where does Eric live"); /// let question_2 = String::from("Where does Eric live");
/// let context_2 = String::from("While Amy lives in Amsterdam, Eric is in The Hague."); /// let context_2 = String::from("While Amy lives in Amsterdam, Eric is in The Hague.");
/// ///
/// let qa_input_1 = QaInput { question: question_1, context: context_1 }; /// let qa_input_1 = QaInput {
/// let qa_input_2 = QaInput { question: question_2, context: context_2 }; /// question: question_1,
/// context: context_1,
/// };
/// let qa_input_2 = QaInput {
/// question: question_2,
/// context: context_2,
/// };
/// let answers = qa_model.predict(&[qa_input_1, qa_input_2], 1, 32); /// let answers = qa_model.predict(&[qa_input_1, qa_input_2], 1, 32);
/// ///
///# Ok(()) /// # Ok(())
///# } /// # }
/// ``` /// ```
/// pub fn predict(
pub fn predict(&self, qa_inputs: &[QaInput], top_k: i64, batch_size: usize) -> Vec<Vec<Answer>> { &self,
qa_inputs: &[QaInput],
top_k: i64,
batch_size: usize,
) -> Vec<Vec<Answer>> {
let examples: Vec<QaExample> = qa_inputs let examples: Vec<QaExample> = qa_inputs
.iter() .iter()
.map(|qa_input| QaExample::new(&qa_input.question, &qa_input.context)) .map(|qa_input| QaExample::new(&qa_input.question, &qa_input.context))
@ -290,7 +315,15 @@ impl QuestionAnsweringModel {
let features: Vec<QaFeature> = examples let features: Vec<QaFeature> = examples
.iter() .iter()
.enumerate() .enumerate()
.map(|(example_index, qa_example)| self.generate_features(&qa_example, self.max_seq_len, self.doc_stride, self.max_query_length, example_index as i64)) .map(|(example_index, qa_example)| {
self.generate_features(
&qa_example,
self.max_seq_len,
self.doc_stride,
self.max_query_length,
example_index as i64,
)
})
.flatten() .flatten()
.collect(); .collect();
@ -310,28 +343,36 @@ impl QuestionAnsweringModel {
} }
let input_ids = Tensor::stack(&input_ids, 0).to(self.var_store.device()); let input_ids = Tensor::stack(&input_ids, 0).to(self.var_store.device());
let attention_masks = Tensor::stack(&attention_masks, 0).to(self.var_store.device()); let attention_masks =
Tensor::stack(&attention_masks, 0).to(self.var_store.device());
let (start_logits, end_logits, _, _) = self.distilbert_qa.forward_t(Some(input_ids), Some(attention_masks), None, false).unwrap(); let (start_logits, end_logits, _, _) = self
.distilbert_qa
.forward_t(Some(input_ids), Some(attention_masks), None, false)
.unwrap();
let start_logits = start_logits.detach(); let start_logits = start_logits.detach();
let end_logits = end_logits.detach(); let end_logits = end_logits.detach();
let example_index_to_feature_end_position: Vec<(usize, i64)> = batch_features let example_index_to_feature_end_position: Vec<(usize, i64)> = batch_features
.iter() .iter()
.enumerate() .enumerate()
.map(|(feature_index, feature)| (feature.example_index as usize, feature_index as i64 + 1)) .map(|(feature_index, feature)| {
(feature.example_index as usize, feature_index as i64 + 1)
})
.collect(); .collect();
let mut feature_id_start = 0; let mut feature_id_start = 0;
for (example_id, max_feature_id) in example_index_to_feature_end_position { for (example_id, max_feature_id) in example_index_to_feature_end_position {
let mut answers: Vec<Answer> = vec!(); let mut answers: Vec<Answer> = vec![];
let example = &examples[example_id]; let example = &examples[example_id];
for feature_idx in feature_id_start..max_feature_id { for feature_idx in feature_id_start..max_feature_id {
let feature = &batch_features[feature_idx as usize]; let feature = &batch_features[feature_idx as usize];
let start = start_logits.get(feature_idx); let start = start_logits.get(feature_idx);
let end = end_logits.get(feature_idx); let end = end_logits.get(feature_idx);
let p_mask = (Tensor::of_slice(&feature.p_mask) - 1).abs().to_device(start.device()); let p_mask = (Tensor::of_slice(&feature.p_mask) - 1)
.abs()
.to_device(start.device());
let start: Tensor = start.exp() / start.exp().sum(Float) * &p_mask; let start: Tensor = start.exp() / start.exp().sum(Float) * &p_mask;
let end: Tensor = end.exp() / end.exp().sum(Float) * &p_mask; let end: Tensor = end.exp() / end.exp().sum(Float) * &p_mask;
@ -343,33 +384,42 @@ impl QuestionAnsweringModel {
let end_pos = feature.token_to_orig_map[&ends[idx]] as usize; let end_pos = feature.token_to_orig_map[&ends[idx]] as usize;
let answer = example.doc_tokens[start_pos..end_pos + 1].join(" "); let answer = example.doc_tokens[start_pos..end_pos + 1].join(" ");
let start = example.char_to_word_offset let start = example
.char_to_word_offset
.iter() .iter()
.position(|&v| v as usize == start_pos) .position(|&v| v as usize == start_pos)
.unwrap(); .unwrap();
let end = example.char_to_word_offset let end = example
.char_to_word_offset
.iter() .iter()
.rposition(|&v| v as usize == end_pos) .rposition(|&v| v as usize == end_pos)
.unwrap(); .unwrap();
answers.push(Answer { score: scores[idx], start, end, answer }); answers.push(Answer {
score: scores[idx],
start,
end,
answer,
});
} }
} }
feature_id_start = max_feature_id; feature_id_start = max_feature_id;
let example_answers = example_top_k_answers_map.entry(example_id).or_insert(vec!()); let example_answers = example_top_k_answers_map
.entry(example_id)
.or_insert(vec![]);
example_answers.extend(answers); example_answers.extend(answers);
} }
}); });
start = end; start = end;
} }
let mut all_answers = vec!(); let mut all_answers = vec![];
for example_id in 0..examples.len() { for example_id in 0..examples.len() {
if let Some(answers) = example_top_k_answers_map.get_mut(&example_id) { if let Some(answers) = example_top_k_answers_map.get_mut(&example_id) {
remove_duplicates(answers).sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap()); remove_duplicates(answers).sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
all_answers.push(answers[..min(answers.len(), top_k as usize)].to_vec()); all_answers.push(answers[..min(answers.len(), top_k as usize)].to_vec());
} else { } else {
all_answers.push(vec!()); all_answers.push(vec![]);
} }
} }
all_answers all_answers
@ -379,7 +429,10 @@ impl QuestionAnsweringModel {
let outer = start.unsqueeze(-1).matmul(&end.unsqueeze(0)); let outer = start.unsqueeze(-1).matmul(&end.unsqueeze(0));
let start_dim = start.size()[0]; let start_dim = start.size()[0];
let end_dim = end.size()[0]; let end_dim = end.size()[0];
let candidates = outer.triu(0).tril(self.max_answer_len as i64 - 1).flatten(0, -1); let candidates = outer
.triu(0)
.tril(self.max_answer_len as i64 - 1)
.flatten(0, -1);
let idx_sort = if top_k == 1 { let idx_sort = if top_k == 1 {
candidates.argmax(0, true) candidates.argmax(0, true)
} else if candidates.size()[0] < top_k { } else if candidates.size()[0] < top_k {
@ -387,9 +440,9 @@ impl QuestionAnsweringModel {
} else { } else {
candidates.argsort(0, true).slice(0, 0, top_k, 1) candidates.argsort(0, true).slice(0, 0, top_k, 1)
}; };
let mut start: Vec<i64> = vec!(); let mut start: Vec<i64> = vec![];
let mut end: Vec<i64> = vec!(); let mut end: Vec<i64> = vec![];
let mut scores: Vec<f64> = vec!(); let mut scores: Vec<f64> = vec![];
for flat_index_position in 0..idx_sort.size()[0] { for flat_index_position in 0..idx_sort.size()[0] {
let flat_index = idx_sort.int64_value(&[flat_index_position]); let flat_index = idx_sort.int64_value(&[flat_index_position]);
scores.push(candidates.double_value(&[flat_index])); scores.push(candidates.double_value(&[flat_index]));
@ -399,10 +452,16 @@ impl QuestionAnsweringModel {
(start, end, scores) (start, end, scores)
} }
fn generate_features(
fn generate_features(&self, qa_example: &QaExample, max_seq_length: usize, doc_stride: usize, max_query_length: usize, example_index: i64) -> Vec<QaFeature> { &self,
let mut tok_to_orig_index: Vec<i64> = vec!(); qa_example: &QaExample,
let mut all_doc_tokens: Vec<String> = vec!(); max_seq_length: usize,
doc_stride: usize,
max_query_length: usize,
example_index: i64,
) -> Vec<QaFeature> {
let mut tok_to_orig_index: Vec<i64> = vec![];
let mut all_doc_tokens: Vec<String> = vec![];
for (idx, token) in qa_example.doc_tokens.iter().enumerate() { for (idx, token) in qa_example.doc_tokens.iter().enumerate() {
let sub_tokens = self.tokenizer.tokenize(token); let sub_tokens = self.tokenizer.tokenize(token);
@ -414,28 +473,61 @@ impl QuestionAnsweringModel {
let truncated_query = self.prepare_query(&qa_example.question, max_query_length); let truncated_query = self.prepare_query(&qa_example.question, max_query_length);
let sequence_added_tokens = self.tokenizer.build_input_with_special_tokens(vec!(), None, vec!(), None, vec!(), None, vec!(), None).0.len(); let sequence_added_tokens = self
let sequence_pair_added_tokens = self.tokenizer.build_input_with_special_tokens(vec!(), Some(vec!()), vec!(), Some(vec!()), vec!(), Some(vec!()), vec!(), Some(vec!())).0.len(); .tokenizer
.build_input_with_special_tokens(vec![], None, vec![], None, vec![], None, vec![], None)
.0
.len();
let sequence_pair_added_tokens = self
.tokenizer
.build_input_with_special_tokens(
vec![],
Some(vec![]),
vec![],
Some(vec![]),
vec![],
Some(vec![]),
vec![],
Some(vec![]),
)
.0
.len();
let mut spans: Vec<QaFeature> = vec!(); let mut spans: Vec<QaFeature> = vec![];
let mut remaining_tokens = self.tokenizer.convert_tokens_to_ids(&all_doc_tokens); let mut remaining_tokens = self.tokenizer.convert_tokens_to_ids(&all_doc_tokens);
while (spans.len() * doc_stride as usize) < all_doc_tokens.len() { while (spans.len() * doc_stride as usize) < all_doc_tokens.len() {
let (encoded_span, attention_mask) = self.encode_qa_pair(&truncated_query, &remaining_tokens, max_seq_length, doc_stride, sequence_pair_added_tokens); let (encoded_span, attention_mask) = self.encode_qa_pair(
&truncated_query,
&remaining_tokens,
max_seq_length,
doc_stride,
sequence_pair_added_tokens,
);
let paragraph_len = min( let paragraph_len = min(
all_doc_tokens.len() - spans.len() * doc_stride, all_doc_tokens.len() - spans.len() * doc_stride,
max_seq_length - truncated_query.len() - sequence_pair_added_tokens); max_seq_length - truncated_query.len() - sequence_pair_added_tokens,
);
let mut token_to_orig_map = HashMap::new(); let mut token_to_orig_map = HashMap::new();
for i in 0..paragraph_len { for i in 0..paragraph_len {
let index = truncated_query.len() + sequence_added_tokens + i; let index = truncated_query.len() + sequence_added_tokens + i;
token_to_orig_map.insert(index as i64, tok_to_orig_index[spans.len() * doc_stride + i] as i64); token_to_orig_map.insert(
index as i64,
tok_to_orig_index[spans.len() * doc_stride + i] as i64,
);
} }
let p_mask = self.get_mask(&encoded_span); let p_mask = self.get_mask(&encoded_span);
let qa_feature = QaFeature { input_ids: encoded_span.token_ids, attention_mask, token_to_orig_map, p_mask, example_index }; let qa_feature = QaFeature {
input_ids: encoded_span.token_ids,
attention_mask,
token_to_orig_map,
p_mask,
example_index,
};
spans.push(qa_feature); spans.push(qa_feature);
if encoded_span.num_truncated_tokens == 0 { if encoded_span.num_truncated_tokens == 0 {
@ -447,59 +539,81 @@ impl QuestionAnsweringModel {
} }
fn prepare_query(&self, query: &str, max_query_length: usize) -> Vec<i64> { fn prepare_query(&self, query: &str, max_query_length: usize) -> Vec<i64> {
let truncated_query = self.tokenizer.convert_tokens_to_ids(&self.tokenizer.tokenize(&query)); let truncated_query = self
let num_query_tokens_to_remove = if truncated_query.len() > max_query_length as usize { truncated_query.len() - max_query_length } else { 0 }; .tokenizer
let (truncated_query, _, _, _, _, _, _, _, _, _) = truncate_sequences(truncated_query, .convert_tokens_to_ids(&self.tokenizer.tokenize(&query));
None, let num_query_tokens_to_remove = if truncated_query.len() > max_query_length as usize {
vec!(), truncated_query.len() - max_query_length
None, } else {
vec!(), 0
None, };
vec!(), let (truncated_query, _, _, _, _, _, _, _, _, _) = truncate_sequences(
None, truncated_query,
num_query_tokens_to_remove, None,
&TruncationStrategy::OnlyFirst, vec![],
0).unwrap(); None,
vec![],
None,
vec![],
None,
num_query_tokens_to_remove,
&TruncationStrategy::OnlyFirst,
0,
)
.unwrap();
truncated_query truncated_query
} }
fn encode_qa_pair(&self, fn encode_qa_pair(
truncated_query: &Vec<i64>, &self,
spans_token_ids: &Vec<i64>, truncated_query: &Vec<i64>,
max_seq_length: usize, spans_token_ids: &Vec<i64>,
doc_stride: usize, max_seq_length: usize,
sequence_pair_added_tokens: usize) -> (TokenizedInput, Vec<i64>) { doc_stride: usize,
sequence_pair_added_tokens: usize,
) -> (TokenizedInput, Vec<i64>) {
let len_1 = truncated_query.len(); let len_1 = truncated_query.len();
let len_2 = spans_token_ids.len(); let len_2 = spans_token_ids.len();
let total_len = len_1 + len_2 + sequence_pair_added_tokens; let total_len = len_1 + len_2 + sequence_pair_added_tokens;
let num_truncated_tokens = if total_len > max_seq_length { total_len - max_seq_length } else { 0 }; let num_truncated_tokens = if total_len > max_seq_length {
total_len - max_seq_length
} else {
0
};
let (truncated_query, truncated_context, _, _, _, _, _, _, overflowing_tokens, _) let (truncated_query, truncated_context, _, _, _, _, _, _, overflowing_tokens, _) =
= truncate_sequences(truncated_query.clone(), truncate_sequences(
Some(spans_token_ids.clone()), truncated_query.clone(),
vec!(), Some(spans_token_ids.clone()),
None, vec![],
vec!(), None,
None, vec![],
vec!(), None,
None, vec![],
num_truncated_tokens, None,
&TruncationStrategy::OnlySecond, num_truncated_tokens,
max_seq_length - doc_stride - len_1 - sequence_pair_added_tokens).unwrap(); &TruncationStrategy::OnlySecond,
max_seq_length - doc_stride - len_1 - sequence_pair_added_tokens,
)
.unwrap();
let (mut token_ids, let (
mut token_ids,
mut segment_ids, mut segment_ids,
special_tokens_mask, special_tokens_mask,
mut token_offsets, mut token_offsets,
mut reference_offsets, mut reference_offsets,
mut mask) = self.tokenizer.build_input_with_special_tokens(truncated_query, mut mask,
truncated_context, ) = self.tokenizer.build_input_with_special_tokens(
vec!(), truncated_query,
None, truncated_context,
vec!(), vec![],
None, None,
vec!(), vec![],
None); None,
vec![],
None,
);
let mut attention_mask = vec![1; token_ids.len()]; let mut attention_mask = vec![1; token_ids.len()];
if token_ids.len() < max_seq_length { if token_ids.len() < max_seq_length {
token_ids.append(&mut vec![self.pad_idx; max_seq_length - token_ids.len()]); token_ids.append(&mut vec![self.pad_idx; max_seq_length - token_ids.len()]);
@ -509,18 +623,32 @@ impl QuestionAnsweringModel {
reference_offsets.append(&mut vec![vec!(); max_seq_length - token_offsets.len()]); reference_offsets.append(&mut vec![vec!(); max_seq_length - token_offsets.len()]);
mask.append(&mut vec![Mask::Special; max_seq_length - mask.len()]); mask.append(&mut vec![Mask::Special; max_seq_length - mask.len()]);
} }
(TokenizedInput { token_ids, segment_ids, special_tokens_mask, overflowing_tokens, num_truncated_tokens, token_offsets, reference_offsets, mask }, attention_mask) (
TokenizedInput {
token_ids,
segment_ids,
special_tokens_mask,
overflowing_tokens,
num_truncated_tokens,
token_offsets,
reference_offsets,
mask,
},
attention_mask,
)
} }
fn get_mask(&self, encoded_span: &TokenizedInput) -> Vec<i8> { fn get_mask(&self, encoded_span: &TokenizedInput) -> Vec<i8> {
let sep_indices: Vec<usize> = encoded_span.token_ids let sep_indices: Vec<usize> = encoded_span
.token_ids
.iter() .iter()
.enumerate() .enumerate()
.filter(|(_, &value)| value == self.sep_idx) .filter(|(_, &value)| value == self.sep_idx)
.map(|(position, _)| position) .map(|(position, _)| position)
.collect(); .collect();
let mut p_mask: Vec<i8> = encoded_span.segment_ids let mut p_mask: Vec<i8> = encoded_span
.segment_ids
.iter() .iter()
.map(|v| min(v, &1i8)) .map(|v| min(v, &1i8))
.map(|&v| 1i8 - v) .map(|&v| 1i8 - v)
@ -534,10 +662,13 @@ impl QuestionAnsweringModel {
pub fn squad_processor(file_path: PathBuf) -> Vec<QaInput> { pub fn squad_processor(file_path: PathBuf) -> Vec<QaInput> {
let file = fs::File::open(file_path).expect("unable to open file"); let file = fs::File::open(file_path).expect("unable to open file");
let json: serde_json::Value = serde_json::from_reader(file).expect("JSON not properly formatted"); let json: serde_json::Value =
serde_json::from_reader(file).expect("JSON not properly formatted");
let data = json let data = json
.get("data").expect("SQuAD file does not contain data field") .get("data")
.as_array().expect("Data array not properly formatted"); .expect("SQuAD file does not contain data field")
.as_array()
.expect("Data array not properly formatted");
let mut qa_inputs: Vec<QaInput> = Vec::with_capacity(data.len()); let mut qa_inputs: Vec<QaInput> = Vec::with_capacity(data.len());
for qa_input in data.iter() { for qa_input in data.iter() {
@ -548,8 +679,17 @@ pub fn squad_processor(file_path: PathBuf) -> Vec<QaInput> {
let context = paragraph.get("context").unwrap().as_str().unwrap(); let context = paragraph.get("context").unwrap().as_str().unwrap();
let qas = paragraph.get("qas").unwrap().as_array().unwrap(); let qas = paragraph.get("qas").unwrap().as_array().unwrap();
for qa in qas.iter() { for qa in qas.iter() {
let question = qa.as_object().unwrap().get("question").unwrap().as_str().unwrap(); let question = qa
qa_inputs.push(QaInput { question: question.to_owned(), context: context.to_owned() }); .as_object()
.unwrap()
.get("question")
.unwrap()
.as_str()
.unwrap();
qa_inputs.push(QaInput {
question: question.to_owned(),
context: context.to_owned(),
});
} }
} }
} }

View File

@ -19,7 +19,7 @@
//! ```no_run //! ```no_run
//! use rust_bert::pipelines::sentiment::SentimentModel; //! use rust_bert::pipelines::sentiment::SentimentModel;
//! //!
//!# fn main() -> failure::Fallible<()> { //! # fn main() -> failure::Fallible<()> {
//! let sentiment_classifier = SentimentModel::new(Default::default())?; //! let sentiment_classifier = SentimentModel::new(Default::default())?;
//! let input = [ //! let input = [
//! "Probably my all-time favorite movie, a story of selflessness, sacrifice and dedication to a noble cause, but it's not preachy or boring.", //! "Probably my all-time favorite movie, a story of selflessness, sacrifice and dedication to a noble cause, but it's not preachy or boring.",
@ -27,29 +27,40 @@
//! "If you like original gut wrenching laughter you will like this movie. If you are young or old then you will love this movie, hell even my mom liked it.", //! "If you like original gut wrenching laughter you will like this movie. If you are young or old then you will love this movie, hell even my mom liked it.",
//! ]; //! ];
//! let output = sentiment_classifier.predict(&input); //! let output = sentiment_classifier.predict(&input);
//!# Ok(()) //! # Ok(())
//!# } //! # }
//! ``` //! ```
//! (Example courtesy of [IMDb](http://www.imdb.com)) //! (Example courtesy of [IMDb](http://www.imdb.com))
//! //!
//! Output: \ //! Output: \
//! ```no_run //! ```no_run
//!# use rust_bert::pipelines::sentiment::Sentiment; //! # use rust_bert::pipelines::sentiment::Sentiment;
//!# use rust_bert::pipelines::sentiment::SentimentPolarity::{Positive, Negative}; //! # use rust_bert::pipelines::sentiment::SentimentPolarity::{Positive, Negative};
//!# let output = //! # let output =
//! [ //! [
//! Sentiment { polarity: Positive, score: 0.998 }, //! Sentiment {
//! Sentiment { polarity: Negative, score: 0.992 }, //! polarity: Positive,
//! Sentiment { polarity: Positive, score: 0.999 } //! score: 0.998,
//! },
//! Sentiment {
//! polarity: Negative,
//! score: 0.992,
//! },
//! Sentiment {
//! polarity: Positive,
//! score: 0.999,
//! },
//! ] //! ]
//!# ; //! # ;
//! ``` //! ```
use std::path::PathBuf; use crate::pipelines::sequence_classification::{
use std::fs; SequenceClassificationConfig, SequenceClassificationModel,
};
use serde::Deserialize; use serde::Deserialize;
use std::error::Error; use std::error::Error;
use crate::pipelines::sequence_classification::{SequenceClassificationConfig, SequenceClassificationModel}; use std::fs;
use std::path::PathBuf;
#[derive(Debug, PartialEq)] #[derive(Debug, PartialEq)]
/// Enum with the possible sentiment polarities. Note that the pre-trained SST2 model does not include neutral sentiment. /// Enum with the possible sentiment polarities. Note that the pre-trained SST2 model does not include neutral sentiment.
@ -71,7 +82,7 @@ type SentimentConfig = SequenceClassificationConfig;
/// # SentimentClassifier to perform sentiment analysis /// # SentimentClassifier to perform sentiment analysis
pub struct SentimentModel { pub struct SentimentModel {
sequence_classification_model: SequenceClassificationModel sequence_classification_model: SequenceClassificationModel,
} }
impl SentimentModel { impl SentimentModel {
@ -84,17 +95,18 @@ impl SentimentModel {
/// # Example /// # Example
/// ///
/// ```no_run /// ```no_run
///# fn main() -> failure::Fallible<()> { /// # fn main() -> failure::Fallible<()> {
/// use rust_bert::pipelines::sentiment::SentimentModel; /// use rust_bert::pipelines::sentiment::SentimentModel;
/// ///
/// let sentiment_model = SentimentModel::new(Default::default())?; /// let sentiment_model = SentimentModel::new(Default::default())?;
///# Ok(()) /// # Ok(())
///# } /// # }
/// ``` /// ```
///
pub fn new(sentiment_config: SentimentConfig) -> failure::Fallible<SentimentModel> { pub fn new(sentiment_config: SentimentConfig) -> failure::Fallible<SentimentModel> {
let sequence_classification_model = SequenceClassificationModel::new(sentiment_config)?; let sequence_classification_model = SequenceClassificationModel::new(sentiment_config)?;
Ok(SentimentModel { sequence_classification_model }) Ok(SentimentModel {
sequence_classification_model,
})
} }
/// Extract sentiment form an array of text inputs /// Extract sentiment form an array of text inputs
@ -109,7 +121,7 @@ impl SentimentModel {
/// # Example /// # Example
/// ///
/// ```no_run /// ```no_run
///# fn main() -> failure::Fallible<()> { /// # fn main() -> failure::Fallible<()> {
/// use rust_bert::pipelines::sentiment::SentimentModel; /// use rust_bert::pipelines::sentiment::SentimentModel;
/// ///
/// let sentiment_classifier = SentimentModel::new(Default::default())?; /// let sentiment_classifier = SentimentModel::new(Default::default())?;
@ -121,17 +133,23 @@ impl SentimentModel {
/// ]; /// ];
/// ///
/// let output = sentiment_classifier.predict(&input); /// let output = sentiment_classifier.predict(&input);
///# Ok(()) /// # Ok(())
///# } /// # }
/// ``` /// ```
///
pub fn predict(&self, input: &[&str]) -> Vec<Sentiment> { pub fn predict(&self, input: &[&str]) -> Vec<Sentiment> {
let labels = self.sequence_classification_model.predict(input); let labels = self.sequence_classification_model.predict(input);
let mut sentiments = Vec::with_capacity(labels.len()); let mut sentiments = Vec::with_capacity(labels.len());
for label in labels { for label in labels {
let polarity = if label.id == 1 { SentimentPolarity::Positive } else { SentimentPolarity::Negative }; let polarity = if label.id == 1 {
sentiments.push(Sentiment { polarity, score: label.score }) SentimentPolarity::Positive
}; } else {
SentimentPolarity::Negative
};
sentiments.push(Sentiment {
polarity,
score: label.score,
})
}
sentiments sentiments
} }
} }
@ -154,4 +172,4 @@ pub fn ss2_processor(file_path: PathBuf) -> Result<Vec<String>, Box<dyn Error>>
records.push(record.sentence); records.push(record.sentence);
} }
Ok(records) Ok(records)
} }

View File

@ -19,7 +19,7 @@
//! use rust_bert::distilbert::{DistilBertModelResources, DistilBertVocabResources, DistilBertConfigResources}; //! use rust_bert::distilbert::{DistilBertModelResources, DistilBertVocabResources, DistilBertConfigResources};
//! use rust_bert::pipelines::sequence_classification::SequenceClassificationModel; //! use rust_bert::pipelines::sequence_classification::SequenceClassificationModel;
//! use rust_bert::pipelines::common::ModelType; //! use rust_bert::pipelines::common::ModelType;
//!# fn main() -> failure::Fallible<()> { //! # fn main() -> failure::Fallible<()> {
//! //!
//! //Load a configuration //! //Load a configuration
//! let config = SequenceClassificationConfig::new(ModelType::DistilBert, //! let config = SequenceClassificationConfig::new(ModelType::DistilBert,
@ -39,34 +39,38 @@
//! "If you like original gut wrenching laughter you will like this movie. If you are young or old then you will love this movie, hell even my mom liked it.", //! "If you like original gut wrenching laughter you will like this movie. If you are young or old then you will love this movie, hell even my mom liked it.",
//! ]; //! ];
//! let output = sequence_classification_model.predict(&input); //! let output = sequence_classification_model.predict(&input);
//!# Ok(()) //! # Ok(())
//!# } //! # }
//! ``` //! ```
//! (Example courtesy of [IMDb](http://www.imdb.com)) //! (Example courtesy of [IMDb](http://www.imdb.com))
//! //!
//! Output: \ //! Output: \
//! ```no_run //! ```no_run
//!# use rust_bert::pipelines::sequence_classification::Label; //! # use rust_bert::pipelines::sequence_classification::Label;
//! let output = //! let output =
//! [ //! [
//! Label { text: String::from("POSITIVE"), score: 0.9986, id: 1, sentence: 0}, //! Label { text: String::from("POSITIVE"), score: 0.9986, id: 1, sentence: 0},
//! Label { text: String::from("NEGATIVE"), score: 0.9985, id: 0, sentence: 1}, //! Label { text: String::from("NEGATIVE"), score: 0.9985, id: 0, sentence: 1},
//! Label { text: String::from("POSITIVE"), score: 0.9988, id: 1, sentence: 12}, //! Label { text: String::from("POSITIVE"), score: 0.9988, id: 1, sentence: 12},
//! ] //! ]
//!# ; //! # ;
//! ``` //! ```
//! //!
use tch::nn::VarStore;
use rust_tokenizers::preprocessing::tokenizer::base_tokenizer::{TokenizedInput, TruncationStrategy};
use std::collections::HashMap;
use tch::{Tensor, no_grad, Device, Kind};
use crate::bert::BertForSequenceClassification; use crate::bert::BertForSequenceClassification;
use crate::common::resources::{download_resource, RemoteResource, Resource};
use crate::distilbert::{
DistilBertConfigResources, DistilBertModelClassifier, DistilBertModelResources,
DistilBertVocabResources,
};
use crate::pipelines::common::{ConfigOption, ModelType, TokenizerOption};
use crate::roberta::RobertaForSequenceClassification; use crate::roberta::RobertaForSequenceClassification;
use crate::distilbert::{DistilBertModelResources, DistilBertConfigResources, DistilBertVocabResources, DistilBertModelClassifier}; use rust_tokenizers::preprocessing::tokenizer::base_tokenizer::{
use crate::common::resources::{Resource, RemoteResource, download_resource}; TokenizedInput, TruncationStrategy,
use serde::{Serialize, Deserialize}; };
use crate::pipelines::common::{ModelType, ConfigOption, TokenizerOption}; use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use tch::nn::VarStore;
use tch::{no_grad, Device, Kind, Tensor};
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
/// # Label generated by a `SequenceClassificationModel` /// # Label generated by a `SequenceClassificationModel`
@ -112,8 +116,14 @@ impl SequenceClassificationConfig {
/// * vocab - The `Resource' pointing to the tokenizer's vocabulary to load (e.g. vocab.txt/vocab.json) /// * vocab - The `Resource' pointing to the tokenizer's vocabulary to load (e.g. vocab.txt/vocab.json)
/// * vocab - An optional `Resource` tuple (`Option<Resource>`) pointing to the tokenizer's merge file to load (e.g. merges.txt), needed only for Roberta. /// * vocab - An optional `Resource` tuple (`Option<Resource>`) pointing to the tokenizer's merge file to load (e.g. merges.txt), needed only for Roberta.
/// * lower_case - A `bool' indicating whether the tokeniser should lower case all input (in case of a lower-cased model) /// * lower_case - A `bool' indicating whether the tokeniser should lower case all input (in case of a lower-cased model)
/// pub fn new(
pub fn new(model_type: ModelType, model_resource: Resource, config_resource: Resource, vocab_resource: Resource, merges_resource: Option<Resource>, lower_case: bool) -> SequenceClassificationConfig { model_type: ModelType,
model_resource: Resource,
config_resource: Resource,
vocab_resource: Resource,
merges_resource: Option<Resource>,
lower_case: bool,
) -> SequenceClassificationConfig {
SequenceClassificationConfig { SequenceClassificationConfig {
model_type, model_type,
model_resource, model_resource,
@ -131,9 +141,15 @@ impl Default for SequenceClassificationConfig {
fn default() -> SequenceClassificationConfig { fn default() -> SequenceClassificationConfig {
SequenceClassificationConfig { SequenceClassificationConfig {
model_type: ModelType::DistilBert, model_type: ModelType::DistilBert,
model_resource: Resource::Remote(RemoteResource::from_pretrained(DistilBertModelResources::DISTIL_BERT_SST2)), model_resource: Resource::Remote(RemoteResource::from_pretrained(
config_resource: Resource::Remote(RemoteResource::from_pretrained(DistilBertConfigResources::DISTIL_BERT_SST2)), DistilBertModelResources::DISTIL_BERT_SST2,
vocab_resource: Resource::Remote(RemoteResource::from_pretrained(DistilBertVocabResources::DISTIL_BERT_SST2)), )),
config_resource: Resource::Remote(RemoteResource::from_pretrained(
DistilBertConfigResources::DISTIL_BERT_SST2,
)),
vocab_resource: Resource::Remote(RemoteResource::from_pretrained(
DistilBertVocabResources::DISTIL_BERT_SST2,
)),
merges_resource: None, merges_resource: None,
lower_case: true, lower_case: true,
device: Device::cuda_if_available(), device: Device::cuda_if_available(),
@ -160,31 +176,38 @@ impl SequenceClassificationOption {
/// * `p` - `tch::nn::Path` path to the model file to load (e.g. model.ot) /// * `p` - `tch::nn::Path` path to the model file to load (e.g. model.ot)
/// * `config` - A configuration (the model type of the configuration must be compatible with the value for /// * `config` - A configuration (the model type of the configuration must be compatible with the value for
/// `model_type`) /// `model_type`)
///
pub fn new(model_type: ModelType, p: &tch::nn::Path, config: &ConfigOption) -> Self { pub fn new(model_type: ModelType, p: &tch::nn::Path, config: &ConfigOption) -> Self {
match model_type { match model_type {
ModelType::Bert => { ModelType::Bert => {
if let ConfigOption::Bert(config) = config { if let ConfigOption::Bert(config) = config {
SequenceClassificationOption::Bert(BertForSequenceClassification::new(p, config)) SequenceClassificationOption::Bert(BertForSequenceClassification::new(
p, config,
))
} else { } else {
panic!("You can only supply a BertConfig for Bert!"); panic!("You can only supply a BertConfig for Bert!");
} }
} }
ModelType::DistilBert => { ModelType::DistilBert => {
if let ConfigOption::DistilBert(config) = config { if let ConfigOption::DistilBert(config) = config {
SequenceClassificationOption::DistilBert(DistilBertModelClassifier::new(p, config)) SequenceClassificationOption::DistilBert(DistilBertModelClassifier::new(
p, config,
))
} else { } else {
panic!("You can only supply a DistilBertConfig for DistilBert!"); panic!("You can only supply a DistilBertConfig for DistilBert!");
} }
} }
ModelType::Roberta => { ModelType::Roberta => {
if let ConfigOption::Bert(config) = config { if let ConfigOption::Bert(config) = config {
SequenceClassificationOption::Roberta(RobertaForSequenceClassification::new(p, config)) SequenceClassificationOption::Roberta(RobertaForSequenceClassification::new(
p, config,
))
} else { } else {
panic!("You can only supply a BertConfig for Roberta!"); panic!("You can only supply a BertConfig for Roberta!");
} }
} }
ModelType::Electra => { panic!("SequenceClassification not implemented for Electra!"); } ModelType::Electra => {
panic!("SequenceClassification not implemented for Electra!");
}
} }
} }
@ -193,27 +216,44 @@ impl SequenceClassificationOption {
match *self { match *self {
Self::Bert(_) => ModelType::Bert, Self::Bert(_) => ModelType::Bert,
Self::Roberta(_) => ModelType::Roberta, Self::Roberta(_) => ModelType::Roberta,
Self::DistilBert(_) => ModelType::DistilBert Self::DistilBert(_) => ModelType::DistilBert,
} }
} }
/// Interface method to forward_t() of the particular models. /// Interface method to forward_t() of the particular models.
pub fn forward_t(&self, pub fn forward_t(
input_ids: Option<Tensor>, &self,
mask: Option<Tensor>, input_ids: Option<Tensor>,
token_type_ids: Option<Tensor>, mask: Option<Tensor>,
position_ids: Option<Tensor>, token_type_ids: Option<Tensor>,
input_embeds: Option<Tensor>, position_ids: Option<Tensor>,
train: bool) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) { input_embeds: Option<Tensor>,
train: bool,
) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
match *self { match *self {
Self::Bert(ref model) => model.forward_t(input_ids, mask, token_type_ids, position_ids, input_embeds, train), Self::Bert(ref model) => model.forward_t(
Self::DistilBert(ref model) => model.forward_t(input_ids, mask, input_embeds, train).expect("Error in distilbert forward_t"), input_ids,
Self::Roberta(ref model) => model.forward_t(input_ids, mask, token_type_ids, position_ids, input_embeds, train), mask,
token_type_ids,
position_ids,
input_embeds,
train,
),
Self::DistilBert(ref model) => model
.forward_t(input_ids, mask, input_embeds, train)
.expect("Error in distilbert forward_t"),
Self::Roberta(ref model) => model.forward_t(
input_ids,
mask,
token_type_ids,
position_ids,
input_embeds,
train,
),
} }
} }
} }
/// # SequenceClassificationModel for Classification (e.g. Sentiment Analysis) /// # SequenceClassificationModel for Classification (e.g. Sentiment Analysis)
pub struct SequenceClassificationModel { pub struct SequenceClassificationModel {
tokenizer: TokenizerOption, tokenizer: TokenizerOption,
@ -232,15 +272,16 @@ impl SequenceClassificationModel {
/// # Example /// # Example
/// ///
/// ```no_run /// ```no_run
///# fn main() -> failure::Fallible<()> { /// # fn main() -> failure::Fallible<()> {
/// use rust_bert::pipelines::sequence_classification::SequenceClassificationModel; /// use rust_bert::pipelines::sequence_classification::SequenceClassificationModel;
/// ///
/// let model = SequenceClassificationModel::new(Default::default())?; /// let model = SequenceClassificationModel::new(Default::default())?;
///# Ok(()) /// # Ok(())
///# } /// # }
/// ``` /// ```
/// pub fn new(
pub fn new(config: SequenceClassificationConfig) -> failure::Fallible<SequenceClassificationModel> { config: SequenceClassificationConfig,
) -> failure::Fallible<SequenceClassificationModel> {
let config_path = download_resource(&config.config_resource)?; let config_path = download_resource(&config.config_resource)?;
let vocab_path = download_resource(&config.vocab_resource)?; let vocab_path = download_resource(&config.vocab_resource)?;
let weights_path = download_resource(&config.model_resource)?; let weights_path = download_resource(&config.model_resource)?;
@ -251,31 +292,44 @@ impl SequenceClassificationModel {
}; };
let device = config.device; let device = config.device;
let tokenizer = TokenizerOption::from_file(config.model_type, vocab_path.to_str().unwrap(), merges_path.map(|path| path.to_str().unwrap()), config.lower_case); let tokenizer = TokenizerOption::from_file(
config.model_type,
vocab_path.to_str().unwrap(),
merges_path.map(|path| path.to_str().unwrap()),
config.lower_case,
);
let mut var_store = VarStore::new(device); let mut var_store = VarStore::new(device);
let model_config = ConfigOption::from_file(config.model_type, config_path); let model_config = ConfigOption::from_file(config.model_type, config_path);
let sequence_classifier = SequenceClassificationOption::new(config.model_type, &var_store.root(), &model_config); let sequence_classifier =
SequenceClassificationOption::new(config.model_type, &var_store.root(), &model_config);
let label_mapping = model_config.get_label_mapping(); let label_mapping = model_config.get_label_mapping();
var_store.load(weights_path)?; var_store.load(weights_path)?;
Ok(SequenceClassificationModel { tokenizer, sequence_classifier, label_mapping, var_store }) Ok(SequenceClassificationModel {
tokenizer,
sequence_classifier,
label_mapping,
var_store,
})
} }
fn prepare_for_model(&self, input: Vec<&str>) -> Tensor { fn prepare_for_model(&self, input: Vec<&str>) -> Tensor {
let tokenized_input: Vec<TokenizedInput> = self.tokenizer.encode_list(input.to_vec(), let tokenized_input: Vec<TokenizedInput> =
128, self.tokenizer
&TruncationStrategy::LongestFirst, .encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
0); let max_len = tokenized_input
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap(); .iter()
let tokenized_input_tensors: Vec<tch::Tensor> = tokenized_input. .map(|input| input.token_ids.len())
iter(). .max()
map(|input| input.token_ids.clone()). .unwrap();
map(|mut input| { let tokenized_input_tensors: Vec<tch::Tensor> = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]); input.extend(vec![0; max_len - input.len()]);
input input
}). })
map(|input| .map(|input| Tensor::of_slice(&(input)))
Tensor::of_slice(&(input))). .collect::<Vec<_>>();
collect::<Vec<_>>();
Tensor::stack(tokenized_input_tensors.as_slice(), 0).to(self.var_store.device()) Tensor::stack(tokenized_input_tensors.as_slice(), 0).to(self.var_store.device())
} }
@ -292,8 +346,8 @@ impl SequenceClassificationModel {
/// # Example /// # Example
/// ///
/// ```no_run /// ```no_run
///# fn main() -> failure::Fallible<()> { /// # fn main() -> failure::Fallible<()> {
///# use rust_bert::pipelines::sequence_classification::SequenceClassificationModel; /// # use rust_bert::pipelines::sequence_classification::SequenceClassificationModel;
/// ///
/// let sequence_classification_model = SequenceClassificationModel::new(Default::default())?; /// let sequence_classification_model = SequenceClassificationModel::new(Default::default())?;
/// let input = [ /// let input = [
@ -302,29 +356,36 @@ impl SequenceClassificationModel {
/// "If you like original gut wrenching laughter you will like this movie. If you are young or old then you will love this movie, hell even my mom liked it.", /// "If you like original gut wrenching laughter you will like this movie. If you are young or old then you will love this movie, hell even my mom liked it.",
/// ]; /// ];
/// let output = sequence_classification_model.predict(&input); /// let output = sequence_classification_model.predict(&input);
///# Ok(()) /// # Ok(())
///# } /// # }
/// ``` /// ```
pub fn predict(&self, input: &[&str]) -> Vec<Label> { pub fn predict(&self, input: &[&str]) -> Vec<Label> {
let input_tensor = self.prepare_for_model(input.to_vec()); let input_tensor = self.prepare_for_model(input.to_vec());
let output = no_grad(|| { let output = no_grad(|| {
let (output, _, _) = self.sequence_classifier let (output, _, _) = self.sequence_classifier.forward_t(
.forward_t(Some(input_tensor.copy()), Some(input_tensor.copy()),
None, None,
None, None,
None, None,
None, None,
false); false,
);
output.softmax(-1, Kind::Float).detach().to(Device::Cpu) output.softmax(-1, Kind::Float).detach().to(Device::Cpu)
}); });
let label_indices = output.as_ref().argmax(-1, true).squeeze1(1); let label_indices = output.as_ref().argmax(-1, true).squeeze1(1);
let scores = output.gather(1, &label_indices.unsqueeze(-1), false).squeeze1(1); let scores = output
.gather(1, &label_indices.unsqueeze(-1), false)
.squeeze1(1);
let label_indices = label_indices.iter::<i64>().unwrap().collect::<Vec<i64>>(); let label_indices = label_indices.iter::<i64>().unwrap().collect::<Vec<i64>>();
let scores = scores.iter::<f64>().unwrap().collect::<Vec<f64>>(); let scores = scores.iter::<f64>().unwrap().collect::<Vec<f64>>();
let mut labels: Vec<Label> = vec!(); let mut labels: Vec<Label> = vec![];
for sentence_idx in 0..label_indices.len() { for sentence_idx in 0..label_indices.len() {
let label_string = self.label_mapping.get(&label_indices[sentence_idx]).unwrap().clone(); let label_string = self
.label_mapping
.get(&label_indices[sentence_idx])
.unwrap()
.clone();
let label = Label { let label = Label {
text: label_string, text: label_string,
score: scores[sentence_idx], score: scores[sentence_idx],
@ -350,8 +411,8 @@ impl SequenceClassificationModel {
/// # Example /// # Example
/// ///
/// ```no_run /// ```no_run
///# fn main() -> failure::Fallible<()> { /// # fn main() -> failure::Fallible<()> {
///# use rust_bert::pipelines::sequence_classification::SequenceClassificationModel; /// # use rust_bert::pipelines::sequence_classification::SequenceClassificationModel;
/// ///
/// let sequence_classification_model = SequenceClassificationModel::new(Default::default())?; /// let sequence_classification_model = SequenceClassificationModel::new(Default::default())?;
/// let input = [ /// let input = [
@ -360,34 +421,37 @@ impl SequenceClassificationModel {
/// "If you like original gut wrenching laughter you will like this movie. If you are young or old then you will love this movie, hell even my mom liked it.", /// "If you like original gut wrenching laughter you will like this movie. If you are young or old then you will love this movie, hell even my mom liked it.",
/// ]; /// ];
/// let output = sequence_classification_model.predict_multilabel(&input, 0.5); /// let output = sequence_classification_model.predict_multilabel(&input, 0.5);
///# Ok(()) /// # Ok(())
///# } /// # }
/// ``` /// ```
pub fn predict_multilabel(&self, input: &[&str], threshold: f64) -> Vec<Vec<Label>> { pub fn predict_multilabel(&self, input: &[&str], threshold: f64) -> Vec<Vec<Label>> {
let input_tensor = self.prepare_for_model(input.to_vec()); let input_tensor = self.prepare_for_model(input.to_vec());
let output = no_grad(|| { let output = no_grad(|| {
let (output, _, _) = self.sequence_classifier let (output, _, _) = self.sequence_classifier.forward_t(
.forward_t(Some(input_tensor.copy()), Some(input_tensor.copy()),
None, None,
None, None,
None, None,
None, None,
false); false,
);
output.sigmoid().detach().to(Device::Cpu) output.sigmoid().detach().to(Device::Cpu)
}); });
let label_indices = output.as_ref().ge(threshold).nonzero(); let label_indices = output.as_ref().ge(threshold).nonzero();
let mut labels: Vec<Vec<Label>> = vec!(); let mut labels: Vec<Vec<Label>> = vec![];
let mut sequence_labels: Vec<Label> = vec!(); let mut sequence_labels: Vec<Label> = vec![];
for sentence_idx in 0..label_indices.size()[0] { for sentence_idx in 0..label_indices.size()[0] {
let label_index_tensor = label_indices.get(sentence_idx); let label_index_tensor = label_indices.get(sentence_idx);
let sentence_label = label_index_tensor.iter::<i64>().unwrap().collect::<Vec<i64>>(); let sentence_label = label_index_tensor
.iter::<i64>()
.unwrap()
.collect::<Vec<i64>>();
let (sentence, id) = (sentence_label[0], sentence_label[1]); let (sentence, id) = (sentence_label[0], sentence_label[1]);
if sentence as usize > labels.len() { if sentence as usize > labels.len() {
labels.push(sequence_labels); labels.push(sequence_labels);
sequence_labels = vec!(); sequence_labels = vec![];
} }
let score = output.double_value(sentence_label.as_slice()); let score = output.double_value(sentence_label.as_slice());
let label_string = self.label_mapping.get(&id).unwrap().to_owned(); let label_string = self.label_mapping.get(&id).unwrap().to_owned();
@ -398,7 +462,6 @@ impl SequenceClassificationModel {
sentence: sentence as usize, sentence: sentence as usize,
}; };
sequence_labels.push(label); sequence_labels.push(label);
} }
if sequence_labels.len() > 0 { if sequence_labels.len() > 0 {
labels.push(sequence_labels); labels.push(sequence_labels);

View File

@ -11,7 +11,6 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
//! # Summarization pipeline //! # Summarization pipeline
//! Abstractive summarization of texts based on the BART encoder-decoder architecture //! Abstractive summarization of texts based on the BART encoder-decoder architecture
//! Include techniques such as beam search, top-k and nucleus sampling, temperature setting and repetition penalty. //! Include techniques such as beam search, top-k and nucleus sampling, temperature setting and repetition penalty.
@ -21,52 +20,54 @@
//! //!
//! //!
//! ```no_run //! ```no_run
//!# fn main() -> failure::Fallible<()> { //! # fn main() -> failure::Fallible<()> {
//!# use rust_bert::pipelines::generation::LanguageGenerator; //! # use rust_bert::pipelines::generation::LanguageGenerator;
//! use rust_bert::pipelines::summarization::SummarizationModel; //! use rust_bert::pipelines::summarization::SummarizationModel;
//! let mut model = SummarizationModel::new(Default::default())?; //! let mut model = SummarizationModel::new(Default::default())?;
//! //!
//! let input = ["In findings published Tuesday in Cornell University's arXiv by a team of scientists //! let input = ["In findings published Tuesday in Cornell University's arXiv by a team of scientists
//!from the University of Montreal and a separate report published Wednesday in Nature Astronomy by a team //! from the University of Montreal and a separate report published Wednesday in Nature Astronomy by a team
//!from University College London (UCL), the presence of water vapour was confirmed in the atmosphere of K2-18b, //! from University College London (UCL), the presence of water vapour was confirmed in the atmosphere of K2-18b,
//!a planet circling a star in the constellation Leo. This is the first such discovery in a planet in its star's //! a planet circling a star in the constellation Leo. This is the first such discovery in a planet in its star's
//!habitable zone — not too hot and not too cold for liquid water to exist. The Montreal team, led by Björn Benneke, //! habitable zone — not too hot and not too cold for liquid water to exist. The Montreal team, led by Björn Benneke,
//!used data from the NASA's Hubble telescope to assess changes in the light coming from K2-18b's star as the planet //! used data from the NASA's Hubble telescope to assess changes in the light coming from K2-18b's star as the planet
//!passed between it and Earth. They found that certain wavelengths of light, which are usually absorbed by water, //! passed between it and Earth. They found that certain wavelengths of light, which are usually absorbed by water,
//!weakened when the planet was in the way, indicating not only does K2-18b have an atmosphere, but the atmosphere //! weakened when the planet was in the way, indicating not only does K2-18b have an atmosphere, but the atmosphere
//!contains water in vapour form. The team from UCL then analyzed the Montreal team's data using their own software //! contains water in vapour form. The team from UCL then analyzed the Montreal team's data using their own software
//!and confirmed their conclusion. This was not the first time scientists have found signs of water on an exoplanet, //! and confirmed their conclusion. This was not the first time scientists have found signs of water on an exoplanet,
//!but previous discoveries were made on planets with high temperatures or other pronounced differences from Earth. //! but previous discoveries were made on planets with high temperatures or other pronounced differences from Earth.
//!\"This is the first potentially habitable planet where the temperature is right and where we now know there is water,\" //! \"This is the first potentially habitable planet where the temperature is right and where we now know there is water,\"
//!said UCL astronomer Angelos Tsiaras. \"It's the best candidate for habitability right now.\" \"It's a good sign\", //! said UCL astronomer Angelos Tsiaras. \"It's the best candidate for habitability right now.\" \"It's a good sign\",
//!said Ryan Cloutier of the HarvardSmithsonian Center for Astrophysics, who was not one of either study's authors. //! said Ryan Cloutier of the HarvardSmithsonian Center for Astrophysics, who was not one of either study's authors.
//!\"Overall,\" he continued, \"the presence of water in its atmosphere certainly improves the prospect of K2-18b being //! \"Overall,\" he continued, \"the presence of water in its atmosphere certainly improves the prospect of K2-18b being
//!a potentially habitable planet, but further observations will be required to say for sure. \" //! a potentially habitable planet, but further observations will be required to say for sure. \"
//!K2-18b was first identified in 2015 by the Kepler space telescope. It is about 110 light-years from Earth and larger //! K2-18b was first identified in 2015 by the Kepler space telescope. It is about 110 light-years from Earth and larger
//!but less dense. Its star, a red dwarf, is cooler than the Sun, but the planet's orbit is much closer, such that a year //! but less dense. Its star, a red dwarf, is cooler than the Sun, but the planet's orbit is much closer, such that a year
//!on K2-18b lasts 33 Earth days. According to The Guardian, astronomers were optimistic that NASA's James Webb space //! on K2-18b lasts 33 Earth days. According to The Guardian, astronomers were optimistic that NASA's James Webb space
//!telescope — scheduled for launch in 2021 — and the European Space Agency's 2028 ARIEL program, could reveal more //! telescope — scheduled for launch in 2021 — and the European Space Agency's 2028 ARIEL program, could reveal more
//!about exoplanets like K2-18b."]; //! about exoplanets like K2-18b."];
//! //!
//! let output = model.summarize(&input); //! let output = model.summarize(&input);
//!# Ok(()) //! # Ok(())
//!# } //! # }
//! ``` //! ```
//! (New sample credits: [WikiNews](https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b)) //! (New sample credits: [WikiNews](https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b))
//! //!
//! Example output: \ //! Example output: \
//! ```no_run //! ```no_run
//!# let output = //! # let output =
//! "Scientists have found water vapour on K2-18b, a planet 110 light-years from Earth. //! "Scientists have found water vapour on K2-18b, a planet 110 light-years from Earth.
//! This is the first such discovery in a planet in its star's habitable zone. //! This is the first such discovery in a planet in its star's habitable zone.
//! The planet is not too hot and not too cold for liquid water to exist." //! The planet is not too hot and not too cold for liquid water to exist."
//!# ; //! # ;
//!``` //! ```
use crate::bart::{
BartConfigResources, BartMergesResources, BartModelResources, BartVocabResources,
};
use crate::common::resources::{RemoteResource, Resource};
use crate::pipelines::generation::{BartGenerator, GenerateConfig, LanguageGenerator}; use crate::pipelines::generation::{BartGenerator, GenerateConfig, LanguageGenerator};
use tch::Device; use tch::Device;
use crate::common::resources::{Resource, RemoteResource};
use crate::bart::{BartModelResources, BartConfigResources, BartVocabResources, BartMergesResources};
/// # Configuration for text summarization /// # Configuration for text summarization
/// Contains information regarding the model to load, mirrors the GenerationConfig, with a /// Contains information regarding the model to load, mirrors the GenerationConfig, with a
@ -111,10 +112,18 @@ pub struct SummarizationConfig {
impl Default for SummarizationConfig { impl Default for SummarizationConfig {
fn default() -> SummarizationConfig { fn default() -> SummarizationConfig {
SummarizationConfig { SummarizationConfig {
model_resource: Resource::Remote(RemoteResource::from_pretrained(BartModelResources::BART_CNN)), model_resource: Resource::Remote(RemoteResource::from_pretrained(
config_resource: Resource::Remote(RemoteResource::from_pretrained(BartConfigResources::BART_CNN)), BartModelResources::BART_CNN,
vocab_resource: Resource::Remote(RemoteResource::from_pretrained(BartVocabResources::BART_CNN)), )),
merges_resource: Resource::Remote(RemoteResource::from_pretrained(BartMergesResources::BART_CNN)), config_resource: Resource::Remote(RemoteResource::from_pretrained(
BartConfigResources::BART_CNN,
)),
vocab_resource: Resource::Remote(RemoteResource::from_pretrained(
BartVocabResources::BART_CNN,
)),
merges_resource: Resource::Remote(RemoteResource::from_pretrained(
BartMergesResources::BART_CNN,
)),
min_length: 56, min_length: 56,
max_length: 142, max_length: 142,
do_sample: false, do_sample: false,
@ -134,7 +143,7 @@ impl Default for SummarizationConfig {
/// # SummarizationModel to perform summarization /// # SummarizationModel to perform summarization
pub struct SummarizationModel { pub struct SummarizationModel {
model: BartGenerator model: BartGenerator,
} }
impl SummarizationModel { impl SummarizationModel {
@ -147,16 +156,14 @@ impl SummarizationModel {
/// # Example /// # Example
/// ///
/// ```no_run /// ```no_run
///# fn main() -> failure::Fallible<()> { /// # fn main() -> failure::Fallible<()> {
/// use rust_bert::pipelines::summarization::SummarizationModel; /// use rust_bert::pipelines::summarization::SummarizationModel;
/// ///
/// let mut summarization_model = SummarizationModel::new(Default::default())?; /// let mut summarization_model = SummarizationModel::new(Default::default())?;
///# Ok(()) /// # Ok(())
///# } /// # }
/// ``` /// ```
/// pub fn new(summarization_config: SummarizationConfig) -> failure::Fallible<SummarizationModel> {
pub fn new(summarization_config: SummarizationConfig)
-> failure::Fallible<SummarizationModel> {
let generate_config = GenerateConfig { let generate_config = GenerateConfig {
model_resource: summarization_config.model_resource, model_resource: summarization_config.model_resource,
config_resource: summarization_config.config_resource, config_resource: summarization_config.config_resource,
@ -194,40 +201,39 @@ impl SummarizationModel {
/// # Example /// # Example
/// ///
/// ```no_run /// ```no_run
///# fn main() -> failure::Fallible<()> { /// # fn main() -> failure::Fallible<()> {
/// use rust_bert::pipelines::generation::LanguageGenerator; /// use rust_bert::pipelines::generation::LanguageGenerator;
/// use rust_bert::pipelines::summarization::SummarizationModel; /// use rust_bert::pipelines::summarization::SummarizationModel;
/// let model = SummarizationModel::new(Default::default())?; /// let model = SummarizationModel::new(Default::default())?;
/// ///
/// let input = ["In findings published Tuesday in Cornell University's arXiv by a team of scientists /// let input = ["In findings published Tuesday in Cornell University's arXiv by a team of scientists
///from the University of Montreal and a separate report published Wednesday in Nature Astronomy by a team /// from the University of Montreal and a separate report published Wednesday in Nature Astronomy by a team
///from University College London (UCL), the presence of water vapour was confirmed in the atmosphere of K2-18b, /// from University College London (UCL), the presence of water vapour was confirmed in the atmosphere of K2-18b,
///a planet circling a star in the constellation Leo. This is the first such discovery in a planet in its star's /// a planet circling a star in the constellation Leo. This is the first such discovery in a planet in its star's
///habitable zone — not too hot and not too cold for liquid water to exist. The Montreal team, led by Björn Benneke, /// habitable zone — not too hot and not too cold for liquid water to exist. The Montreal team, led by Björn Benneke,
///used data from the NASA's Hubble telescope to assess changes in the light coming from K2-18b's star as the planet /// used data from the NASA's Hubble telescope to assess changes in the light coming from K2-18b's star as the planet
///passed between it and Earth. They found that certain wavelengths of light, which are usually absorbed by water, /// passed between it and Earth. They found that certain wavelengths of light, which are usually absorbed by water,
///weakened when the planet was in the way, indicating not only does K2-18b have an atmosphere, but the atmosphere /// weakened when the planet was in the way, indicating not only does K2-18b have an atmosphere, but the atmosphere
///contains water in vapour form. The team from UCL then analyzed the Montreal team's data using their own software /// contains water in vapour form. The team from UCL then analyzed the Montreal team's data using their own software
///and confirmed their conclusion. This was not the first time scientists have found signs of water on an exoplanet, /// and confirmed their conclusion. This was not the first time scientists have found signs of water on an exoplanet,
///but previous discoveries were made on planets with high temperatures or other pronounced differences from Earth. /// but previous discoveries were made on planets with high temperatures or other pronounced differences from Earth.
///\"This is the first potentially habitable planet where the temperature is right and where we now know there is water,\" /// \"This is the first potentially habitable planet where the temperature is right and where we now know there is water,\"
///said UCL astronomer Angelos Tsiaras. \"It's the best candidate for habitability right now.\" \"It's a good sign\", /// said UCL astronomer Angelos Tsiaras. \"It's the best candidate for habitability right now.\" \"It's a good sign\",
///said Ryan Cloutier of the HarvardSmithsonian Center for Astrophysics, who was not one of either study's authors. /// said Ryan Cloutier of the HarvardSmithsonian Center for Astrophysics, who was not one of either study's authors.
///\"Overall,\" he continued, \"the presence of water in its atmosphere certainly improves the prospect of K2-18b being /// \"Overall,\" he continued, \"the presence of water in its atmosphere certainly improves the prospect of K2-18b being
///a potentially habitable planet, but further observations will be required to say for sure. \" /// a potentially habitable planet, but further observations will be required to say for sure. \"
///K2-18b was first identified in 2015 by the Kepler space telescope. It is about 110 light-years from Earth and larger /// K2-18b was first identified in 2015 by the Kepler space telescope. It is about 110 light-years from Earth and larger
///but less dense. Its star, a red dwarf, is cooler than the Sun, but the planet's orbit is much closer, such that a year /// but less dense. Its star, a red dwarf, is cooler than the Sun, but the planet's orbit is much closer, such that a year
///on K2-18b lasts 33 Earth days. According to The Guardian, astronomers were optimistic that NASA's James Webb space /// on K2-18b lasts 33 Earth days. According to The Guardian, astronomers were optimistic that NASA's James Webb space
///telescope — scheduled for launch in 2021 — and the European Space Agency's 2028 ARIEL program, could reveal more /// telescope — scheduled for launch in 2021 — and the European Space Agency's 2028 ARIEL program, could reveal more
///about exoplanets like K2-18b."]; /// about exoplanets like K2-18b."];
/// ///
/// let output = model.summarize(&input); /// let output = model.summarize(&input);
///# Ok(()) /// # Ok(())
///# } /// # }
/// ``` /// ```
/// (New sample credits: [WikiNews](https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b)) /// (New sample credits: [WikiNews](https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b))
///
pub fn summarize(&self, texts: &[&str]) -> Vec<String> { pub fn summarize(&self, texts: &[&str]) -> Vec<String> {
self.model.generate(Some(texts.to_vec()), None) self.model.generate(Some(texts.to_vec()), None)
} }
} }

View File

@ -19,7 +19,7 @@
//! use rust_bert::resources::{Resource,RemoteResource}; //! use rust_bert::resources::{Resource,RemoteResource};
//! use rust_bert::bert::{BertModelResources, BertVocabResources, BertConfigResources}; //! use rust_bert::bert::{BertModelResources, BertVocabResources, BertConfigResources};
//! use rust_bert::pipelines::common::ModelType; //! use rust_bert::pipelines::common::ModelType;
//!# fn main() -> failure::Fallible<()> { //! # fn main() -> failure::Fallible<()> {
//! //!
//! //Load a configuration //! //Load a configuration
//! use rust_bert::pipelines::token_classification::LabelAggregationOption; //! use rust_bert::pipelines::token_classification::LabelAggregationOption;
@ -40,41 +40,94 @@
//! "Paris is a city in France." //! "Paris is a city in France."
//! ]; //! ];
//! let output = token_classification_model.predict(&input, true, true); //ignore_first_label = true (only returns the NER parts, ignoring first label O) //! let output = token_classification_model.predict(&input, true, true); //ignore_first_label = true (only returns the NER parts, ignoring first label O)
//!# Ok(()) //! # Ok(())
//!# } //! # }
//! ``` //! ```
//! Output: \ //! Output: \
//! ```no_run //! ```no_run
//!# use rust_bert::pipelines::token_classification::Token; //! # use rust_bert::pipelines::token_classification::Token;
//! use rust_tokenizers::preprocessing::tokenizer::base_tokenizer::Mask::Special; //! use rust_tokenizers::preprocessing::tokenizer::base_tokenizer::Mask::Special;
//! use rust_tokenizers::preprocessing::tokenizer::base_tokenizer::{Offset, Mask}; //! use rust_tokenizers::preprocessing::tokenizer::base_tokenizer::{Mask, Offset};
//!# let output = //! # let output =
//! [ //! [
//! Token { text: String::from("[CLS]"), score: 0.9995001554489136, label: String::from("O"), label_index: 0, sentence: 0, index: 0, word_index: 0, offset: None, mask: Special }, //! Token {
//! Token { text: String::from("My"), score: 0.9980450868606567, label: String::from("O"), label_index: 0, sentence: 0, index: 1, word_index: 1, offset: Some(Offset { begin: 0, end: 2 }), mask: Mask::None }, //! text: String::from("[CLS]"),
//! Token { text: String::from("name"), score: 0.9995062351226807, label: String::from("O"), label_index: 0, sentence: 0, index: 2, word_index: 2, offset: Some(Offset { begin: 3, end: 7 }), mask: Mask::None }, //! score: 0.9995001554489136,
//! Token { text: String::from("is"), score: 0.9997343420982361, label: String::from("O"), label_index: 0, sentence: 0, index: 3, word_index: 3, offset: Some(Offset { begin: 8, end: 10 }), mask: Mask::None }, //! label: String::from("O"),
//! Token { text: String::from("Amélie"), score: 0.9913727683112525, label: String::from("I-PER"), label_index: 4, sentence: 0, index: 4, word_index: 4, offset: Some(Offset { begin: 11, end: 17 }), mask: Mask::None } //! label_index: 0,
//! // ... //! sentence: 0,
//! index: 0,
//! word_index: 0,
//! offset: None,
//! mask: Special,
//! },
//! Token {
//! text: String::from("My"),
//! score: 0.9980450868606567,
//! label: String::from("O"),
//! label_index: 0,
//! sentence: 0,
//! index: 1,
//! word_index: 1,
//! offset: Some(Offset { begin: 0, end: 2 }),
//! mask: Mask::None,
//! },
//! Token {
//! text: String::from("name"),
//! score: 0.9995062351226807,
//! label: String::from("O"),
//! label_index: 0,
//! sentence: 0,
//! index: 2,
//! word_index: 2,
//! offset: Some(Offset { begin: 3, end: 7 }),
//! mask: Mask::None,
//! },
//! Token {
//! text: String::from("is"),
//! score: 0.9997343420982361,
//! label: String::from("O"),
//! label_index: 0,
//! sentence: 0,
//! index: 3,
//! word_index: 3,
//! offset: Some(Offset { begin: 8, end: 10 }),
//! mask: Mask::None,
//! },
//! Token {
//! text: String::from("Amélie"),
//! score: 0.9913727683112525,
//! label: String::from("I-PER"),
//! label_index: 4,
//! sentence: 0,
//! index: 4,
//! word_index: 4,
//! offset: Some(Offset { begin: 11, end: 17 }),
//! mask: Mask::None,
//! }, // ...
//! ] //! ]
//!# ; //! # ;
//! ``` //! ```
use tch::nn::VarStore; use crate::bert::{
use rust_tokenizers::preprocessing::tokenizer::base_tokenizer::{Tokenizer, TokenizedInput, TruncationStrategy, Mask, Offset, ConsolidatableTokens, ConsolidatedTokenIterator, TokenTrait}; BertConfigResources, BertForTokenClassification, BertModelResources, BertVocabResources,
use std::collections::HashMap; };
use tch::{Tensor, no_grad, Device}; use crate::common::resources::{download_resource, RemoteResource, Resource};
use tch::kind::Kind::Float;
use crate::bert::{BertForTokenClassification, BertModelResources, BertConfigResources, BertVocabResources};
use crate::roberta::RobertaForTokenClassification;
use crate::distilbert::DistilBertForTokenClassification; use crate::distilbert::DistilBertForTokenClassification;
use crate::common::resources::{Resource, RemoteResource, download_resource};
use crate::pipelines::common::{ModelType, ConfigOption, TokenizerOption};
use crate::electra::ElectraForTokenClassification; use crate::electra::ElectraForTokenClassification;
use crate::pipelines::common::{ConfigOption, ModelType, TokenizerOption};
use crate::roberta::RobertaForTokenClassification;
use itertools::Itertools; use itertools::Itertools;
use rust_tokenizers::preprocessing::tokenizer::base_tokenizer::{
ConsolidatableTokens, ConsolidatedTokenIterator, Mask, Offset, TokenTrait, TokenizedInput,
Tokenizer, TruncationStrategy,
};
use serde::{Deserialize, Serialize};
use std::cmp::min; use std::cmp::min;
use serde::{Serialize, Deserialize}; use std::collections::HashMap;
use tch::kind::Kind::Float;
use tch::nn::VarStore;
use tch::{no_grad, Device, Tensor};
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
/// # Token generated by a `TokenClassificationModel` /// # Token generated by a `TokenClassificationModel`
@ -140,7 +193,6 @@ pub enum LabelAggregationOption {
Custom(Box<dyn Fn(&[Token]) -> (i64, String)>), Custom(Box<dyn Fn(&[Token]) -> (i64, String)>),
} }
/// # Configuration for TokenClassificationModel /// # Configuration for TokenClassificationModel
/// Contains information regarding the model to load and device to place the model on. /// Contains information regarding the model to load and device to place the model on.
pub struct TokenClassificationConfig { pub struct TokenClassificationConfig {
@ -173,14 +225,15 @@ impl TokenClassificationConfig {
/// * vocab - The `Resource' pointing to the tokenizer's vocabulary to load (e.g. vocab.txt/vocab.json) /// * vocab - The `Resource' pointing to the tokenizer's vocabulary to load (e.g. vocab.txt/vocab.json)
/// * vocab - An optional `Resource` tuple (`Option<Resource>`) pointing to the tokenizer's merge file to load (e.g. merges.txt), needed only for Roberta. /// * vocab - An optional `Resource` tuple (`Option<Resource>`) pointing to the tokenizer's merge file to load (e.g. merges.txt), needed only for Roberta.
/// * lower_case - A `bool' indicating whether the tokenizer should lower case all input (in case of a lower-cased model) /// * lower_case - A `bool' indicating whether the tokenizer should lower case all input (in case of a lower-cased model)
/// pub fn new(
pub fn new(model_type: ModelType, model_type: ModelType,
model_resource: Resource, model_resource: Resource,
config_resource: Resource, config_resource: Resource,
vocab_resource: Resource, vocab_resource: Resource,
merges_resource: Option<Resource>, merges_resource: Option<Resource>,
lower_case: bool, lower_case: bool,
label_aggregation_function: LabelAggregationOption) -> TokenClassificationConfig { label_aggregation_function: LabelAggregationOption,
) -> TokenClassificationConfig {
TokenClassificationConfig { TokenClassificationConfig {
model_type, model_type,
model_resource, model_resource,
@ -199,9 +252,15 @@ impl Default for TokenClassificationConfig {
fn default() -> TokenClassificationConfig { fn default() -> TokenClassificationConfig {
TokenClassificationConfig { TokenClassificationConfig {
model_type: ModelType::Bert, model_type: ModelType::Bert,
model_resource: Resource::Remote(RemoteResource::from_pretrained(BertModelResources::BERT_NER)), model_resource: Resource::Remote(RemoteResource::from_pretrained(
config_resource: Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT_NER)), BertModelResources::BERT_NER,
vocab_resource: Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT_NER)), )),
config_resource: Resource::Remote(RemoteResource::from_pretrained(
BertConfigResources::BERT_NER,
)),
vocab_resource: Resource::Remote(RemoteResource::from_pretrained(
BertVocabResources::BERT_NER,
)),
merges_resource: None, merges_resource: None,
lower_case: false, lower_case: false,
device: Device::cuda_if_available(), device: Device::cuda_if_available(),
@ -231,7 +290,6 @@ impl TokenClassificationOption {
/// * `p` - `tch::nn::Path` path to the model file to load (e.g. model.ot) /// * `p` - `tch::nn::Path` path to the model file to load (e.g. model.ot)
/// * `config` - A configuration (the model type of the configuration must be compatible with the value for /// * `config` - A configuration (the model type of the configuration must be compatible with the value for
/// `model_type`) /// `model_type`)
///
pub fn new(model_type: ModelType, p: &tch::nn::Path, config: &ConfigOption) -> Self { pub fn new(model_type: ModelType, p: &tch::nn::Path, config: &ConfigOption) -> Self {
match model_type { match model_type {
ModelType::Bert => { ModelType::Bert => {
@ -243,21 +301,27 @@ impl TokenClassificationOption {
} }
ModelType::DistilBert => { ModelType::DistilBert => {
if let ConfigOption::DistilBert(config) = config { if let ConfigOption::DistilBert(config) = config {
TokenClassificationOption::DistilBert(DistilBertForTokenClassification::new(p, config)) TokenClassificationOption::DistilBert(DistilBertForTokenClassification::new(
p, config,
))
} else { } else {
panic!("You can only supply a DistilBertConfig for DistilBert!"); panic!("You can only supply a DistilBertConfig for DistilBert!");
} }
} }
ModelType::Roberta => { ModelType::Roberta => {
if let ConfigOption::Bert(config) = config { if let ConfigOption::Bert(config) = config {
TokenClassificationOption::Roberta(RobertaForTokenClassification::new(p, config)) TokenClassificationOption::Roberta(RobertaForTokenClassification::new(
p, config,
))
} else { } else {
panic!("You can only supply a BertConfig for Roberta!"); panic!("You can only supply a BertConfig for Roberta!");
} }
} }
ModelType::Electra => { ModelType::Electra => {
if let ConfigOption::Electra(config) = config { if let ConfigOption::Electra(config) = config {
TokenClassificationOption::Electra(ElectraForTokenClassification::new(p, config)) TokenClassificationOption::Electra(ElectraForTokenClassification::new(
p, config,
))
} else { } else {
panic!("You can only supply a BertConfig for Roberta!"); panic!("You can only supply a BertConfig for Roberta!");
} }
@ -271,27 +335,51 @@ impl TokenClassificationOption {
Self::Bert(_) => ModelType::Bert, Self::Bert(_) => ModelType::Bert,
Self::Roberta(_) => ModelType::Roberta, Self::Roberta(_) => ModelType::Roberta,
Self::DistilBert(_) => ModelType::DistilBert, Self::DistilBert(_) => ModelType::DistilBert,
Self::Electra(_) => ModelType::Electra Self::Electra(_) => ModelType::Electra,
} }
} }
fn forward_t(&self, fn forward_t(
input_ids: Option<Tensor>, &self,
mask: Option<Tensor>, input_ids: Option<Tensor>,
token_type_ids: Option<Tensor>, mask: Option<Tensor>,
position_ids: Option<Tensor>, token_type_ids: Option<Tensor>,
input_embeds: Option<Tensor>, position_ids: Option<Tensor>,
train: bool) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) { input_embeds: Option<Tensor>,
train: bool,
) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
match *self { match *self {
Self::Bert(ref model) => model.forward_t(input_ids, mask, token_type_ids, position_ids, input_embeds, train), Self::Bert(ref model) => model.forward_t(
Self::DistilBert(ref model) => model.forward_t(input_ids, mask, input_embeds, train).expect("Error in distilbert forward_t"), input_ids,
Self::Roberta(ref model) => model.forward_t(input_ids, mask, token_type_ids, position_ids, input_embeds, train), mask,
Self::Electra(ref model) => model.forward_t(input_ids, mask, token_type_ids, position_ids, input_embeds, train), token_type_ids,
position_ids,
input_embeds,
train,
),
Self::DistilBert(ref model) => model
.forward_t(input_ids, mask, input_embeds, train)
.expect("Error in distilbert forward_t"),
Self::Roberta(ref model) => model.forward_t(
input_ids,
mask,
token_type_ids,
position_ids,
input_embeds,
train,
),
Self::Electra(ref model) => model.forward_t(
input_ids,
mask,
token_type_ids,
position_ids,
input_embeds,
train,
),
} }
} }
} }
/// # TokenClassificationModel for Named Entity Recognition or Part-of-Speech tagging /// # TokenClassificationModel for Named Entity Recognition or Part-of-Speech tagging
pub struct TokenClassificationModel { pub struct TokenClassificationModel {
tokenizer: TokenizerOption, tokenizer: TokenizerOption,
@ -311,14 +399,13 @@ impl TokenClassificationModel {
/// # Example /// # Example
/// ///
/// ```no_run /// ```no_run
///# fn main() -> failure::Fallible<()> { /// # fn main() -> failure::Fallible<()> {
/// use rust_bert::pipelines::token_classification::TokenClassificationModel; /// use rust_bert::pipelines::token_classification::TokenClassificationModel;
/// ///
/// let model = TokenClassificationModel::new(Default::default())?; /// let model = TokenClassificationModel::new(Default::default())?;
///# Ok(()) /// # Ok(())
///# } /// # }
/// ``` /// ```
///
pub fn new(config: TokenClassificationConfig) -> failure::Fallible<TokenClassificationModel> { pub fn new(config: TokenClassificationConfig) -> failure::Fallible<TokenClassificationModel> {
let config_path = download_resource(&config.config_resource)?; let config_path = download_resource(&config.config_resource)?;
let vocab_path = download_resource(&config.vocab_resource)?; let vocab_path = download_resource(&config.vocab_resource)?;
@ -331,32 +418,49 @@ impl TokenClassificationModel {
let device = config.device; let device = config.device;
let label_aggregation_function = config.label_aggregation_function; let label_aggregation_function = config.label_aggregation_function;
let tokenizer = TokenizerOption::from_file(config.model_type, vocab_path.to_str().unwrap(), merges_path.map(|path| path.to_str().unwrap()), config.lower_case); let tokenizer = TokenizerOption::from_file(
config.model_type,
vocab_path.to_str().unwrap(),
merges_path.map(|path| path.to_str().unwrap()),
config.lower_case,
);
let mut var_store = VarStore::new(device); let mut var_store = VarStore::new(device);
let model_config = ConfigOption::from_file(config.model_type, config_path); let model_config = ConfigOption::from_file(config.model_type, config_path);
let token_sequence_classifier = TokenClassificationOption::new(config.model_type, &var_store.root(), &model_config); let token_sequence_classifier =
TokenClassificationOption::new(config.model_type, &var_store.root(), &model_config);
let label_mapping = model_config.get_label_mapping(); let label_mapping = model_config.get_label_mapping();
var_store.load(weights_path)?; var_store.load(weights_path)?;
Ok(TokenClassificationModel { tokenizer, token_sequence_classifier, label_mapping, var_store, label_aggregation_function }) Ok(TokenClassificationModel {
tokenizer,
token_sequence_classifier,
label_mapping,
var_store,
label_aggregation_function,
})
} }
fn prepare_for_model(&self, input: Vec<&str>) -> (Vec<TokenizedInput>, Tensor) { fn prepare_for_model(&self, input: Vec<&str>) -> (Vec<TokenizedInput>, Tensor) {
let tokenized_input: Vec<TokenizedInput> = self.tokenizer.encode_list(input.to_vec(), let tokenized_input: Vec<TokenizedInput> =
128, self.tokenizer
&TruncationStrategy::LongestFirst, .encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
0); let max_len = tokenized_input
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap(); .iter()
let tokenized_input_tensors: Vec<tch::Tensor> = tokenized_input. .map(|input| input.token_ids.len())
iter(). .max()
map(|input| input.token_ids.clone()). .unwrap();
map(|mut input| { let tokenized_input_tensors: Vec<tch::Tensor> = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]); input.extend(vec![0; max_len - input.len()]);
input input
}). })
map(|input| .map(|input| Tensor::of_slice(&(input)))
Tensor::of_slice(&(input))). .collect::<Vec<_>>();
collect::<Vec<_>>(); (
(tokenized_input, Tensor::stack(tokenized_input_tensors.as_slice(), 0).to(self.var_store.device())) tokenized_input,
Tensor::stack(tokenized_input_tensors.as_slice(), 0).to(self.var_store.device()),
)
} }
/// Classify tokens in a text sequence /// Classify tokens in a text sequence
@ -374,33 +478,39 @@ impl TokenClassificationModel {
/// # Example /// # Example
/// ///
/// ```no_run /// ```no_run
///# fn main() -> failure::Fallible<()> { /// # fn main() -> failure::Fallible<()> {
///# use rust_bert::pipelines::token_classification::TokenClassificationModel; /// # use rust_bert::pipelines::token_classification::TokenClassificationModel;
/// ///
/// let ner_model = TokenClassificationModel::new(Default::default())?; /// let ner_model = TokenClassificationModel::new(Default::default())?;
/// let input = [ /// let input = [
/// "My name is Amy. I live in Paris.", /// "My name is Amy. I live in Paris.",
/// "Paris is a city in France." /// "Paris is a city in France.",
/// ]; /// ];
/// let output = ner_model.predict(&input, true, true); /// let output = ner_model.predict(&input, true, true);
///# Ok(()) /// # Ok(())
///# } /// # }
/// ``` /// ```
pub fn predict(&self, input: &[&str], consolidate_sub_tokens: bool, return_special: bool) -> Vec<Token> { pub fn predict(
&self,
input: &[&str],
consolidate_sub_tokens: bool,
return_special: bool,
) -> Vec<Token> {
let (tokenized_input, input_tensor) = self.prepare_for_model(input.to_vec()); let (tokenized_input, input_tensor) = self.prepare_for_model(input.to_vec());
let (output, _, _) = no_grad(|| { let (output, _, _) = no_grad(|| {
self.token_sequence_classifier self.token_sequence_classifier.forward_t(
.forward_t(Some(input_tensor.copy()), Some(input_tensor.copy()),
None, None,
None, None,
None, None,
None, None,
false) false,
)
}); });
let output = output.detach().to(Device::Cpu); let output = output.detach().to(Device::Cpu);
let score: Tensor = output.exp() / output.exp().sum1(&[-1], true, Float); let score: Tensor = output.exp() / output.exp().sum1(&[-1], true, Float);
let labels_idx = &score.argmax(-1, true); let labels_idx = &score.argmax(-1, true);
let mut tokens: Vec<Token> = vec!(); let mut tokens: Vec<Token> = vec![];
for sentence_idx in 0..labels_idx.size()[0] { for sentence_idx in 0..labels_idx.size()[0] {
let labels = labels_idx.get(sentence_idx); let labels = labels_idx.get(sentence_idx);
let sentence_tokens = &tokenized_input[sentence_idx as usize]; let sentence_tokens = &tokenized_input[sentence_idx as usize];
@ -415,7 +525,16 @@ impl TokenClassificationModel {
word_idx += 1; word_idx += 1;
} }
let token = { let token = {
self.decode_token(&original_chars, sentence_tokens, &input_tensor, &labels, &score, sentence_idx, position_idx as i64, word_idx - 1) self.decode_token(
&original_chars,
sentence_tokens,
&input_tensor,
&labels,
&score,
sentence_idx,
position_idx as i64,
word_idx - 1,
)
}; };
tokens.push(token); tokens.push(token);
} }
@ -426,8 +545,17 @@ impl TokenClassificationModel {
tokens tokens
} }
fn decode_token(&self, original_sentence_chars: &Vec<char>, sentence_tokens: &TokenizedInput, input_tensor: &Tensor, fn decode_token(
labels: &Tensor, score: &Tensor, sentence_idx: i64, position_idx: i64, word_index: u16) -> Token { &self,
original_sentence_chars: &Vec<char>,
sentence_tokens: &TokenizedInput,
input_tensor: &Tensor,
labels: &Tensor,
score: &Tensor,
sentence_idx: i64,
position_idx: i64,
word_index: u16,
) -> Token {
let label_id = labels.int64_value(&[position_idx as i64]); let label_id = labels.int64_value(&[position_idx as i64]);
let token_id = input_tensor.int64_value(&[sentence_idx, position_idx as i64]); let token_id = input_tensor.int64_value(&[sentence_idx, position_idx as i64]);
@ -435,13 +563,19 @@ impl TokenClassificationModel {
let text = match offsets { let text = match offsets {
None => match self.tokenizer { None => match self.tokenizer {
TokenizerOption::Bert(ref tokenizer) => Tokenizer::decode(tokenizer, vec!(token_id), false, false), TokenizerOption::Bert(ref tokenizer) => {
TokenizerOption::Roberta(ref tokenizer) => Tokenizer::decode(tokenizer, vec!(token_id), false, false), Tokenizer::decode(tokenizer, vec![token_id], false, false)
}
TokenizerOption::Roberta(ref tokenizer) => {
Tokenizer::decode(tokenizer, vec![token_id], false, false)
}
}, },
Some(offsets) => { Some(offsets) => {
let (start_char, end_char) = (offsets.begin as usize, offsets.end as usize); let (start_char, end_char) = (offsets.begin as usize, offsets.end as usize);
let end_char = min(end_char, original_sentence_chars.len()); let end_char = min(end_char, original_sentence_chars.len());
let text = original_sentence_chars[start_char..end_char].iter().collect(); let text = original_sentence_chars[start_char..end_char]
.iter()
.collect();
text text
} }
}; };
@ -449,7 +583,11 @@ impl TokenClassificationModel {
Token { Token {
text, text,
score: score.double_value(&[sentence_idx, position_idx, label_id]), score: score.double_value(&[sentence_idx, position_idx, label_id]),
label: self.label_mapping.get(&label_id).expect("Index out of vocabulary bounds.").to_owned(), label: self
.label_mapping
.get(&label_id)
.expect("Index out of vocabulary bounds.")
.to_owned(),
label_index: label_id, label_index: label_id,
sentence: sentence_idx as usize, sentence: sentence_idx as usize,
index: position_idx as u16, index: position_idx as u16,
@ -459,24 +597,29 @@ impl TokenClassificationModel {
} }
} }
fn consolidate_tokens(&self, tokens: &mut Vec<Token>, label_aggregation_function: &LabelAggregationOption) { fn consolidate_tokens(
let mut tokens_to_replace = vec!(); &self,
tokens: &mut Vec<Token>,
label_aggregation_function: &LabelAggregationOption,
) {
let mut tokens_to_replace = vec![];
let mut token_iter = tokens.iter_consolidate_tokens(); let mut token_iter = tokens.iter_consolidate_tokens();
let mut cursor = 0; let mut cursor = 0;
while let Some(sub_tokens) = token_iter.next() { while let Some(sub_tokens) = token_iter.next() {
if sub_tokens.len() > 1 { if sub_tokens.len() > 1 {
let (label_index, label) = self.consolidate_labels(sub_tokens, label_aggregation_function); let (label_index, label) =
self.consolidate_labels(sub_tokens, label_aggregation_function);
let sentence = (&sub_tokens[0]).sentence; let sentence = (&sub_tokens[0]).sentence;
let index = (&sub_tokens[0]).index; let index = (&sub_tokens[0]).index;
let word_index = (&sub_tokens[0]).word_index; let word_index = (&sub_tokens[0]).word_index;
let offset_start = match &sub_tokens.first().unwrap().offset { let offset_start = match &sub_tokens.first().unwrap().offset {
Some(offset) => Some(offset.begin), Some(offset) => Some(offset.begin),
None => None None => None,
}; };
let offset_end = match &sub_tokens.last().unwrap().offset { let offset_end = match &sub_tokens.last().unwrap().offset {
Some(offset) => Some(offset.end), Some(offset) => Some(offset.end),
None => None None => None,
}; };
let offset = if offset_start.is_some() & offset_end.is_some() { let offset = if offset_start.is_some() & offset_end.is_some() {
Some(Offset::new(offset_start.unwrap(), offset_end.unwrap())) Some(Offset::new(offset_start.unwrap(), offset_end.unwrap()))
@ -513,7 +656,11 @@ impl TokenClassificationModel {
} }
} }
fn consolidate_labels(&self, tokens: &[Token], aggregation: &LabelAggregationOption) -> (i64, String) { fn consolidate_labels(
&self,
tokens: &[Token],
aggregation: &LabelAggregationOption,
) -> (i64, String) {
match aggregation { match aggregation {
LabelAggregationOption::First => { LabelAggregationOption::First => {
let token = tokens.first().unwrap(); let token = tokens.first().unwrap();
@ -524,22 +671,17 @@ impl TokenClassificationModel {
(token.label_index, token.label.clone()) (token.label_index, token.label.clone())
} }
LabelAggregationOption::Mode => { LabelAggregationOption::Mode => {
let counts = tokens let counts = tokens.iter().fold(HashMap::new(), |mut m, c| {
.iter() *m.entry((c.label_index, c.label.as_str())).or_insert(0) += 1;
.fold( m
HashMap::new(), });
|mut m, c| {
*m.entry((c.label_index, c.label.as_str())).or_insert(0) += 1;
m
},
);
counts counts
.into_iter() .into_iter()
.max_by(|a, b| a.1.cmp(&b.1)) .max_by(|a, b| a.1.cmp(&b.1))
.map(|((label_index, label), _)| (label_index, label.to_owned())) .map(|((label_index, label), _)| (label_index, label.to_owned()))
.unwrap() .unwrap()
} }
LabelAggregationOption::Custom(function) => function(tokens) LabelAggregationOption::Custom(function) => function(tokens),
} }
} }
} }

View File

@ -11,7 +11,6 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
//! # Translation pipeline //! # Translation pipeline
//! Translation based on the Marian encoder-decoder architecture //! Translation based on the Marian encoder-decoder architecture
//! Include techniques such as beam search, top-k and nucleus sampling, temperature setting and repetition penalty. //! Include techniques such as beam search, top-k and nucleus sampling, temperature setting and repetition penalty.
@ -33,31 +32,35 @@
//! //!
//! //!
//! ```no_run //! ```no_run
//!# fn main() -> failure::Fallible<()> { //! # fn main() -> failure::Fallible<()> {
//!# use rust_bert::pipelines::generation::LanguageGenerator; //! # use rust_bert::pipelines::generation::LanguageGenerator;
//! use rust_bert::pipelines::translation::{TranslationModel, TranslationConfig, Language}; //! use rust_bert::pipelines::translation::{Language, TranslationConfig, TranslationModel};
//! use tch::Device; //! use tch::Device;
//! let translation_config = TranslationConfig::new(Language::EnglishToFrench, Device::cuda_if_available()); //! let translation_config =
//! TranslationConfig::new(Language::EnglishToFrench, Device::cuda_if_available());
//! let mut model = TranslationModel::new(translation_config)?; //! let mut model = TranslationModel::new(translation_config)?;
//! //!
//! let input = ["This is a sentence to be translated"]; //! let input = ["This is a sentence to be translated"];
//! //!
//! let output = model.translate(&input); //! let output = model.translate(&input);
//!# Ok(()) //! # Ok(())
//!# } //! # }
//! ``` //! ```
//! //!
//! Output: \ //! Output: \
//! ```no_run //! ```no_run
//!# let output = //! # let output =
//! "Il s'agit d'une phrase à traduire" //! "Il s'agit d'une phrase à traduire"
//!# ; //! # ;
//!``` //! ```
use crate::pipelines::generation::{MarianGenerator, GenerateConfig, LanguageGenerator}; use crate::common::resources::{RemoteResource, Resource};
use crate::marian::{
MarianConfigResources, MarianModelResources, MarianPrefix, MarianSpmResources,
MarianVocabResources,
};
use crate::pipelines::generation::{GenerateConfig, LanguageGenerator, MarianGenerator};
use tch::Device; use tch::Device;
use crate::common::resources::{Resource, RemoteResource};
use crate::marian::{MarianModelResources, MarianConfigResources, MarianVocabResources, MarianSpmResources, MarianPrefix};
/// Pretrained languages available for direct use /// Pretrained languages available for direct use
pub enum Language { pub enum Language {
@ -84,47 +87,244 @@ pub enum Language {
struct RemoteTranslationResources; struct RemoteTranslationResources;
impl RemoteTranslationResources { impl RemoteTranslationResources {
pub const ENGLISH2FRENCH: ((&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), Option<&'static str>) = pub const ENGLISH2FRENCH: (
(MarianModelResources::ENGLISH2ROMANCE, MarianConfigResources::ENGLISH2ROMANCE, MarianVocabResources::ENGLISH2ROMANCE, MarianSpmResources::ENGLISH2ROMANCE, MarianPrefix::ENGLISH2FRENCH); (&'static str, &'static str),
pub const ENGLISH2CATALAN: ((&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), Option<&'static str>) = (&'static str, &'static str),
(MarianModelResources::ENGLISH2ROMANCE, MarianConfigResources::ENGLISH2ROMANCE, MarianVocabResources::ENGLISH2ROMANCE, MarianSpmResources::ENGLISH2ROMANCE, MarianPrefix::ENGLISH2CATALAN); (&'static str, &'static str),
pub const ENGLISH2SPANISH: ((&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), Option<&'static str>) = (&'static str, &'static str),
(MarianModelResources::ENGLISH2ROMANCE, MarianConfigResources::ENGLISH2ROMANCE, MarianVocabResources::ENGLISH2ROMANCE, MarianSpmResources::ENGLISH2ROMANCE, MarianPrefix::ENGLISH2SPANISH); Option<&'static str>,
pub const ENGLISH2PORTUGUESE: ((&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), Option<&'static str>) = ) = (
(MarianModelResources::ENGLISH2ROMANCE, MarianConfigResources::ENGLISH2ROMANCE, MarianVocabResources::ENGLISH2ROMANCE, MarianSpmResources::ENGLISH2ROMANCE, MarianPrefix::ENGLISH2PORTUGUESE); MarianModelResources::ENGLISH2ROMANCE,
pub const ENGLISH2ITALIAN: ((&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), Option<&'static str>) = MarianConfigResources::ENGLISH2ROMANCE,
(MarianModelResources::ENGLISH2ROMANCE, MarianConfigResources::ENGLISH2ROMANCE, MarianVocabResources::ENGLISH2ROMANCE, MarianSpmResources::ENGLISH2ROMANCE, MarianPrefix::ENGLISH2ITALIAN); MarianVocabResources::ENGLISH2ROMANCE,
pub const ENGLISH2ROMANIAN: ((&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), Option<&'static str>) = MarianSpmResources::ENGLISH2ROMANCE,
(MarianModelResources::ENGLISH2ROMANCE, MarianConfigResources::ENGLISH2ROMANCE, MarianVocabResources::ENGLISH2ROMANCE, MarianSpmResources::ENGLISH2ROMANCE, MarianPrefix::ENGLISH2ROMANIAN); MarianPrefix::ENGLISH2FRENCH,
pub const ENGLISH2GERMAN: ((&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), Option<&'static str>) = );
(MarianModelResources::ENGLISH2GERMAN, MarianConfigResources::ENGLISH2GERMAN, MarianVocabResources::ENGLISH2GERMAN, MarianSpmResources::ENGLISH2GERMAN, MarianPrefix::ENGLISH2GERMAN); pub const ENGLISH2CATALAN: (
pub const ENGLISH2RUSSIAN: ((&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), Option<&'static str>) = (&'static str, &'static str),
(MarianModelResources::ENGLISH2RUSSIAN, MarianConfigResources::ENGLISH2RUSSIAN, MarianVocabResources::ENGLISH2RUSSIAN, MarianSpmResources::ENGLISH2RUSSIAN, MarianPrefix::ENGLISH2RUSSIAN); (&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
) = (
MarianModelResources::ENGLISH2ROMANCE,
MarianConfigResources::ENGLISH2ROMANCE,
MarianVocabResources::ENGLISH2ROMANCE,
MarianSpmResources::ENGLISH2ROMANCE,
MarianPrefix::ENGLISH2CATALAN,
);
pub const ENGLISH2SPANISH: (
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
) = (
MarianModelResources::ENGLISH2ROMANCE,
MarianConfigResources::ENGLISH2ROMANCE,
MarianVocabResources::ENGLISH2ROMANCE,
MarianSpmResources::ENGLISH2ROMANCE,
MarianPrefix::ENGLISH2SPANISH,
);
pub const ENGLISH2PORTUGUESE: (
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
) = (
MarianModelResources::ENGLISH2ROMANCE,
MarianConfigResources::ENGLISH2ROMANCE,
MarianVocabResources::ENGLISH2ROMANCE,
MarianSpmResources::ENGLISH2ROMANCE,
MarianPrefix::ENGLISH2PORTUGUESE,
);
pub const ENGLISH2ITALIAN: (
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
) = (
MarianModelResources::ENGLISH2ROMANCE,
MarianConfigResources::ENGLISH2ROMANCE,
MarianVocabResources::ENGLISH2ROMANCE,
MarianSpmResources::ENGLISH2ROMANCE,
MarianPrefix::ENGLISH2ITALIAN,
);
pub const ENGLISH2ROMANIAN: (
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
) = (
MarianModelResources::ENGLISH2ROMANCE,
MarianConfigResources::ENGLISH2ROMANCE,
MarianVocabResources::ENGLISH2ROMANCE,
MarianSpmResources::ENGLISH2ROMANCE,
MarianPrefix::ENGLISH2ROMANIAN,
);
pub const ENGLISH2GERMAN: (
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
) = (
MarianModelResources::ENGLISH2GERMAN,
MarianConfigResources::ENGLISH2GERMAN,
MarianVocabResources::ENGLISH2GERMAN,
MarianSpmResources::ENGLISH2GERMAN,
MarianPrefix::ENGLISH2GERMAN,
);
pub const ENGLISH2RUSSIAN: (
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
) = (
MarianModelResources::ENGLISH2RUSSIAN,
MarianConfigResources::ENGLISH2RUSSIAN,
MarianVocabResources::ENGLISH2RUSSIAN,
MarianSpmResources::ENGLISH2RUSSIAN,
MarianPrefix::ENGLISH2RUSSIAN,
);
pub const FRENCH2ENGLISH: ((&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), Option<&'static str>) = pub const FRENCH2ENGLISH: (
(MarianModelResources::ROMANCE2ENGLISH, MarianConfigResources::ROMANCE2ENGLISH, MarianVocabResources::ROMANCE2ENGLISH, MarianSpmResources::ROMANCE2ENGLISH, MarianPrefix::FRENCH2ENGLISH); (&'static str, &'static str),
pub const CATALAN2ENGLISH: ((&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), Option<&'static str>) = (&'static str, &'static str),
(MarianModelResources::ROMANCE2ENGLISH, MarianConfigResources::ROMANCE2ENGLISH, MarianVocabResources::ROMANCE2ENGLISH, MarianSpmResources::ROMANCE2ENGLISH, MarianPrefix::CATALAN2ENGLISH); (&'static str, &'static str),
pub const SPANISH2ENGLISH: ((&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), Option<&'static str>) = (&'static str, &'static str),
(MarianModelResources::ROMANCE2ENGLISH, MarianConfigResources::ROMANCE2ENGLISH, MarianVocabResources::ROMANCE2ENGLISH, MarianSpmResources::ROMANCE2ENGLISH, MarianPrefix::SPANISH2ENGLISH); Option<&'static str>,
pub const PORTUGUESE2ENGLISH: ((&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), Option<&'static str>) = ) = (
(MarianModelResources::ROMANCE2ENGLISH, MarianConfigResources::ROMANCE2ENGLISH, MarianVocabResources::ROMANCE2ENGLISH, MarianSpmResources::ROMANCE2ENGLISH, MarianPrefix::PORTUGUESE2ENGLISH); MarianModelResources::ROMANCE2ENGLISH,
pub const ITALIAN2ENGLISH: ((&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), Option<&'static str>) = MarianConfigResources::ROMANCE2ENGLISH,
(MarianModelResources::ROMANCE2ENGLISH, MarianConfigResources::ROMANCE2ENGLISH, MarianVocabResources::ROMANCE2ENGLISH, MarianSpmResources::ROMANCE2ENGLISH, MarianPrefix::ITALIAN2ENGLISH); MarianVocabResources::ROMANCE2ENGLISH,
pub const ROMANIAN2ENGLISH: ((&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), Option<&'static str>) = MarianSpmResources::ROMANCE2ENGLISH,
(MarianModelResources::ROMANCE2ENGLISH, MarianConfigResources::ROMANCE2ENGLISH, MarianVocabResources::ROMANCE2ENGLISH, MarianSpmResources::ROMANCE2ENGLISH, MarianPrefix::ROMANIAN2ENGLISH); MarianPrefix::FRENCH2ENGLISH,
pub const GERMAN2ENGLISH: ((&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), Option<&'static str>) = );
(MarianModelResources::GERMAN2ENGLISH, MarianConfigResources::GERMAN2ENGLISH, MarianVocabResources::GERMAN2ENGLISH, MarianSpmResources::GERMAN2ENGLISH, MarianPrefix::GERMAN2ENGLISH); pub const CATALAN2ENGLISH: (
pub const RUSSIAN2ENGLISH: ((&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), Option<&'static str>) = (&'static str, &'static str),
(MarianModelResources::RUSSIAN2ENGLISH, MarianConfigResources::RUSSIAN2ENGLISH, MarianVocabResources::RUSSIAN2ENGLISH, MarianSpmResources::RUSSIAN2ENGLISH, MarianPrefix::RUSSIAN2ENGLISH); (&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
) = (
MarianModelResources::ROMANCE2ENGLISH,
MarianConfigResources::ROMANCE2ENGLISH,
MarianVocabResources::ROMANCE2ENGLISH,
MarianSpmResources::ROMANCE2ENGLISH,
MarianPrefix::CATALAN2ENGLISH,
);
pub const SPANISH2ENGLISH: (
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
) = (
MarianModelResources::ROMANCE2ENGLISH,
MarianConfigResources::ROMANCE2ENGLISH,
MarianVocabResources::ROMANCE2ENGLISH,
MarianSpmResources::ROMANCE2ENGLISH,
MarianPrefix::SPANISH2ENGLISH,
);
pub const PORTUGUESE2ENGLISH: (
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
) = (
MarianModelResources::ROMANCE2ENGLISH,
MarianConfigResources::ROMANCE2ENGLISH,
MarianVocabResources::ROMANCE2ENGLISH,
MarianSpmResources::ROMANCE2ENGLISH,
MarianPrefix::PORTUGUESE2ENGLISH,
);
pub const ITALIAN2ENGLISH: (
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
) = (
MarianModelResources::ROMANCE2ENGLISH,
MarianConfigResources::ROMANCE2ENGLISH,
MarianVocabResources::ROMANCE2ENGLISH,
MarianSpmResources::ROMANCE2ENGLISH,
MarianPrefix::ITALIAN2ENGLISH,
);
pub const ROMANIAN2ENGLISH: (
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
) = (
MarianModelResources::ROMANCE2ENGLISH,
MarianConfigResources::ROMANCE2ENGLISH,
MarianVocabResources::ROMANCE2ENGLISH,
MarianSpmResources::ROMANCE2ENGLISH,
MarianPrefix::ROMANIAN2ENGLISH,
);
pub const GERMAN2ENGLISH: (
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
) = (
MarianModelResources::GERMAN2ENGLISH,
MarianConfigResources::GERMAN2ENGLISH,
MarianVocabResources::GERMAN2ENGLISH,
MarianSpmResources::GERMAN2ENGLISH,
MarianPrefix::GERMAN2ENGLISH,
);
pub const RUSSIAN2ENGLISH: (
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
) = (
MarianModelResources::RUSSIAN2ENGLISH,
MarianConfigResources::RUSSIAN2ENGLISH,
MarianVocabResources::RUSSIAN2ENGLISH,
MarianSpmResources::RUSSIAN2ENGLISH,
MarianPrefix::RUSSIAN2ENGLISH,
);
pub const FRENCH2GERMAN: ((&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), Option<&'static str>) = pub const FRENCH2GERMAN: (
(MarianModelResources::FRENCH2GERMAN, MarianConfigResources::FRENCH2GERMAN, MarianVocabResources::FRENCH2GERMAN, MarianSpmResources::FRENCH2GERMAN, MarianPrefix::FRENCH2GERMAN); (&'static str, &'static str),
pub const GERMAN2FRENCH: ((&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), Option<&'static str>) = (&'static str, &'static str),
(MarianModelResources::GERMAN2FRENCH, MarianConfigResources::GERMAN2FRENCH, MarianVocabResources::GERMAN2FRENCH, MarianSpmResources::GERMAN2FRENCH, MarianPrefix::GERMAN2FRENCH); (&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
) = (
MarianModelResources::FRENCH2GERMAN,
MarianConfigResources::FRENCH2GERMAN,
MarianVocabResources::FRENCH2GERMAN,
MarianSpmResources::FRENCH2GERMAN,
MarianPrefix::FRENCH2GERMAN,
);
pub const GERMAN2FRENCH: (
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
) = (
MarianModelResources::GERMAN2FRENCH,
MarianConfigResources::GERMAN2FRENCH,
MarianVocabResources::GERMAN2FRENCH,
MarianSpmResources::GERMAN2FRENCH,
MarianPrefix::GERMAN2FRENCH,
);
} }
/// # Configuration for text translation /// # Configuration for text translation
/// Contains information regarding the model to load, mirrors the GenerationConfig, with a /// Contains information regarding the model to load, mirrors the GenerationConfig, with a
/// different set of default parameters and sets the device to place the model on. /// different set of default parameters and sets the device to place the model on.
@ -178,45 +378,46 @@ impl TranslationConfig {
/// # Example /// # Example
/// ///
/// ```no_run /// ```no_run
///# fn main() -> failure::Fallible<()> { /// # fn main() -> failure::Fallible<()> {
/// use rust_bert::pipelines::translation::{TranslationConfig, Language}; /// use rust_bert::pipelines::translation::{Language, TranslationConfig};
/// use tch::Device; /// use tch::Device;
/// ///
/// let translation_config = TranslationConfig::new(Language::FrenchToEnglish, Device::cuda_if_available()); /// let translation_config =
///# Ok(()) /// TranslationConfig::new(Language::FrenchToEnglish, Device::cuda_if_available());
///# } /// # Ok(())
/// # }
/// ``` /// ```
///
pub fn new(language: Language, device: Device) -> TranslationConfig { pub fn new(language: Language, device: Device) -> TranslationConfig {
let (model_resource, config_resource, vocab_resource, merges_resource, prefix) = match language { let (model_resource, config_resource, vocab_resource, merges_resource, prefix) =
Language::EnglishToFrench => RemoteTranslationResources::ENGLISH2FRENCH, match language {
Language::EnglishToCatalan => RemoteTranslationResources::ENGLISH2CATALAN, Language::EnglishToFrench => RemoteTranslationResources::ENGLISH2FRENCH,
Language::EnglishToSpanish => RemoteTranslationResources::ENGLISH2SPANISH, Language::EnglishToCatalan => RemoteTranslationResources::ENGLISH2CATALAN,
Language::EnglishToPortuguese => RemoteTranslationResources::ENGLISH2PORTUGUESE, Language::EnglishToSpanish => RemoteTranslationResources::ENGLISH2SPANISH,
Language::EnglishToItalian => RemoteTranslationResources::ENGLISH2ITALIAN, Language::EnglishToPortuguese => RemoteTranslationResources::ENGLISH2PORTUGUESE,
Language::EnglishToRomanian => RemoteTranslationResources::ENGLISH2ROMANIAN, Language::EnglishToItalian => RemoteTranslationResources::ENGLISH2ITALIAN,
Language::EnglishToGerman => RemoteTranslationResources::ENGLISH2GERMAN, Language::EnglishToRomanian => RemoteTranslationResources::ENGLISH2ROMANIAN,
Language::EnglishToRussian => RemoteTranslationResources::ENGLISH2RUSSIAN, Language::EnglishToGerman => RemoteTranslationResources::ENGLISH2GERMAN,
Language::EnglishToRussian => RemoteTranslationResources::ENGLISH2RUSSIAN,
Language::FrenchToEnglish => RemoteTranslationResources::FRENCH2ENGLISH, Language::FrenchToEnglish => RemoteTranslationResources::FRENCH2ENGLISH,
Language::CatalanToEnglish => RemoteTranslationResources::CATALAN2ENGLISH, Language::CatalanToEnglish => RemoteTranslationResources::CATALAN2ENGLISH,
Language::SpanishToEnglish => RemoteTranslationResources::SPANISH2ENGLISH, Language::SpanishToEnglish => RemoteTranslationResources::SPANISH2ENGLISH,
Language::PortugueseToEnglish => RemoteTranslationResources::PORTUGUESE2ENGLISH, Language::PortugueseToEnglish => RemoteTranslationResources::PORTUGUESE2ENGLISH,
Language::ItalianToEnglish => RemoteTranslationResources::ITALIAN2ENGLISH, Language::ItalianToEnglish => RemoteTranslationResources::ITALIAN2ENGLISH,
Language::RomanianToEnglish => RemoteTranslationResources::ROMANIAN2ENGLISH, Language::RomanianToEnglish => RemoteTranslationResources::ROMANIAN2ENGLISH,
Language::GermanToEnglish => RemoteTranslationResources::GERMAN2ENGLISH, Language::GermanToEnglish => RemoteTranslationResources::GERMAN2ENGLISH,
Language::RussianToEnglish => RemoteTranslationResources::RUSSIAN2ENGLISH, Language::RussianToEnglish => RemoteTranslationResources::RUSSIAN2ENGLISH,
Language::FrenchToGerman => RemoteTranslationResources::FRENCH2GERMAN, Language::FrenchToGerman => RemoteTranslationResources::FRENCH2GERMAN,
Language::GermanToFrench => RemoteTranslationResources::GERMAN2FRENCH, Language::GermanToFrench => RemoteTranslationResources::GERMAN2FRENCH,
}; };
let model_resource = Resource::Remote(RemoteResource::from_pretrained(model_resource)); let model_resource = Resource::Remote(RemoteResource::from_pretrained(model_resource));
let config_resource = Resource::Remote(RemoteResource::from_pretrained(config_resource)); let config_resource = Resource::Remote(RemoteResource::from_pretrained(config_resource));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(vocab_resource)); let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(vocab_resource));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(merges_resource)); let merges_resource = Resource::Remote(RemoteResource::from_pretrained(merges_resource));
let prefix = match prefix { let prefix = match prefix {
Some(value) => Some(value.to_string()), Some(value) => Some(value.to_string()),
None => None None => None,
}; };
TranslationConfig { TranslationConfig {
model_resource, model_resource,
@ -253,33 +454,44 @@ impl TranslationConfig {
/// # Example /// # Example
/// ///
/// ```no_run /// ```no_run
///# fn main() -> failure::Fallible<()> { /// # fn main() -> failure::Fallible<()> {
/// use rust_bert::pipelines::translation::TranslationConfig; /// use rust_bert::pipelines::translation::TranslationConfig;
/// use tch::Device; /// use rust_bert::resources::{LocalResource, Resource};
/// use rust_bert::resources::{Resource, LocalResource};
/// use std::path::PathBuf; /// use std::path::PathBuf;
/// use tch::Device;
/// ///
/// let config_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/config.json") }); /// let config_resource = Resource::Local(LocalResource {
/// let model_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/model.ot") }); /// local_path: PathBuf::from("path/to/config.json"),
/// let vocab_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/vocab.json") }); /// });
/// let sentence_piece_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/spiece.model") }); /// let model_resource = Resource::Local(LocalResource {
/// local_path: PathBuf::from("path/to/model.ot"),
/// });
/// let vocab_resource = Resource::Local(LocalResource {
/// local_path: PathBuf::from("path/to/vocab.json"),
/// });
/// let sentence_piece_resource = Resource::Local(LocalResource {
/// local_path: PathBuf::from("path/to/spiece.model"),
/// });
/// ///
/// let translation_config = TranslationConfig::new_from_resources(model_resource, /// let translation_config = TranslationConfig::new_from_resources(
/// config_resource, /// model_resource,
/// vocab_resource, /// config_resource,
/// sentence_piece_resource, /// vocab_resource,
/// Some(">>fr<<".to_string()), /// sentence_piece_resource,
/// Device::cuda_if_available()); /// Some(">>fr<<".to_string()),
///# Ok(()) /// Device::cuda_if_available(),
///# } /// );
/// # Ok(())
/// # }
/// ``` /// ```
/// pub fn new_from_resources(
pub fn new_from_resources(model_resource: Resource, model_resource: Resource,
config_resource: Resource, config_resource: Resource,
vocab_resource: Resource, vocab_resource: Resource,
sentence_piece_resource: Resource, sentence_piece_resource: Resource,
prefix: Option<String>, prefix: Option<String>,
device: Device) -> TranslationConfig { device: Device,
) -> TranslationConfig {
TranslationConfig { TranslationConfig {
model_resource, model_resource,
config_resource, config_resource,
@ -319,18 +531,17 @@ impl TranslationModel {
/// # Example /// # Example
/// ///
/// ```no_run /// ```no_run
///# fn main() -> failure::Fallible<()> { /// # fn main() -> failure::Fallible<()> {
/// use rust_bert::pipelines::translation::{TranslationModel, TranslationConfig, Language}; /// use rust_bert::pipelines::translation::{Language, TranslationConfig, TranslationModel};
/// use tch::Device; /// use tch::Device;
/// ///
/// let translation_config = TranslationConfig::new(Language::FrenchToEnglish, Device::cuda_if_available()); /// let translation_config =
/// let mut summarization_model = TranslationModel::new(translation_config)?; /// TranslationConfig::new(Language::FrenchToEnglish, Device::cuda_if_available());
///# Ok(()) /// let mut summarization_model = TranslationModel::new(translation_config)?;
///# } /// # Ok(())
/// # }
/// ``` /// ```
/// pub fn new(translation_config: TranslationConfig) -> failure::Fallible<TranslationModel> {
pub fn new(translation_config: TranslationConfig)
-> failure::Fallible<TranslationModel> {
let generate_config = GenerateConfig { let generate_config = GenerateConfig {
model_resource: translation_config.model_resource, model_resource: translation_config.model_resource,
config_resource: translation_config.config_resource, config_resource: translation_config.config_resource,
@ -353,7 +564,10 @@ impl TranslationModel {
let model = MarianGenerator::new(generate_config)?; let model = MarianGenerator::new(generate_config)?;
Ok(TranslationModel { model, prefix: translation_config.prefix }) Ok(TranslationModel {
model,
prefix: translation_config.prefix,
})
} }
/// Translates texts provided /// Translates texts provided
@ -368,31 +582,32 @@ impl TranslationModel {
/// # Example /// # Example
/// ///
/// ```no_run /// ```no_run
///# fn main() -> failure::Fallible<()> { /// # fn main() -> failure::Fallible<()> {
/// use rust_bert::pipelines::generation::LanguageGenerator; /// use rust_bert::pipelines::generation::LanguageGenerator;
/// use rust_bert::pipelines::translation::{TranslationModel, TranslationConfig, Language}; /// use rust_bert::pipelines::translation::{Language, TranslationConfig, TranslationModel};
/// use tch::Device; /// use tch::Device;
/// ///
/// let translation_config = TranslationConfig::new(Language::EnglishToFrench, Device::cuda_if_available()); /// let translation_config =
/// TranslationConfig::new(Language::EnglishToFrench, Device::cuda_if_available());
/// let model = TranslationModel::new(translation_config)?; /// let model = TranslationModel::new(translation_config)?;
/// ///
/// let input = ["This is a sentence to be translated"]; /// let input = ["This is a sentence to be translated"];
/// ///
/// let output = model.translate(&input); /// let output = model.translate(&input);
///# Ok(()) /// # Ok(())
///# } /// # }
/// ``` /// ```
///
pub fn translate(&self, texts: &[&str]) -> Vec<String> { pub fn translate(&self, texts: &[&str]) -> Vec<String> {
match &self.prefix { match &self.prefix {
Some(value) => { Some(value) => {
let texts: Vec<String> = texts let texts: Vec<String> = texts
.into_iter() .into_iter()
.map(|&v| { format!("{} {}", value, v) }) .map(|&v| format!("{} {}", value, v))
.collect(); .collect();
self.model.generate(Some(texts.iter().map(AsRef::as_ref).collect()), None) self.model
.generate(Some(texts.iter().map(AsRef::as_ref).collect()), None)
} }
None => self.model.generate(Some(texts.to_vec()), None) None => self.model.generate(Some(texts.to_vec()), None),
} }
} }
} }

View File

@ -11,10 +11,10 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use tch::{nn, Tensor, Kind};
use crate::common::dropout::Dropout;
use tch::nn::{EmbeddingConfig, embedding};
use crate::bert::{BertConfig, BertEmbedding}; use crate::bert::{BertConfig, BertEmbedding};
use crate::common::dropout::Dropout;
use tch::nn::{embedding, EmbeddingConfig};
use tch::{nn, Kind, Tensor};
#[derive(Debug)] #[derive(Debug)]
/// # BertEmbeddings implementation for RoBERTa model /// # BertEmbeddings implementation for RoBERTa model
@ -36,8 +36,12 @@ impl RobertaEmbeddings {
fn create_position_ids_from_embeddings(&self, x: &Tensor) -> Tensor { fn create_position_ids_from_embeddings(&self, x: &Tensor) -> Tensor {
let input_shape = x.size(); let input_shape = x.size();
let input_shape = vec!(input_shape[0], input_shape[1]); let input_shape = vec![input_shape[0], input_shape[1]];
let position_ids: Tensor = Tensor::arange1(self.padding_index + 1, input_shape[0], (Kind::Int64, x.device())); let position_ids: Tensor = Tensor::arange1(
self.padding_index + 1,
input_shape[0],
(Kind::Int64, x.device()),
);
position_ids.unsqueeze(0).expand(&input_shape, true) position_ids.unsqueeze(0).expand(&input_shape, true)
} }
} }
@ -54,10 +58,10 @@ impl BertEmbedding for RobertaEmbeddings {
/// ///
/// ```no_run /// ```no_run
/// use rust_bert::bert::{BertConfig, BertEmbedding}; /// use rust_bert::bert::{BertConfig, BertEmbedding};
/// use tch::{nn, Device}; /// use rust_bert::roberta::RobertaEmbeddings;
/// use rust_bert::Config; /// use rust_bert::Config;
/// use std::path::Path; /// use std::path::Path;
/// use rust_bert::roberta::RobertaEmbeddings; /// use tch::{nn, Device};
/// ///
/// let config_path = Path::new("path/to/config.json"); /// let config_path = Path::new("path/to/config.json");
/// let device = Device::Cpu; /// let device = Device::Cpu;
@ -65,29 +69,48 @@ impl BertEmbedding for RobertaEmbeddings {
/// let config = BertConfig::from_file(config_path); /// let config = BertConfig::from_file(config_path);
/// let robert_embeddings = RobertaEmbeddings::new(&(&p.root() / "bert_embeddings"), &config); /// let robert_embeddings = RobertaEmbeddings::new(&(&p.root() / "bert_embeddings"), &config);
/// ``` /// ```
///
fn new(p: &nn::Path, config: &BertConfig) -> RobertaEmbeddings { fn new(p: &nn::Path, config: &BertConfig) -> RobertaEmbeddings {
let embedding_config = EmbeddingConfig { padding_idx: 1, ..Default::default() }; let embedding_config = EmbeddingConfig {
padding_idx: 1,
..Default::default()
};
let word_embeddings: nn::Embedding = embedding(p / "word_embeddings", let word_embeddings: nn::Embedding = embedding(
config.vocab_size, p / "word_embeddings",
config.hidden_size, config.vocab_size,
embedding_config); config.hidden_size,
embedding_config,
);
let position_embeddings: nn::Embedding = embedding(p / "position_embeddings", let position_embeddings: nn::Embedding = embedding(
config.max_position_embeddings, p / "position_embeddings",
config.hidden_size, config.max_position_embeddings,
Default::default()); config.hidden_size,
Default::default(),
);
let token_type_embeddings: nn::Embedding = embedding(p / "token_type_embeddings", let token_type_embeddings: nn::Embedding = embedding(
config.type_vocab_size, p / "token_type_embeddings",
config.hidden_size, config.type_vocab_size,
Default::default()); config.hidden_size,
Default::default(),
);
let layer_norm_config = nn::LayerNormConfig { eps: 1e-12, ..Default::default() }; let layer_norm_config = nn::LayerNormConfig {
let layer_norm: nn::LayerNorm = nn::layer_norm(p / "LayerNorm", vec![config.hidden_size], layer_norm_config); eps: 1e-12,
..Default::default()
};
let layer_norm: nn::LayerNorm =
nn::layer_norm(p / "LayerNorm", vec![config.hidden_size], layer_norm_config);
let dropout: Dropout = Dropout::new(config.hidden_dropout_prob); let dropout: Dropout = Dropout::new(config.hidden_dropout_prob);
RobertaEmbeddings { word_embeddings, position_embeddings, token_type_embeddings, layer_norm, dropout, padding_index: 1 } RobertaEmbeddings {
word_embeddings,
position_embeddings,
token_type_embeddings,
layer_norm,
dropout,
padding_index: 1,
}
} }
/// Forward pass through the embedding layer. /// Forward pass through the embedding layer.
@ -108,68 +131,82 @@ impl BertEmbedding for RobertaEmbeddings {
/// # Example /// # Example
/// ///
/// ```no_run /// ```no_run
///# use rust_bert::bert::{BertConfig, BertEmbedding}; /// # use rust_bert::bert::{BertConfig, BertEmbedding};
///# use tch::{nn, Device, Tensor, no_grad}; /// # use tch::{nn, Device, Tensor, no_grad};
///# use rust_bert::Config; /// # use rust_bert::Config;
///# use std::path::Path; /// # use std::path::Path;
///# use tch::kind::Kind::Int64; /// # use tch::kind::Kind::Int64;
/// use rust_bert::roberta::RobertaEmbeddings; /// use rust_bert::roberta::RobertaEmbeddings;
///# let config_path = Path::new("path/to/config.json"); /// # let config_path = Path::new("path/to/config.json");
///# let vocab_path = Path::new("path/to/vocab.txt"); /// # let vocab_path = Path::new("path/to/vocab.txt");
///# let device = Device::Cpu; /// # let device = Device::Cpu;
///# let vs = nn::VarStore::new(device); /// # let vs = nn::VarStore::new(device);
///# let config = BertConfig::from_file(config_path); /// # let config = BertConfig::from_file(config_path);
///# let roberta_embeddings = RobertaEmbeddings::new(&vs.root(), &config); /// # let roberta_embeddings = RobertaEmbeddings::new(&vs.root(), &config);
/// let (batch_size, sequence_length) = (64, 128); /// let (batch_size, sequence_length) = (64, 128);
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device)); /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
/// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device)); /// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
/// let position_ids = Tensor::arange(sequence_length, (Int64, device)).expand(&[batch_size, sequence_length], true); /// let position_ids = Tensor::arange(sequence_length, (Int64, device))
/// .expand(&[batch_size, sequence_length], true);
/// ///
/// let embedded_output = no_grad(|| { /// let embedded_output = no_grad(|| {
/// roberta_embeddings /// roberta_embeddings
/// .forward_t(Some(input_tensor), /// .forward_t(
/// Some(token_type_ids), /// Some(input_tensor),
/// Some(position_ids), /// Some(token_type_ids),
/// None, /// Some(position_ids),
/// false).unwrap() /// None,
/// }); /// false,
/// )
/// .unwrap()
/// });
/// ``` /// ```
/// fn forward_t(
fn forward_t(&self, &self,
input_ids: Option<Tensor>, input_ids: Option<Tensor>,
token_type_ids: Option<Tensor>, token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>, position_ids: Option<Tensor>,
input_embeds: Option<Tensor>, input_embeds: Option<Tensor>,
train: bool) -> Result<Tensor, &'static str> { train: bool,
) -> Result<Tensor, &'static str> {
let (input_embeddings, input_shape) = match &input_ids { let (input_embeddings, input_shape) = match &input_ids {
Some(input_value) => match &input_embeds { Some(input_value) => match &input_embeds {
Some(_) => { return Err("Only one of input ids or input embeddings may be set"); } Some(_) => {
None => (input_value.apply_t(&self.word_embeddings, train), input_value.size()) return Err("Only one of input ids or input embeddings may be set");
} }
None => (
input_value.apply_t(&self.word_embeddings, train),
input_value.size(),
),
},
None => match &input_embeds { None => match &input_embeds {
Some(embeds) => (embeds.copy(), vec!(embeds.size()[0], embeds.size()[1])), Some(embeds) => (embeds.copy(), vec![embeds.size()[0], embeds.size()[1]]),
None => { return Err("Only one of input ids or input embeddings may be set"); } None => {
} return Err("Only one of input ids or input embeddings may be set");
}
},
}; };
let position_ids = match position_ids { let position_ids = match position_ids {
Some(value) => value, Some(value) => value,
None => match input_ids { None => match input_ids {
Some(value) => self.create_position_ids_from_input_ids(&value), Some(value) => self.create_position_ids_from_input_ids(&value),
None => self.create_position_ids_from_embeddings(&input_embeds.unwrap()) None => self.create_position_ids_from_embeddings(&input_embeds.unwrap()),
} },
}; };
let token_type_ids = match token_type_ids { let token_type_ids = match token_type_ids {
Some(value) => value, Some(value) => value,
None => Tensor::zeros(&input_shape, (Kind::Int64, input_embeddings.device())) None => Tensor::zeros(&input_shape, (Kind::Int64, input_embeddings.device())),
}; };
let position_embeddings = position_ids.apply(&self.position_embeddings); let position_embeddings = position_ids.apply(&self.position_embeddings);
let token_type_embeddings = token_type_ids.apply(&self.token_type_embeddings); let token_type_embeddings = token_type_ids.apply(&self.token_type_embeddings);
let input_embeddings: Tensor = input_embeddings + position_embeddings + token_type_embeddings; let input_embeddings: Tensor =
Ok(input_embeddings.apply(&self.layer_norm).apply_t(&self.dropout, train)) input_embeddings + position_embeddings + token_type_embeddings;
Ok(input_embeddings
.apply(&self.layer_norm)
.apply_t(&self.dropout, train))
} }
} }

View File

@ -19,20 +19,28 @@
//! Pretrained models are available and can be downloaded using RemoteResources. //! Pretrained models are available and can be downloaded using RemoteResources.
//! //!
//! ```no_run //! ```no_run
//!# fn main() -> failure::Fallible<()> { //! # fn main() -> failure::Fallible<()> {
//!# //! #
//! use rust_tokenizers::RobertaTokenizer; //! use rust_tokenizers::RobertaTokenizer;
//! use tch::{nn, Device}; //! use tch::{nn, Device};
//!# use std::path::PathBuf; //! # use std::path::PathBuf;
//! use rust_bert::bert::BertConfig; //! use rust_bert::bert::BertConfig;
//! use rust_bert::Config; //! use rust_bert::resources::{download_resource, LocalResource, Resource};
//! use rust_bert::roberta::RobertaForMaskedLM; //! use rust_bert::roberta::RobertaForMaskedLM;
//! use rust_bert::resources::{Resource, download_resource, LocalResource}; //! use rust_bert::Config;
//! //!
//! let config_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/config.json")}); //! let config_resource = Resource::Local(LocalResource {
//! let vocab_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/vocab.txt")}); //! local_path: PathBuf::from("path/to/config.json"),
//! let merges_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/vocab.txt")}); //! });
//! let weights_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/model.ot")}); //! let vocab_resource = Resource::Local(LocalResource {
//! local_path: PathBuf::from("path/to/vocab.txt"),
//! });
//! let merges_resource = Resource::Local(LocalResource {
//! local_path: PathBuf::from("path/to/vocab.txt"),
//! });
//! let weights_resource = Resource::Local(LocalResource {
//! local_path: PathBuf::from("path/to/model.ot"),
//! });
//! let config_path = download_resource(&config_resource)?; //! let config_path = download_resource(&config_resource)?;
//! let vocab_path = download_resource(&vocab_resource)?; //! let vocab_path = download_resource(&vocab_resource)?;
//! let merges_path = download_resource(&merges_resource)?; //! let merges_path = download_resource(&merges_resource)?;
@ -40,19 +48,25 @@
//! //!
//! let device = Device::cuda_if_available(); //! let device = Device::cuda_if_available();
//! let mut vs = nn::VarStore::new(device); //! let mut vs = nn::VarStore::new(device);
//! let tokenizer: RobertaTokenizer = RobertaTokenizer::from_file(vocab_path.to_str().unwrap(), merges_path.to_str().unwrap(), true); //! let tokenizer: RobertaTokenizer = RobertaTokenizer::from_file(
//! vocab_path.to_str().unwrap(),
//! merges_path.to_str().unwrap(),
//! true,
//! );
//! let config = BertConfig::from_file(config_path); //! let config = BertConfig::from_file(config_path);
//! let bert_model = RobertaForMaskedLM::new(&vs.root(), &config); //! let bert_model = RobertaForMaskedLM::new(&vs.root(), &config);
//! vs.load(weights_path)?; //! vs.load(weights_path)?;
//! //!
//!# Ok(()) //! # Ok(())
//!# } //! # }
//! ``` //! ```
mod embeddings; mod embeddings;
mod roberta; mod roberta;
pub use roberta::{RobertaModelResources, RobertaConfigResources, RobertaVocabResources, RobertaMergesResources, pub use embeddings::RobertaEmbeddings;
RobertaForMaskedLM, RobertaForMultipleChoice, RobertaForTokenClassification, RobertaForQuestionAnswering, RobertaForSequenceClassification}; pub use roberta::{
pub use embeddings::RobertaEmbeddings; RobertaConfigResources, RobertaForMaskedLM, RobertaForMultipleChoice,
RobertaForQuestionAnswering, RobertaForSequenceClassification, RobertaForTokenClassification,
RobertaMergesResources, RobertaModelResources, RobertaVocabResources,
};

View File

@ -11,13 +11,13 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use tch::{nn, Tensor};
use crate::common::linear::{linear_no_bias, LinearNoBias};
use tch::nn::Init;
use crate::common::activations::_gelu;
use crate::roberta::embeddings::RobertaEmbeddings;
use crate::common::dropout::Dropout;
use crate::bert::{BertConfig, BertModel}; use crate::bert::{BertConfig, BertModel};
use crate::common::activations::_gelu;
use crate::common::dropout::Dropout;
use crate::common::linear::{linear_no_bias, LinearNoBias};
use crate::roberta::embeddings::RobertaEmbeddings;
use tch::nn::Init;
use tch::{nn, Tensor};
/// # RoBERTa Pretrained model weight files /// # RoBERTa Pretrained model weight files
pub struct RobertaModelResources; pub struct RobertaModelResources;
@ -33,22 +33,34 @@ pub struct RobertaMergesResources;
impl RobertaModelResources { impl RobertaModelResources {
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format. /// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
pub const ROBERTA: (&'static str, &'static str) = ("roberta/model.ot", "https://cdn.huggingface.co/roberta-base-rust_model.ot"); pub const ROBERTA: (&'static str, &'static str) = (
"roberta/model.ot",
"https://cdn.huggingface.co/roberta-base-rust_model.ot",
);
} }
impl RobertaConfigResources { impl RobertaConfigResources {
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format. /// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
pub const ROBERTA: (&'static str, &'static str) = ("roberta/config.json", "https://cdn.huggingface.co/roberta-base-config.json"); pub const ROBERTA: (&'static str, &'static str) = (
"roberta/config.json",
"https://cdn.huggingface.co/roberta-base-config.json",
);
} }
impl RobertaVocabResources { impl RobertaVocabResources {
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format. /// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
pub const ROBERTA: (&'static str, &'static str) = ("roberta/vocab.txt", "https://cdn.huggingface.co/roberta-base-vocab.json"); pub const ROBERTA: (&'static str, &'static str) = (
"roberta/vocab.txt",
"https://cdn.huggingface.co/roberta-base-vocab.json",
);
} }
impl RobertaMergesResources { impl RobertaMergesResources {
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format. /// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
pub const ROBERTA: (&'static str, &'static str) = ("roberta/merges.txt", "https://cdn.huggingface.co/roberta-base-merges.txt"); pub const ROBERTA: (&'static str, &'static str) = (
"roberta/merges.txt",
"https://cdn.huggingface.co/roberta-base-merges.txt",
);
} }
pub struct RobertaLMHead { pub struct RobertaLMHead {
@ -60,17 +72,42 @@ pub struct RobertaLMHead {
impl RobertaLMHead { impl RobertaLMHead {
pub fn new(p: &nn::Path, config: &BertConfig) -> RobertaLMHead { pub fn new(p: &nn::Path, config: &BertConfig) -> RobertaLMHead {
let dense = nn::linear(p / "dense", config.hidden_size, config.hidden_size, Default::default()); let dense = nn::linear(
let layer_norm_config = nn::LayerNormConfig { eps: 1e-12, ..Default::default() }; p / "dense",
let layer_norm = nn::layer_norm(p / "layer_norm", vec![config.hidden_size], layer_norm_config); config.hidden_size,
let decoder = linear_no_bias(&(p / "decoder"), config.hidden_size, config.vocab_size, Default::default()); config.hidden_size,
Default::default(),
);
let layer_norm_config = nn::LayerNormConfig {
eps: 1e-12,
..Default::default()
};
let layer_norm = nn::layer_norm(
p / "layer_norm",
vec![config.hidden_size],
layer_norm_config,
);
let decoder = linear_no_bias(
&(p / "decoder"),
config.hidden_size,
config.vocab_size,
Default::default(),
);
let bias = p.var("bias", &[config.vocab_size], Init::KaimingUniform); let bias = p.var("bias", &[config.vocab_size], Init::KaimingUniform);
RobertaLMHead { dense, decoder, layer_norm, bias } RobertaLMHead {
dense,
decoder,
layer_norm,
bias,
}
} }
pub fn forward(&self, hidden_states: &Tensor) -> Tensor { pub fn forward(&self, hidden_states: &Tensor) -> Tensor {
(_gelu(&hidden_states.apply(&self.dense))).apply(&self.layer_norm).apply(&self.decoder) + &self.bias (_gelu(&hidden_states.apply(&self.dense)))
.apply(&self.layer_norm)
.apply(&self.decoder)
+ &self.bias
} }
} }
@ -96,10 +133,10 @@ impl RobertaForMaskedLM {
/// ///
/// ```no_run /// ```no_run
/// use rust_bert::bert::BertConfig; /// use rust_bert::bert::BertConfig;
/// use tch::{nn, Device}; /// use rust_bert::roberta::RobertaForMaskedLM;
/// use rust_bert::Config; /// use rust_bert::Config;
/// use std::path::Path; /// use std::path::Path;
/// use rust_bert::roberta::RobertaForMaskedLM; /// use tch::{nn, Device};
/// ///
/// let config_path = Path::new("path/to/config.json"); /// let config_path = Path::new("path/to/config.json");
/// let device = Device::Cpu; /// let device = Device::Cpu;
@ -107,7 +144,6 @@ impl RobertaForMaskedLM {
/// let config = BertConfig::from_file(config_path); /// let config = BertConfig::from_file(config_path);
/// let roberta = RobertaForMaskedLM::new(&(&p.root() / "roberta"), &config); /// let roberta = RobertaForMaskedLM::new(&(&p.root() / "roberta"), &config);
/// ``` /// ```
///
pub fn new(p: &nn::Path, config: &BertConfig) -> RobertaForMaskedLM { pub fn new(p: &nn::Path, config: &BertConfig) -> RobertaForMaskedLM {
let roberta = BertModel::<RobertaEmbeddings>::new(&(p / "roberta"), config); let roberta = BertModel::<RobertaEmbeddings>::new(&(p / "roberta"), config);
let lm_head = RobertaLMHead::new(&(p / "lm_head"), config); let lm_head = RobertaLMHead::new(&(p / "lm_head"), config);
@ -137,49 +173,62 @@ impl RobertaForMaskedLM {
/// # Example /// # Example
/// ///
/// ```no_run /// ```no_run
///# use rust_bert::bert::BertConfig; /// # use rust_bert::bert::BertConfig;
///# use tch::{nn, Device, Tensor, no_grad}; /// # use tch::{nn, Device, Tensor, no_grad};
///# use rust_bert::Config; /// # use rust_bert::Config;
///# use std::path::Path; /// # use std::path::Path;
///# use tch::kind::Kind::Int64; /// # use tch::kind::Kind::Int64;
/// use rust_bert::roberta::RobertaForMaskedLM; /// use rust_bert::roberta::RobertaForMaskedLM;
///# let config_path = Path::new("path/to/config.json"); /// # let config_path = Path::new("path/to/config.json");
///# let vocab_path = Path::new("path/to/vocab.txt"); /// # let vocab_path = Path::new("path/to/vocab.txt");
///# let device = Device::Cpu; /// # let device = Device::Cpu;
///# let vs = nn::VarStore::new(device); /// # let vs = nn::VarStore::new(device);
///# let config = BertConfig::from_file(config_path); /// # let config = BertConfig::from_file(config_path);
///# let roberta_model = RobertaForMaskedLM::new(&vs.root(), &config); /// # let roberta_model = RobertaForMaskedLM::new(&vs.root(), &config);
/// let (batch_size, sequence_length) = (64, 128); /// let (batch_size, sequence_length) = (64, 128);
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device)); /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device)); /// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
/// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device)); /// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
/// let position_ids = Tensor::arange(sequence_length, (Int64, device)).expand(&[batch_size, sequence_length], true); /// let position_ids = Tensor::arange(sequence_length, (Int64, device))
/// /// .expand(&[batch_size, sequence_length], true);
/// let (output, all_hidden_states, all_attentions) = no_grad(|| {
/// roberta_model
/// .forward_t(Some(input_tensor),
/// Some(mask),
/// Some(token_type_ids),
/// Some(position_ids),
/// None,
/// &None,
/// &None,
/// false)
/// });
/// ///
/// let (output, all_hidden_states, all_attentions) = no_grad(|| {
/// roberta_model.forward_t(
/// Some(input_tensor),
/// Some(mask),
/// Some(token_type_ids),
/// Some(position_ids),
/// None,
/// &None,
/// &None,
/// false,
/// )
/// });
/// ``` /// ```
/// pub fn forward_t(
pub fn forward_t(&self, &self,
input_ids: Option<Tensor>, input_ids: Option<Tensor>,
mask: Option<Tensor>, mask: Option<Tensor>,
token_type_ids: Option<Tensor>, token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>, position_ids: Option<Tensor>,
input_embeds: Option<Tensor>, input_embeds: Option<Tensor>,
encoder_hidden_states: &Option<Tensor>, encoder_hidden_states: &Option<Tensor>,
encoder_mask: &Option<Tensor>, encoder_mask: &Option<Tensor>,
train: bool) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) { train: bool,
let (hidden_state, _, all_hidden_states, all_attentions) = self.roberta.forward_t(input_ids, mask, token_type_ids, position_ids, ) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
input_embeds, encoder_hidden_states, encoder_mask, train).unwrap(); let (hidden_state, _, all_hidden_states, all_attentions) = self
.roberta
.forward_t(
input_ids,
mask,
token_type_ids,
position_ids,
input_embeds,
encoder_hidden_states,
encoder_mask,
train,
)
.unwrap();
let prediction_scores = self.lm_head.forward(&hidden_state); let prediction_scores = self.lm_head.forward(&hidden_state);
(prediction_scores, all_hidden_states, all_attentions) (prediction_scores, all_hidden_states, all_attentions)
@ -194,12 +243,30 @@ pub struct RobertaClassificationHead {
impl RobertaClassificationHead { impl RobertaClassificationHead {
pub fn new(p: &nn::Path, config: &BertConfig) -> RobertaClassificationHead { pub fn new(p: &nn::Path, config: &BertConfig) -> RobertaClassificationHead {
let dense = nn::linear(p / "dense", config.hidden_size, config.hidden_size, Default::default()); let dense = nn::linear(
let num_labels = config.id2label.as_ref().expect("num_labels not provided in configuration").len() as i64; p / "dense",
let out_proj = nn::linear(p / "out_proj", config.hidden_size, num_labels, Default::default()); config.hidden_size,
config.hidden_size,
Default::default(),
);
let num_labels = config
.id2label
.as_ref()
.expect("num_labels not provided in configuration")
.len() as i64;
let out_proj = nn::linear(
p / "out_proj",
config.hidden_size,
num_labels,
Default::default(),
);
let dropout = Dropout::new(config.hidden_dropout_prob); let dropout = Dropout::new(config.hidden_dropout_prob);
RobertaClassificationHead { dense, dropout, out_proj } RobertaClassificationHead {
dense,
dropout,
out_proj,
}
} }
pub fn forward_t(&self, hidden_states: &Tensor, train: bool) -> Tensor { pub fn forward_t(&self, hidden_states: &Tensor, train: bool) -> Tensor {
@ -235,10 +302,10 @@ impl RobertaForSequenceClassification {
/// ///
/// ```no_run /// ```no_run
/// use rust_bert::bert::BertConfig; /// use rust_bert::bert::BertConfig;
/// use tch::{nn, Device}; /// use rust_bert::roberta::RobertaForSequenceClassification;
/// use rust_bert::Config; /// use rust_bert::Config;
/// use std::path::Path; /// use std::path::Path;
/// use rust_bert::roberta::RobertaForSequenceClassification; /// use tch::{nn, Device};
/// ///
/// let config_path = Path::new("path/to/config.json"); /// let config_path = Path::new("path/to/config.json");
/// let device = Device::Cpu; /// let device = Device::Cpu;
@ -246,12 +313,14 @@ impl RobertaForSequenceClassification {
/// let config = BertConfig::from_file(config_path); /// let config = BertConfig::from_file(config_path);
/// let roberta = RobertaForSequenceClassification::new(&(&p.root() / "roberta"), &config); /// let roberta = RobertaForSequenceClassification::new(&(&p.root() / "roberta"), &config);
/// ``` /// ```
///
pub fn new(p: &nn::Path, config: &BertConfig) -> RobertaForSequenceClassification { pub fn new(p: &nn::Path, config: &BertConfig) -> RobertaForSequenceClassification {
let roberta = BertModel::<RobertaEmbeddings>::new(&(p / "roberta"), config); let roberta = BertModel::<RobertaEmbeddings>::new(&(p / "roberta"), config);
let classifier = RobertaClassificationHead::new(&(p / "classifier"), config); let classifier = RobertaClassificationHead::new(&(p / "classifier"), config);
RobertaForSequenceClassification { roberta, classifier } RobertaForSequenceClassification {
roberta,
classifier,
}
} }
/// Forward pass through the model /// Forward pass through the model
@ -274,45 +343,58 @@ impl RobertaForSequenceClassification {
/// # Example /// # Example
/// ///
/// ```no_run /// ```no_run
///# use rust_bert::bert::BertConfig; /// # use rust_bert::bert::BertConfig;
///# use tch::{nn, Device, Tensor, no_grad}; /// # use tch::{nn, Device, Tensor, no_grad};
///# use rust_bert::Config; /// # use rust_bert::Config;
///# use std::path::Path; /// # use std::path::Path;
///# use tch::kind::Kind::Int64; /// # use tch::kind::Kind::Int64;
/// use rust_bert::roberta::RobertaForSequenceClassification; /// use rust_bert::roberta::RobertaForSequenceClassification;
///# let config_path = Path::new("path/to/config.json"); /// # let config_path = Path::new("path/to/config.json");
///# let vocab_path = Path::new("path/to/vocab.txt"); /// # let vocab_path = Path::new("path/to/vocab.txt");
///# let device = Device::Cpu; /// # let device = Device::Cpu;
///# let vs = nn::VarStore::new(device); /// # let vs = nn::VarStore::new(device);
///# let config = BertConfig::from_file(config_path); /// # let config = BertConfig::from_file(config_path);
///# let roberta_model = RobertaForSequenceClassification::new(&vs.root(), &config); /// # let roberta_model = RobertaForSequenceClassification::new(&vs.root(), &config);
/// let (batch_size, sequence_length) = (64, 128); /// let (batch_size, sequence_length) = (64, 128);
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device)); /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device)); /// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
/// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device)); /// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
/// let position_ids = Tensor::arange(sequence_length, (Int64, device)).expand(&[batch_size, sequence_length], true); /// let position_ids = Tensor::arange(sequence_length, (Int64, device))
/// /// .expand(&[batch_size, sequence_length], true);
/// let (labels, all_hidden_states, all_attentions) = no_grad(|| {
/// roberta_model
/// .forward_t(Some(input_tensor),
/// Some(mask),
/// Some(token_type_ids),
/// Some(position_ids),
/// None,
/// false)
/// });
/// ///
/// let (labels, all_hidden_states, all_attentions) = no_grad(|| {
/// roberta_model.forward_t(
/// Some(input_tensor),
/// Some(mask),
/// Some(token_type_ids),
/// Some(position_ids),
/// None,
/// false,
/// )
/// });
/// ``` /// ```
/// pub fn forward_t(
pub fn forward_t(&self, &self,
input_ids: Option<Tensor>, input_ids: Option<Tensor>,
mask: Option<Tensor>, mask: Option<Tensor>,
token_type_ids: Option<Tensor>, token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>, position_ids: Option<Tensor>,
input_embeds: Option<Tensor>, input_embeds: Option<Tensor>,
train: bool) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) { train: bool,
let (hidden_state, _, all_hidden_states, all_attentions) = self.roberta.forward_t(input_ids, mask, token_type_ids, position_ids, ) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
input_embeds, &None, &None, train).unwrap(); let (hidden_state, _, all_hidden_states, all_attentions) = self
.roberta
.forward_t(
input_ids,
mask,
token_type_ids,
position_ids,
input_embeds,
&None,
&None,
train,
)
.unwrap();
let output = self.classifier.forward_t(&hidden_state, train); let output = self.classifier.forward_t(&hidden_state, train);
(output, all_hidden_states, all_attentions) (output, all_hidden_states, all_attentions)
@ -344,10 +426,10 @@ impl RobertaForMultipleChoice {
/// ///
/// ```no_run /// ```no_run
/// use rust_bert::bert::BertConfig; /// use rust_bert::bert::BertConfig;
/// use tch::{nn, Device}; /// use rust_bert::roberta::RobertaForMultipleChoice;
/// use rust_bert::Config; /// use rust_bert::Config;
/// use std::path::Path; /// use std::path::Path;
/// use rust_bert::roberta::RobertaForMultipleChoice; /// use tch::{nn, Device};
/// ///
/// let config_path = Path::new("path/to/config.json"); /// let config_path = Path::new("path/to/config.json");
/// let device = Device::Cpu; /// let device = Device::Cpu;
@ -355,13 +437,16 @@ impl RobertaForMultipleChoice {
/// let config = BertConfig::from_file(config_path); /// let config = BertConfig::from_file(config_path);
/// let roberta = RobertaForMultipleChoice::new(&(&p.root() / "roberta"), &config); /// let roberta = RobertaForMultipleChoice::new(&(&p.root() / "roberta"), &config);
/// ``` /// ```
///
pub fn new(p: &nn::Path, config: &BertConfig) -> RobertaForMultipleChoice { pub fn new(p: &nn::Path, config: &BertConfig) -> RobertaForMultipleChoice {
let roberta = BertModel::<RobertaEmbeddings>::new(&(p / "roberta"), config); let roberta = BertModel::<RobertaEmbeddings>::new(&(p / "roberta"), config);
let dropout = Dropout::new(config.hidden_dropout_prob); let dropout = Dropout::new(config.hidden_dropout_prob);
let classifier = nn::linear(p / "classifier", config.hidden_size, 1, Default::default()); let classifier = nn::linear(p / "classifier", config.hidden_size, 1, Default::default());
RobertaForMultipleChoice { roberta, dropout, classifier } RobertaForMultipleChoice {
roberta,
dropout,
classifier,
}
} }
/// Forward pass through the model /// Forward pass through the model
@ -383,61 +468,77 @@ impl RobertaForMultipleChoice {
/// # Example /// # Example
/// ///
/// ```no_run /// ```no_run
///# use rust_bert::bert::BertConfig; /// # use rust_bert::bert::BertConfig;
///# use tch::{nn, Device, Tensor, no_grad}; /// # use tch::{nn, Device, Tensor, no_grad};
///# use rust_bert::Config; /// # use rust_bert::Config;
///# use std::path::Path; /// # use std::path::Path;
///# use tch::kind::Kind::Int64; /// # use tch::kind::Kind::Int64;
/// use rust_bert::roberta::RobertaForMultipleChoice; /// use rust_bert::roberta::RobertaForMultipleChoice;
///# let config_path = Path::new("path/to/config.json"); /// # let config_path = Path::new("path/to/config.json");
///# let vocab_path = Path::new("path/to/vocab.txt"); /// # let vocab_path = Path::new("path/to/vocab.txt");
///# let device = Device::Cpu; /// # let device = Device::Cpu;
///# let vs = nn::VarStore::new(device); /// # let vs = nn::VarStore::new(device);
///# let config = BertConfig::from_file(config_path); /// # let config = BertConfig::from_file(config_path);
///# let roberta_model = RobertaForMultipleChoice::new(&vs.root(), &config); /// # let roberta_model = RobertaForMultipleChoice::new(&vs.root(), &config);
/// let (num_choices, sequence_length) = (3, 128); /// let (num_choices, sequence_length) = (3, 128);
/// let input_tensor = Tensor::rand(&[num_choices, sequence_length], (Int64, device)); /// let input_tensor = Tensor::rand(&[num_choices, sequence_length], (Int64, device));
/// let mask = Tensor::zeros(&[num_choices, sequence_length], (Int64, device)); /// let mask = Tensor::zeros(&[num_choices, sequence_length], (Int64, device));
/// let token_type_ids = Tensor::zeros(&[num_choices, sequence_length], (Int64, device)); /// let token_type_ids = Tensor::zeros(&[num_choices, sequence_length], (Int64, device));
/// let position_ids = Tensor::arange(sequence_length, (Int64, device)).expand(&[num_choices, sequence_length], true); /// let position_ids = Tensor::arange(sequence_length, (Int64, device))
/// /// .expand(&[num_choices, sequence_length], true);
/// let (choices, all_hidden_states, all_attentions) = no_grad(|| {
/// roberta_model
/// .forward_t(input_tensor,
/// Some(mask),
/// Some(token_type_ids),
/// Some(position_ids),
/// false)
/// });
/// ///
/// let (choices, all_hidden_states, all_attentions) = no_grad(|| {
/// roberta_model.forward_t(
/// input_tensor,
/// Some(mask),
/// Some(token_type_ids),
/// Some(position_ids),
/// false,
/// )
/// });
/// ``` /// ```
/// pub fn forward_t(
pub fn forward_t(&self, &self,
input_ids: Tensor, input_ids: Tensor,
mask: Option<Tensor>, mask: Option<Tensor>,
token_type_ids: Option<Tensor>, token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>, position_ids: Option<Tensor>,
train: bool) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) { train: bool,
) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
let num_choices = input_ids.size()[1]; let num_choices = input_ids.size()[1];
let flat_input_ids = Some(input_ids.view((-1i64, *input_ids.size().last().unwrap()))); let flat_input_ids = Some(input_ids.view((-1i64, *input_ids.size().last().unwrap())));
let flat_position_ids = match position_ids { let flat_position_ids = match position_ids {
Some(value) => Some(value.view((-1i64, *value.size().last().unwrap()))), Some(value) => Some(value.view((-1i64, *value.size().last().unwrap()))),
None => None None => None,
}; };
let flat_token_type_ids = match token_type_ids { let flat_token_type_ids = match token_type_ids {
Some(value) => Some(value.view((-1i64, *value.size().last().unwrap()))), Some(value) => Some(value.view((-1i64, *value.size().last().unwrap()))),
None => None None => None,
}; };
let flat_mask = match mask { let flat_mask = match mask {
Some(value) => Some(value.view((-1i64, *value.size().last().unwrap()))), Some(value) => Some(value.view((-1i64, *value.size().last().unwrap()))),
None => None None => None,
}; };
let (_, pooled_output, all_hidden_states, all_attentions) = self.roberta.forward_t(flat_input_ids, flat_mask, flat_token_type_ids, flat_position_ids, let (_, pooled_output, all_hidden_states, all_attentions) = self
None, &None, &None, train).unwrap(); .roberta
.forward_t(
flat_input_ids,
flat_mask,
flat_token_type_ids,
flat_position_ids,
None,
&None,
&None,
train,
)
.unwrap();
let output = pooled_output.apply_t(&self.dropout, train).apply(&self.classifier).view((-1, num_choices)); let output = pooled_output
.apply_t(&self.dropout, train)
.apply(&self.classifier)
.view((-1, num_choices));
(output, all_hidden_states, all_attentions) (output, all_hidden_states, all_attentions)
} }
} }
@ -466,10 +567,10 @@ impl RobertaForTokenClassification {
/// ///
/// ```no_run /// ```no_run
/// use rust_bert::bert::BertConfig; /// use rust_bert::bert::BertConfig;
/// use tch::{nn, Device}; /// use rust_bert::roberta::RobertaForTokenClassification;
/// use rust_bert::Config; /// use rust_bert::Config;
/// use std::path::Path; /// use std::path::Path;
/// use rust_bert::roberta::RobertaForTokenClassification; /// use tch::{nn, Device};
/// ///
/// let config_path = Path::new("path/to/config.json"); /// let config_path = Path::new("path/to/config.json");
/// let device = Device::Cpu; /// let device = Device::Cpu;
@ -477,14 +578,26 @@ impl RobertaForTokenClassification {
/// let config = BertConfig::from_file(config_path); /// let config = BertConfig::from_file(config_path);
/// let roberta = RobertaForTokenClassification::new(&(&p.root() / "roberta"), &config); /// let roberta = RobertaForTokenClassification::new(&(&p.root() / "roberta"), &config);
/// ``` /// ```
///
pub fn new(p: &nn::Path, config: &BertConfig) -> RobertaForTokenClassification { pub fn new(p: &nn::Path, config: &BertConfig) -> RobertaForTokenClassification {
let roberta = BertModel::<RobertaEmbeddings>::new(&(p / "roberta"), config); let roberta = BertModel::<RobertaEmbeddings>::new(&(p / "roberta"), config);
let dropout = Dropout::new(config.hidden_dropout_prob); let dropout = Dropout::new(config.hidden_dropout_prob);
let num_labels = config.id2label.as_ref().expect("num_labels not provided in configuration").len() as i64; let num_labels = config
let classifier = nn::linear(p / "classifier", config.hidden_size, num_labels, Default::default()); .id2label
.as_ref()
.expect("num_labels not provided in configuration")
.len() as i64;
let classifier = nn::linear(
p / "classifier",
config.hidden_size,
num_labels,
Default::default(),
);
RobertaForTokenClassification { roberta, dropout, classifier } RobertaForTokenClassification {
roberta,
dropout,
classifier,
}
} }
/// Forward pass through the model /// Forward pass through the model
@ -507,47 +620,62 @@ impl RobertaForTokenClassification {
/// # Example /// # Example
/// ///
/// ```no_run /// ```no_run
///# use rust_bert::bert::BertConfig; /// # use rust_bert::bert::BertConfig;
///# use tch::{nn, Device, Tensor, no_grad}; /// # use tch::{nn, Device, Tensor, no_grad};
///# use rust_bert::Config; /// # use rust_bert::Config;
///# use std::path::Path; /// # use std::path::Path;
///# use tch::kind::Kind::Int64; /// # use tch::kind::Kind::Int64;
/// use rust_bert::roberta::RobertaForTokenClassification; /// use rust_bert::roberta::RobertaForTokenClassification;
///# let config_path = Path::new("path/to/config.json"); /// # let config_path = Path::new("path/to/config.json");
///# let vocab_path = Path::new("path/to/vocab.txt"); /// # let vocab_path = Path::new("path/to/vocab.txt");
///# let device = Device::Cpu; /// # let device = Device::Cpu;
///# let vs = nn::VarStore::new(device); /// # let vs = nn::VarStore::new(device);
///# let config = BertConfig::from_file(config_path); /// # let config = BertConfig::from_file(config_path);
///# let roberta_model = RobertaForTokenClassification::new(&vs.root(), &config); /// # let roberta_model = RobertaForTokenClassification::new(&vs.root(), &config);
/// let (batch_size, sequence_length) = (64, 128); /// let (batch_size, sequence_length) = (64, 128);
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device)); /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device)); /// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
/// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device)); /// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
/// let position_ids = Tensor::arange(sequence_length, (Int64, device)).expand(&[batch_size, sequence_length], true); /// let position_ids = Tensor::arange(sequence_length, (Int64, device))
/// /// .expand(&[batch_size, sequence_length], true);
/// let (token_labels, all_hidden_states, all_attentions) = no_grad(|| {
/// roberta_model
/// .forward_t(Some(input_tensor),
/// Some(mask),
/// Some(token_type_ids),
/// Some(position_ids),
/// None,
/// false)
/// });
/// ///
/// let (token_labels, all_hidden_states, all_attentions) = no_grad(|| {
/// roberta_model.forward_t(
/// Some(input_tensor),
/// Some(mask),
/// Some(token_type_ids),
/// Some(position_ids),
/// None,
/// false,
/// )
/// });
/// ``` /// ```
/// pub fn forward_t(
pub fn forward_t(&self, &self,
input_ids: Option<Tensor>, input_ids: Option<Tensor>,
mask: Option<Tensor>, mask: Option<Tensor>,
token_type_ids: Option<Tensor>, token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>, position_ids: Option<Tensor>,
input_embeds: Option<Tensor>, input_embeds: Option<Tensor>,
train: bool) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) { train: bool,
let (hidden_state, _, all_hidden_states, all_attentions) = self.roberta.forward_t(input_ids, mask, token_type_ids, position_ids, ) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
input_embeds, &None, &None, train).unwrap(); let (hidden_state, _, all_hidden_states, all_attentions) = self
.roberta
.forward_t(
input_ids,
mask,
token_type_ids,
position_ids,
input_embeds,
&None,
&None,
train,
)
.unwrap();
let sequence_output = hidden_state.apply_t(&self.dropout, train).apply(&self.classifier); let sequence_output = hidden_state
.apply_t(&self.dropout, train)
.apply(&self.classifier);
(sequence_output, all_hidden_states, all_attentions) (sequence_output, all_hidden_states, all_attentions)
} }
} }
@ -576,10 +704,10 @@ impl RobertaForQuestionAnswering {
/// ///
/// ```no_run /// ```no_run
/// use rust_bert::bert::BertConfig; /// use rust_bert::bert::BertConfig;
/// use tch::{nn, Device}; /// use rust_bert::roberta::RobertaForQuestionAnswering;
/// use rust_bert::Config; /// use rust_bert::Config;
/// use std::path::Path; /// use std::path::Path;
/// use rust_bert::roberta::RobertaForQuestionAnswering; /// use tch::{nn, Device};
/// ///
/// let config_path = Path::new("path/to/config.json"); /// let config_path = Path::new("path/to/config.json");
/// let device = Device::Cpu; /// let device = Device::Cpu;
@ -587,13 +715,20 @@ impl RobertaForQuestionAnswering {
/// let config = BertConfig::from_file(config_path); /// let config = BertConfig::from_file(config_path);
/// let roberta = RobertaForQuestionAnswering::new(&(&p.root() / "roberta"), &config); /// let roberta = RobertaForQuestionAnswering::new(&(&p.root() / "roberta"), &config);
/// ``` /// ```
///
pub fn new(p: &nn::Path, config: &BertConfig) -> RobertaForQuestionAnswering { pub fn new(p: &nn::Path, config: &BertConfig) -> RobertaForQuestionAnswering {
let roberta = BertModel::<RobertaEmbeddings>::new(&(p / "roberta"), config); let roberta = BertModel::<RobertaEmbeddings>::new(&(p / "roberta"), config);
let num_labels = 2; let num_labels = 2;
let qa_outputs = nn::linear(p / "qa_outputs", config.hidden_size, num_labels, Default::default()); let qa_outputs = nn::linear(
p / "qa_outputs",
config.hidden_size,
num_labels,
Default::default(),
);
RobertaForQuestionAnswering { roberta, qa_outputs } RobertaForQuestionAnswering {
roberta,
qa_outputs,
}
} }
/// Forward pass through the model /// Forward pass through the model
@ -617,45 +752,58 @@ impl RobertaForQuestionAnswering {
/// # Example /// # Example
/// ///
/// ```no_run /// ```no_run
///# use rust_bert::bert::BertConfig; /// # use rust_bert::bert::BertConfig;
///# use tch::{nn, Device, Tensor, no_grad}; /// # use tch::{nn, Device, Tensor, no_grad};
///# use rust_bert::Config; /// # use rust_bert::Config;
///# use std::path::Path; /// # use std::path::Path;
///# use tch::kind::Kind::Int64; /// # use tch::kind::Kind::Int64;
/// use rust_bert::roberta::RobertaForQuestionAnswering; /// use rust_bert::roberta::RobertaForQuestionAnswering;
///# let config_path = Path::new("path/to/config.json"); /// # let config_path = Path::new("path/to/config.json");
///# let vocab_path = Path::new("path/to/vocab.txt"); /// # let vocab_path = Path::new("path/to/vocab.txt");
///# let device = Device::Cpu; /// # let device = Device::Cpu;
///# let vs = nn::VarStore::new(device); /// # let vs = nn::VarStore::new(device);
///# let config = BertConfig::from_file(config_path); /// # let config = BertConfig::from_file(config_path);
///# let roberta_model = RobertaForQuestionAnswering::new(&vs.root(), &config); /// # let roberta_model = RobertaForQuestionAnswering::new(&vs.root(), &config);
/// let (batch_size, sequence_length) = (64, 128); /// let (batch_size, sequence_length) = (64, 128);
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device)); /// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device)); /// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
/// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device)); /// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
/// let position_ids = Tensor::arange(sequence_length, (Int64, device)).expand(&[batch_size, sequence_length], true); /// let position_ids = Tensor::arange(sequence_length, (Int64, device))
/// /// .expand(&[batch_size, sequence_length], true);
/// let (start_scores, end_scores, all_hidden_states, all_attentions) = no_grad(|| {
/// roberta_model
/// .forward_t(Some(input_tensor),
/// Some(mask),
/// Some(token_type_ids),
/// Some(position_ids),
/// None,
/// false)
/// });
/// ///
/// let (start_scores, end_scores, all_hidden_states, all_attentions) = no_grad(|| {
/// roberta_model.forward_t(
/// Some(input_tensor),
/// Some(mask),
/// Some(token_type_ids),
/// Some(position_ids),
/// None,
/// false,
/// )
/// });
/// ``` /// ```
/// pub fn forward_t(
pub fn forward_t(&self, &self,
input_ids: Option<Tensor>, input_ids: Option<Tensor>,
mask: Option<Tensor>, mask: Option<Tensor>,
token_type_ids: Option<Tensor>, token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>, position_ids: Option<Tensor>,
input_embeds: Option<Tensor>, input_embeds: Option<Tensor>,
train: bool) -> (Tensor, Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) { train: bool,
let (hidden_state, _, all_hidden_states, all_attentions) = self.roberta.forward_t(input_ids, mask, token_type_ids, position_ids, ) -> (Tensor, Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
input_embeds, &None, &None, train).unwrap(); let (hidden_state, _, all_hidden_states, all_attentions) = self
.roberta
.forward_t(
input_ids,
mask,
token_type_ids,
position_ids,
input_embeds,
&None,
&None,
train,
)
.unwrap();
let sequence_output = hidden_state.apply(&self.qa_outputs); let sequence_output = hidden_state.apply(&self.qa_outputs);
let logits = sequence_output.split(1, -1); let logits = sequence_output.split(1, -1);
@ -665,4 +813,4 @@ impl RobertaForQuestionAnswering {
(start_logits, end_logits, all_hidden_states, all_attentions) (start_logits, end_logits, all_hidden_states, all_attentions)
} }
} }

View File

@ -1,67 +1,77 @@
extern crate failure;
extern crate dirs; extern crate dirs;
extern crate failure;
use tch::{Device, nn, Tensor, no_grad}; use rust_bert::albert::{
use rust_tokenizers::{TruncationStrategy, Tokenizer, Vocab, AlbertTokenizer}; AlbertConfig, AlbertConfigResources, AlbertForMaskedLM, AlbertForMultipleChoice,
AlbertForQuestionAnswering, AlbertForSequenceClassification, AlbertForTokenClassification,
AlbertModelResources, AlbertVocabResources,
};
use rust_bert::resources::{download_resource, RemoteResource, Resource};
use rust_bert::Config; use rust_bert::Config;
use rust_bert::resources::{Resource, RemoteResource, download_resource}; use rust_tokenizers::{AlbertTokenizer, Tokenizer, TruncationStrategy, Vocab};
use rust_bert::albert::{AlbertConfigResources, AlbertVocabResources, AlbertModelResources, AlbertConfig, AlbertForMaskedLM, AlbertForSequenceClassification, AlbertForMultipleChoice, AlbertForTokenClassification, AlbertForQuestionAnswering};
use std::collections::HashMap; use std::collections::HashMap;
use tch::{nn, no_grad, Device, Tensor};
#[test] #[test]
fn albert_masked_lm() -> failure::Fallible<()> { fn albert_masked_lm() -> failure::Fallible<()> {
// Resources paths // Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(AlbertConfigResources::ALBERT_BASE_V2)); let config_resource = Resource::Remote(RemoteResource::from_pretrained(
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(AlbertVocabResources::ALBERT_BASE_V2)); AlbertConfigResources::ALBERT_BASE_V2,
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(AlbertModelResources::ALBERT_BASE_V2)); ));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
AlbertVocabResources::ALBERT_BASE_V2,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
AlbertModelResources::ALBERT_BASE_V2,
));
let config_path = download_resource(&config_resource)?; let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?; let vocab_path = download_resource(&vocab_resource)?;
let weights_path = download_resource(&weights_resource)?; let weights_path = download_resource(&weights_resource)?;
// Set-up masked LM model // Set-up masked LM model
let device = Device::Cpu; let device = Device::Cpu;
let mut vs = nn::VarStore::new(device); let mut vs = nn::VarStore::new(device);
let tokenizer: AlbertTokenizer = AlbertTokenizer::from_file(vocab_path.to_str().unwrap(), true, false); let tokenizer: AlbertTokenizer =
AlbertTokenizer::from_file(vocab_path.to_str().unwrap(), true, false);
let config = AlbertConfig::from_file(config_path); let config = AlbertConfig::from_file(config_path);
let albert_model = AlbertForMaskedLM::new(&vs.root(), &config); let albert_model = AlbertForMaskedLM::new(&vs.root(), &config);
vs.load(weights_path)?; vs.load(weights_path)?;
// Define input // Define input
let input = ["Looks like one [MASK] is missing", "It\'s like comparing [MASK] to apples"]; let input = [
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0); "Looks like one [MASK] is missing",
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap(); "It\'s like comparing [MASK] to apples",
let tokenized_input = tokenized_input. ];
iter(). let tokenized_input =
map(|input| input.token_ids.clone()). tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
map(|mut input| { let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let tokenized_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]); input.extend(vec![0; max_len - input.len()]);
input input
}). })
map(|input| .map(|input| Tensor::of_slice(&(input)))
Tensor::of_slice(&(input))). .collect::<Vec<_>>();
collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device); let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass // Forward pass
let (output, _, _) = no_grad(|| { let (output, _, _) =
albert_model no_grad(|| albert_model.forward_t(Some(input_tensor), None, None, None, None, false));
.forward_t(Some(input_tensor),
None,
None,
None,
None,
false)
});
// Print masked tokens // Print masked tokens
let index_1 = output.get(0).get(4).argmax(0, false); let index_1 = output.get(0).get(4).argmax(0, false);
let index_2 = output.get(1).get(6).argmax(0, false); let index_2 = output.get(1).get(6).argmax(0, false);
let word_1 = tokenizer.vocab().id_to_token(&index_1.int64_value(&[])); let word_1 = tokenizer.vocab().id_to_token(&index_1.int64_value(&[]));
let word_2 = tokenizer.vocab().id_to_token(&index_2.int64_value(&[])); let word_2 = tokenizer.vocab().id_to_token(&index_2.int64_value(&[]));
assert_eq!("▁them", word_1); // Outputs "_them" : "Looks like one [them] is missing (? this is identical with the original implementation)" assert_eq!("▁them", word_1); // Outputs "_them" : "Looks like one [them] is missing (? this is identical with the original implementation)"
assert_eq!("▁grapes", word_2);// Outputs "grapes" : "It\'s like comparing [grapes] to apples" assert_eq!("▁grapes", word_2); // Outputs "grapes" : "It\'s like comparing [grapes] to apples"
assert!((output.double_value(&[0, 0, 0]) - 4.6143).abs() < 1e-4); assert!((output.double_value(&[0, 0, 0]) - 4.6143).abs() < 1e-4);
Ok(()) Ok(())
} }
@ -69,15 +79,20 @@ fn albert_masked_lm() -> failure::Fallible<()> {
#[test] #[test]
fn albert_for_sequence_classification() -> failure::Fallible<()> { fn albert_for_sequence_classification() -> failure::Fallible<()> {
// Resources paths // Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(AlbertConfigResources::ALBERT_BASE_V2)); let config_resource = Resource::Remote(RemoteResource::from_pretrained(
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(AlbertVocabResources::ALBERT_BASE_V2)); AlbertConfigResources::ALBERT_BASE_V2,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
AlbertVocabResources::ALBERT_BASE_V2,
));
let config_path = download_resource(&config_resource)?; let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?; let vocab_path = download_resource(&vocab_resource)?;
// Set-up model // Set-up model
let device = Device::Cpu; let device = Device::Cpu;
let vs = nn::VarStore::new(device); let vs = nn::VarStore::new(device);
let tokenizer: AlbertTokenizer = AlbertTokenizer::from_file(vocab_path.to_str().unwrap(), true, false); let tokenizer: AlbertTokenizer =
AlbertTokenizer::from_file(vocab_path.to_str().unwrap(), true, false);
let mut config = AlbertConfig::from_file(config_path); let mut config = AlbertConfig::from_file(config_path);
let mut dummy_label_mapping = HashMap::new(); let mut dummy_label_mapping = HashMap::new();
dummy_label_mapping.insert(0, String::from("Positive")); dummy_label_mapping.insert(0, String::from("Positive"));
@ -88,37 +103,42 @@ fn albert_for_sequence_classification() -> failure::Fallible<()> {
config.output_hidden_states = Some(true); config.output_hidden_states = Some(true);
let albert_model = AlbertForSequenceClassification::new(&vs.root(), &config); let albert_model = AlbertForSequenceClassification::new(&vs.root(), &config);
// Define input
// Define input let input = [
let input = ["Looks like one thing is missing", "It\'s like comparing oranges to apples"]; "Looks like one thing is missing",
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0); "It\'s like comparing oranges to apples",
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap(); ];
let tokenized_input = tokenized_input. let tokenized_input =
iter(). tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
map(|input| input.token_ids.clone()). let max_len = tokenized_input
map(|mut input| { .iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let tokenized_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]); input.extend(vec![0; max_len - input.len()]);
input input
}). })
map(|input| .map(|input| Tensor::of_slice(&(input)))
Tensor::of_slice(&(input))). .collect::<Vec<_>>();
collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device); let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass // Forward pass
let (output, all_hidden_states, all_attentions) = no_grad(|| { let (output, all_hidden_states, all_attentions) =
albert_model no_grad(|| albert_model.forward_t(Some(input_tensor), None, None, None, None, false));
.forward_t(Some(input_tensor),
None,
None,
None,
None,
false)
});
assert_eq!(output.size(), &[2, 3]); assert_eq!(output.size(), &[2, 3]);
assert_eq!(config.num_hidden_layers as usize, all_hidden_states.unwrap().len()); assert_eq!(
assert_eq!(config.num_hidden_layers as usize, all_attentions.unwrap().len()); config.num_hidden_layers as usize,
all_hidden_states.unwrap().len()
);
assert_eq!(
config.num_hidden_layers as usize,
all_attentions.unwrap().len()
);
Ok(()) Ok(())
} }
@ -126,50 +146,66 @@ fn albert_for_sequence_classification() -> failure::Fallible<()> {
#[test] #[test]
fn albert_for_multiple_choice() -> failure::Fallible<()> { fn albert_for_multiple_choice() -> failure::Fallible<()> {
// Resources paths // Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(AlbertConfigResources::ALBERT_BASE_V2)); let config_resource = Resource::Remote(RemoteResource::from_pretrained(
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(AlbertVocabResources::ALBERT_BASE_V2)); AlbertConfigResources::ALBERT_BASE_V2,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
AlbertVocabResources::ALBERT_BASE_V2,
));
let config_path = download_resource(&config_resource)?; let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?; let vocab_path = download_resource(&vocab_resource)?;
// Set-up model // Set-up model
let device = Device::Cpu; let device = Device::Cpu;
let vs = nn::VarStore::new(device); let vs = nn::VarStore::new(device);
let tokenizer: AlbertTokenizer = AlbertTokenizer::from_file(vocab_path.to_str().unwrap(), true, false); let tokenizer: AlbertTokenizer =
AlbertTokenizer::from_file(vocab_path.to_str().unwrap(), true, false);
let mut config = AlbertConfig::from_file(config_path); let mut config = AlbertConfig::from_file(config_path);
config.output_attentions = Some(true); config.output_attentions = Some(true);
config.output_hidden_states = Some(true); config.output_hidden_states = Some(true);
let albert_model = AlbertForMultipleChoice::new(&vs.root(), &config); let albert_model = AlbertForMultipleChoice::new(&vs.root(), &config);
// Define input // Define input
let input = ["Looks like one thing is missing", "It\'s like comparing oranges to apples"]; let input = [
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0); "Looks like one thing is missing",
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap(); "It\'s like comparing oranges to apples",
let tokenized_input = tokenized_input. ];
iter(). let tokenized_input =
map(|input| input.token_ids.clone()). tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
map(|mut input| { let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let tokenized_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]); input.extend(vec![0; max_len - input.len()]);
input input
}). })
map(|input| .map(|input| Tensor::of_slice(&(input)))
Tensor::of_slice(&(input))). .collect::<Vec<_>>();
collect::<Vec<_>>(); let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0)
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device).unsqueeze(0); .to(device)
.unsqueeze(0);
// Forward pass // Forward pass
let (output, all_hidden_states, all_attentions) = no_grad(|| { let (output, all_hidden_states, all_attentions) = no_grad(|| {
albert_model albert_model
.forward_t(Some(input_tensor), .forward_t(Some(input_tensor), None, None, None, None, false)
None, .unwrap()
None,
None,
None,
false).unwrap()
}); });
assert_eq!(output.size(), &[1, 2]); assert_eq!(output.size(), &[1, 2]);
assert_eq!(config.num_hidden_layers as usize, all_hidden_states.unwrap().len()); assert_eq!(
assert_eq!(config.num_hidden_layers as usize, all_attentions.unwrap().len()); config.num_hidden_layers as usize,
all_hidden_states.unwrap().len()
);
assert_eq!(
config.num_hidden_layers as usize,
all_attentions.unwrap().len()
);
Ok(()) Ok(())
} }
@ -177,15 +213,20 @@ fn albert_for_multiple_choice() -> failure::Fallible<()> {
#[test] #[test]
fn albert_for_token_classification() -> failure::Fallible<()> { fn albert_for_token_classification() -> failure::Fallible<()> {
// Resources paths // Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(AlbertConfigResources::ALBERT_BASE_V2)); let config_resource = Resource::Remote(RemoteResource::from_pretrained(
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(AlbertVocabResources::ALBERT_BASE_V2)); AlbertConfigResources::ALBERT_BASE_V2,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
AlbertVocabResources::ALBERT_BASE_V2,
));
let config_path = download_resource(&config_resource)?; let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?; let vocab_path = download_resource(&vocab_resource)?;
// Set-up model // Set-up model
let device = Device::Cpu; let device = Device::Cpu;
let vs = nn::VarStore::new(device); let vs = nn::VarStore::new(device);
let tokenizer: AlbertTokenizer = AlbertTokenizer::from_file(vocab_path.to_str().unwrap(), true, false); let tokenizer: AlbertTokenizer =
AlbertTokenizer::from_file(vocab_path.to_str().unwrap(), true, false);
let mut config = AlbertConfig::from_file(config_path); let mut config = AlbertConfig::from_file(config_path);
let mut dummy_label_mapping = HashMap::new(); let mut dummy_label_mapping = HashMap::new();
dummy_label_mapping.insert(0, String::from("O")); dummy_label_mapping.insert(0, String::from("O"));
@ -197,37 +238,42 @@ fn albert_for_token_classification() -> failure::Fallible<()> {
config.output_hidden_states = Some(true); config.output_hidden_states = Some(true);
let bert_model = AlbertForTokenClassification::new(&vs.root(), &config); let bert_model = AlbertForTokenClassification::new(&vs.root(), &config);
// Define input
// Define input let input = [
let input = ["Looks like one thing is missing", "It\'s like comparing oranges to apples"]; "Looks like one thing is missing",
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0); "It\'s like comparing oranges to apples",
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap(); ];
let tokenized_input = tokenized_input. let tokenized_input =
iter(). tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
map(|input| input.token_ids.clone()). let max_len = tokenized_input
map(|mut input| { .iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let tokenized_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]); input.extend(vec![0; max_len - input.len()]);
input input
}). })
map(|input| .map(|input| Tensor::of_slice(&(input)))
Tensor::of_slice(&(input))). .collect::<Vec<_>>();
collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device); let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass // Forward pass
let (output, all_hidden_states, all_attentions) = no_grad(|| { let (output, all_hidden_states, all_attentions) =
bert_model no_grad(|| bert_model.forward_t(Some(input_tensor), None, None, None, None, false));
.forward_t(Some(input_tensor),
None,
None,
None,
None,
false)
});
assert_eq!(output.size(), &[2, 12, 4]); assert_eq!(output.size(), &[2, 12, 4]);
assert_eq!(config.num_hidden_layers as usize, all_hidden_states.unwrap().len()); assert_eq!(
assert_eq!(config.num_hidden_layers as usize, all_attentions.unwrap().len()); config.num_hidden_layers as usize,
all_hidden_states.unwrap().len()
);
assert_eq!(
config.num_hidden_layers as usize,
all_attentions.unwrap().len()
);
Ok(()) Ok(())
} }
@ -235,51 +281,62 @@ fn albert_for_token_classification() -> failure::Fallible<()> {
#[test] #[test]
fn albert_for_question_answering() -> failure::Fallible<()> { fn albert_for_question_answering() -> failure::Fallible<()> {
// Resources paths // Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(AlbertConfigResources::ALBERT_BASE_V2)); let config_resource = Resource::Remote(RemoteResource::from_pretrained(
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(AlbertVocabResources::ALBERT_BASE_V2)); AlbertConfigResources::ALBERT_BASE_V2,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
AlbertVocabResources::ALBERT_BASE_V2,
));
let config_path = download_resource(&config_resource)?; let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?; let vocab_path = download_resource(&vocab_resource)?;
// Set-up model // Set-up model
let device = Device::Cpu; let device = Device::Cpu;
let vs = nn::VarStore::new(device); let vs = nn::VarStore::new(device);
let tokenizer: AlbertTokenizer = AlbertTokenizer::from_file(vocab_path.to_str().unwrap(), true, false); let tokenizer: AlbertTokenizer =
AlbertTokenizer::from_file(vocab_path.to_str().unwrap(), true, false);
let mut config = AlbertConfig::from_file(config_path); let mut config = AlbertConfig::from_file(config_path);
config.output_attentions = Some(true); config.output_attentions = Some(true);
config.output_hidden_states = Some(true); config.output_hidden_states = Some(true);
let albert_model = AlbertForQuestionAnswering::new(&vs.root(), &config); let albert_model = AlbertForQuestionAnswering::new(&vs.root(), &config);
// Define input // Define input
let input = ["Looks like one thing is missing", "It\'s like comparing oranges to apples"]; let input = [
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0); "Looks like one thing is missing",
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap(); "It\'s like comparing oranges to apples",
let tokenized_input = tokenized_input. ];
iter(). let tokenized_input =
map(|input| input.token_ids.clone()). tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
map(|mut input| { let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let tokenized_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]); input.extend(vec![0; max_len - input.len()]);
input input
}). })
map(|input| .map(|input| Tensor::of_slice(&(input)))
Tensor::of_slice(&(input))). .collect::<Vec<_>>();
collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device); let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass // Forward pass
let (start_scores, end_scores, all_hidden_states, all_attentions) = no_grad(|| { let (start_scores, end_scores, all_hidden_states, all_attentions) =
albert_model no_grad(|| albert_model.forward_t(Some(input_tensor), None, None, None, None, false));
.forward_t(Some(input_tensor),
None,
None,
None,
None,
false)
});
assert_eq!(start_scores.size(), &[2, 12]); assert_eq!(start_scores.size(), &[2, 12]);
assert_eq!(end_scores.size(), &[2, 12]); assert_eq!(end_scores.size(), &[2, 12]);
assert_eq!(config.num_hidden_layers as usize, all_hidden_states.unwrap().len()); assert_eq!(
assert_eq!(config.num_hidden_layers as usize, all_attentions.unwrap().len()); config.num_hidden_layers as usize,
all_hidden_states.unwrap().len()
);
assert_eq!(
config.num_hidden_layers as usize,
all_attentions.unwrap().len()
);
Ok(()) Ok(())
} }

View File

@ -1,56 +1,65 @@
use tch::{Device, nn, Tensor}; use rust_bert::bart::{
use rust_tokenizers::{TruncationStrategy, Tokenizer, RobertaTokenizer}; BartConfig, BartConfigResources, BartMergesResources, BartModel, BartModelResources,
use rust_bert::Config; BartVocabResources,
use rust_bert::bart::{BartConfig, BartConfigResources, BartVocabResources, BartModelResources, BartMergesResources, BartModel}; };
use rust_bert::pipelines::summarization::{SummarizationConfig, SummarizationModel}; use rust_bert::pipelines::summarization::{SummarizationConfig, SummarizationModel};
use rust_bert::resources::{Resource, RemoteResource, download_resource}; use rust_bert::resources::{download_resource, RemoteResource, Resource};
use rust_bert::Config;
use rust_tokenizers::{RobertaTokenizer, Tokenizer, TruncationStrategy};
use tch::{nn, Device, Tensor};
#[test] #[test]
#[cfg_attr(not(feature = "all-tests"), ignore)] #[cfg_attr(not(feature = "all-tests"), ignore)]
fn bart_lm_model() -> failure::Fallible<()> { fn bart_lm_model() -> failure::Fallible<()> {
// Resources paths // Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(BartConfigResources::BART)); let config_resource =
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(BartVocabResources::BART)); Resource::Remote(RemoteResource::from_pretrained(BartConfigResources::BART));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(BartMergesResources::BART)); let vocab_resource =
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(BartModelResources::BART)); Resource::Remote(RemoteResource::from_pretrained(BartVocabResources::BART));
let merges_resource =
Resource::Remote(RemoteResource::from_pretrained(BartMergesResources::BART));
let weights_resource =
Resource::Remote(RemoteResource::from_pretrained(BartModelResources::BART));
let config_path = download_resource(&config_resource)?; let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?; let vocab_path = download_resource(&vocab_resource)?;
let merges_path = download_resource(&merges_resource)?; let merges_path = download_resource(&merges_resource)?;
let weights_path = download_resource(&weights_resource)?; let weights_path = download_resource(&weights_resource)?;
// Set-up masked LM model // Set-up masked LM model
let device = Device::Cpu; let device = Device::Cpu;
let mut vs = nn::VarStore::new(device); let mut vs = nn::VarStore::new(device);
let tokenizer: RobertaTokenizer = RobertaTokenizer::from_file(vocab_path.to_str().unwrap(), merges_path.to_str().unwrap(), false); let tokenizer: RobertaTokenizer = RobertaTokenizer::from_file(
vocab_path.to_str().unwrap(),
merges_path.to_str().unwrap(),
false,
);
let config = BartConfig::from_file(config_path); let config = BartConfig::from_file(config_path);
let bart_model = BartModel::new(&vs.root(), &config, false); let bart_model = BartModel::new(&vs.root(), &config, false);
vs.load(weights_path)?; vs.load(weights_path)?;
// Define input // Define input
let input = ["One two three four"]; let input = ["One two three four"];
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0); let tokenized_input =
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap(); tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let tokenized_input = tokenized_input. let max_len = tokenized_input
iter(). .iter()
map(|input| input.token_ids.clone()). .map(|input| input.token_ids.len())
map(|mut input| { .max()
.unwrap();
let tokenized_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]); input.extend(vec![0; max_len - input.len()]);
input input
}). })
map(|input| .map(|input| Tensor::of_slice(&(input)))
Tensor::of_slice(&(input))). .collect::<Vec<_>>();
collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device); let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass // Forward pass
let (output, encoder_outputs, _, _, _, _, _) = bart_model.forward_t( let (output, encoder_outputs, _, _, _, _, _) =
Some(&input_tensor), bart_model.forward_t(Some(&input_tensor), None, None, None, None, None, false);
None,
None,
None,
None,
None,
false);
assert_eq!(output.size(), vec!(1, 6, 1024)); assert_eq!(output.size(), vec!(1, 6, 1024));
assert_eq!(encoder_outputs.size(), vec!(1, 6, 1024)); assert_eq!(encoder_outputs.size(), vec!(1, 6, 1024));
@ -58,12 +67,10 @@ fn bart_lm_model() -> failure::Fallible<()> {
Ok(()) Ok(())
} }
#[test] #[test]
#[cfg_attr(not(feature = "all-tests"), ignore)] #[cfg_attr(not(feature = "all-tests"), ignore)]
fn bart_summarization_greedy() -> failure::Fallible<()> { fn bart_summarization_greedy() -> failure::Fallible<()> {
// Set-up masked LM model
// Set-up masked LM model
let summarization_config = SummarizationConfig { let summarization_config = SummarizationConfig {
num_beams: 1, num_beams: 1,
device: Device::Cpu, device: Device::Cpu,
@ -93,7 +100,7 @@ on K2-18b lasts 33 Earth days. According to The Guardian, astronomers were optim
telescope scheduled for launch in 2021 and the European Space Agency's 2028 ARIEL program, could reveal more \ telescope scheduled for launch in 2021 and the European Space Agency's 2028 ARIEL program, could reveal more \
about exoplanets like K2-18b."]; about exoplanets like K2-18b."];
// Credits: WikiNews, CC BY 2.5 license (https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b) // Credits: WikiNews, CC BY 2.5 license (https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b)
let output = model.summarize(&input); let output = model.summarize(&input);
assert_eq!(output.len(), 1); assert_eq!(output.len(), 1);
@ -107,8 +114,7 @@ about exoplanets like K2-18b."];
#[test] #[test]
#[cfg_attr(not(feature = "all-tests"), ignore)] #[cfg_attr(not(feature = "all-tests"), ignore)]
fn bart_summarization_beam_search() -> failure::Fallible<()> { fn bart_summarization_beam_search() -> failure::Fallible<()> {
// Set-up masked LM model
// Set-up masked LM model
let summarization_config = SummarizationConfig { let summarization_config = SummarizationConfig {
num_beams: 3, num_beams: 3,
device: Device::Cpu, device: Device::Cpu,
@ -138,7 +144,7 @@ on K2-18b lasts 33 Earth days. According to The Guardian, astronomers were optim
telescope scheduled for launch in 2021 and the European Space Agency's 2028 ARIEL program, could reveal more \ telescope scheduled for launch in 2021 and the European Space Agency's 2028 ARIEL program, could reveal more \
about exoplanets like K2-18b."]; about exoplanets like K2-18b."];
// Credits: WikiNews, CC BY 2.5 license (https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b) // Credits: WikiNews, CC BY 2.5 license (https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b)
let output = model.summarize(&input); let output = model.summarize(&input);
assert_eq!(output.len(), 1); assert_eq!(output.len(), 1);
@ -148,4 +154,4 @@ about exoplanets like K2-18b."];
star as the planet passed between it and Earth."); star as the planet passed between it and Earth.");
Ok(()) Ok(())
} }

View File

@ -1,27 +1,32 @@
extern crate failure;
extern crate dirs; extern crate dirs;
extern crate failure;
use tch::{Device, nn, Tensor, no_grad}; use rust_bert::bert::{
use rust_tokenizers::{BertTokenizer, TruncationStrategy, Tokenizer, Vocab}; BertConfig, BertConfigResources, BertForMaskedLM, BertForMultipleChoice,
use rust_bert::Config; BertForQuestionAnswering, BertForSequenceClassification, BertForTokenClassification,
use rust_bert::bert::{BertConfig, BertForMaskedLM, BertForSequenceClassification, BertForMultipleChoice, BertForTokenClassification, BertForQuestionAnswering, BertModelResources, BertVocabResources,
BertConfigResources, BertVocabResources, BertModelResources}; };
use rust_bert::pipelines::ner::NERModel; use rust_bert::pipelines::ner::NERModel;
use rust_bert::resources::{Resource, RemoteResource, download_resource}; use rust_bert::resources::{download_resource, RemoteResource, Resource};
use rust_bert::Config;
use rust_tokenizers::{BertTokenizer, Tokenizer, TruncationStrategy, Vocab};
use std::collections::HashMap; use std::collections::HashMap;
use tch::{nn, no_grad, Device, Tensor};
#[test] #[test]
fn bert_masked_lm() -> failure::Fallible<()> { fn bert_masked_lm() -> failure::Fallible<()> {
// Resources paths // Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT)); let config_resource =
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT)); Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(BertModelResources::BERT)); let vocab_resource =
Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT));
let weights_resource =
Resource::Remote(RemoteResource::from_pretrained(BertModelResources::BERT));
let config_path = download_resource(&config_resource)?; let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?; let vocab_path = download_resource(&vocab_resource)?;
let weights_path = download_resource(&weights_resource)?; let weights_path = download_resource(&weights_resource)?;
// Set-up masked LM model // Set-up masked LM model
let device = Device::Cpu; let device = Device::Cpu;
let mut vs = nn::VarStore::new(device); let mut vs = nn::VarStore::new(device);
let tokenizer: BertTokenizer = BertTokenizer::from_file(vocab_path.to_str().unwrap(), true); let tokenizer: BertTokenizer = BertTokenizer::from_file(vocab_path.to_str().unwrap(), true);
@ -29,50 +34,58 @@ fn bert_masked_lm() -> failure::Fallible<()> {
let bert_model = BertForMaskedLM::new(&vs.root(), &config); let bert_model = BertForMaskedLM::new(&vs.root(), &config);
vs.load(weights_path)?; vs.load(weights_path)?;
// Define input // Define input
let input = ["Looks like one thing is missing", "It\'s like comparing oranges to apples"]; let input = [
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0); "Looks like one thing is missing",
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap(); "It\'s like comparing oranges to apples",
let mut tokenized_input = tokenized_input. ];
iter(). let tokenized_input =
map(|input| input.token_ids.clone()). tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
map(|mut input| { let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let mut tokenized_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]); input.extend(vec![0; max_len - input.len()]);
input input
}). })
collect::<Vec<_>>(); .collect::<Vec<_>>();
// Masking the token [thing] of sentence 1 and [oranges] of sentence 2 // Masking the token [thing] of sentence 1 and [oranges] of sentence 2
tokenized_input[0][4] = 103; tokenized_input[0][4] = 103;
tokenized_input[1][6] = 103; tokenized_input[1][6] = 103;
let tokenized_input = tokenized_input. let tokenized_input = tokenized_input
iter(). .iter()
map(|input| .map(|input| Tensor::of_slice(&(input)))
Tensor::of_slice(&(input))). .collect::<Vec<_>>();
collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device); let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass // Forward pass
let (output, _, _) = no_grad(|| { let (output, _, _) = no_grad(|| {
bert_model bert_model.forward_t(
.forward_t(Some(input_tensor), Some(input_tensor),
None, None,
None, None,
None, None,
None, None,
&None, &None,
&None, &None,
false) false,
)
}); });
// Print masked tokens // Print masked tokens
let index_1 = output.get(0).get(4).argmax(0, false); let index_1 = output.get(0).get(4).argmax(0, false);
let index_2 = output.get(1).get(6).argmax(0, false); let index_2 = output.get(1).get(6).argmax(0, false);
let word_1 = tokenizer.vocab().id_to_token(&index_1.int64_value(&[])); let word_1 = tokenizer.vocab().id_to_token(&index_1.int64_value(&[]));
let word_2 = tokenizer.vocab().id_to_token(&index_2.int64_value(&[])); let word_2 = tokenizer.vocab().id_to_token(&index_2.int64_value(&[]));
assert_eq!("person", word_1); // Outputs "person" : "Looks like one [person] is missing" assert_eq!("person", word_1); // Outputs "person" : "Looks like one [person] is missing"
assert_eq!("orange", word_2);// Outputs "pear" : "It\'s like comparing [pear] to apples" assert_eq!("orange", word_2); // Outputs "pear" : "It\'s like comparing [pear] to apples"
Ok(()) Ok(())
} }
@ -80,12 +93,14 @@ fn bert_masked_lm() -> failure::Fallible<()> {
#[test] #[test]
fn bert_for_sequence_classification() -> failure::Fallible<()> { fn bert_for_sequence_classification() -> failure::Fallible<()> {
// Resources paths // Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT)); let config_resource =
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT)); Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT));
let vocab_resource =
Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT));
let config_path = download_resource(&config_resource)?; let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?; let vocab_path = download_resource(&vocab_resource)?;
// Set-up model // Set-up model
let device = Device::Cpu; let device = Device::Cpu;
let vs = nn::VarStore::new(device); let vs = nn::VarStore::new(device);
let tokenizer: BertTokenizer = BertTokenizer::from_file(vocab_path.to_str().unwrap(), true); let tokenizer: BertTokenizer = BertTokenizer::from_file(vocab_path.to_str().unwrap(), true);
@ -99,37 +114,42 @@ fn bert_for_sequence_classification() -> failure::Fallible<()> {
config.output_hidden_states = Some(true); config.output_hidden_states = Some(true);
let bert_model = BertForSequenceClassification::new(&vs.root(), &config); let bert_model = BertForSequenceClassification::new(&vs.root(), &config);
// Define input
// Define input let input = [
let input = ["Looks like one thing is missing", "It\'s like comparing oranges to apples"]; "Looks like one thing is missing",
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0); "It\'s like comparing oranges to apples",
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap(); ];
let tokenized_input = tokenized_input. let tokenized_input =
iter(). tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
map(|input| input.token_ids.clone()). let max_len = tokenized_input
map(|mut input| { .iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let tokenized_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]); input.extend(vec![0; max_len - input.len()]);
input input
}). })
map(|input| .map(|input| Tensor::of_slice(&(input)))
Tensor::of_slice(&(input))). .collect::<Vec<_>>();
collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device); let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass // Forward pass
let (output, all_hidden_states, all_attentions) = no_grad(|| { let (output, all_hidden_states, all_attentions) =
bert_model no_grad(|| bert_model.forward_t(Some(input_tensor), None, None, None, None, false));
.forward_t(Some(input_tensor),
None,
None,
None,
None,
false)
});
assert_eq!(output.size(), &[2, 3]); assert_eq!(output.size(), &[2, 3]);
assert_eq!(config.num_hidden_layers as usize, all_hidden_states.unwrap().len()); assert_eq!(
assert_eq!(config.num_hidden_layers as usize, all_attentions.unwrap().len()); config.num_hidden_layers as usize,
all_hidden_states.unwrap().len()
);
assert_eq!(
config.num_hidden_layers as usize,
all_attentions.unwrap().len()
);
Ok(()) Ok(())
} }
@ -137,12 +157,14 @@ fn bert_for_sequence_classification() -> failure::Fallible<()> {
#[test] #[test]
fn bert_for_multiple_choice() -> failure::Fallible<()> { fn bert_for_multiple_choice() -> failure::Fallible<()> {
// Resources paths // Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT)); let config_resource =
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT)); Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT));
let vocab_resource =
Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT));
let config_path = download_resource(&config_resource)?; let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?; let vocab_path = download_resource(&vocab_resource)?;
// Set-up model // Set-up model
let device = Device::Cpu; let device = Device::Cpu;
let vs = nn::VarStore::new(device); let vs = nn::VarStore::new(device);
let tokenizer: BertTokenizer = BertTokenizer::from_file(vocab_path.to_str().unwrap(), true); let tokenizer: BertTokenizer = BertTokenizer::from_file(vocab_path.to_str().unwrap(), true);
@ -151,35 +173,44 @@ fn bert_for_multiple_choice() -> failure::Fallible<()> {
config.output_hidden_states = Some(true); config.output_hidden_states = Some(true);
let bert_model = BertForMultipleChoice::new(&vs.root(), &config); let bert_model = BertForMultipleChoice::new(&vs.root(), &config);
// Define input // Define input
let input = ["Looks like one thing is missing", "It\'s like comparing oranges to apples"]; let input = [
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0); "Looks like one thing is missing",
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap(); "It\'s like comparing oranges to apples",
let tokenized_input = tokenized_input. ];
iter(). let tokenized_input =
map(|input| input.token_ids.clone()). tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
map(|mut input| { let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let tokenized_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]); input.extend(vec![0; max_len - input.len()]);
input input
}). })
map(|input| .map(|input| Tensor::of_slice(&(input)))
Tensor::of_slice(&(input))). .collect::<Vec<_>>();
collect::<Vec<_>>(); let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0)
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device).unsqueeze(0); .to(device)
.unsqueeze(0);
// Forward pass // Forward pass
let (output, all_hidden_states, all_attentions) = no_grad(|| { let (output, all_hidden_states, all_attentions) =
bert_model no_grad(|| bert_model.forward_t(input_tensor, None, None, None, false));
.forward_t(input_tensor,
None,
None,
None,
false)
});
assert_eq!(output.size(), &[1, 2]); assert_eq!(output.size(), &[1, 2]);
assert_eq!(config.num_hidden_layers as usize, all_hidden_states.unwrap().len()); assert_eq!(
assert_eq!(config.num_hidden_layers as usize, all_attentions.unwrap().len()); config.num_hidden_layers as usize,
all_hidden_states.unwrap().len()
);
assert_eq!(
config.num_hidden_layers as usize,
all_attentions.unwrap().len()
);
Ok(()) Ok(())
} }
@ -187,12 +218,14 @@ fn bert_for_multiple_choice() -> failure::Fallible<()> {
#[test] #[test]
fn bert_for_token_classification() -> failure::Fallible<()> { fn bert_for_token_classification() -> failure::Fallible<()> {
// Resources paths // Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT)); let config_resource =
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT)); Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT));
let vocab_resource =
Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT));
let config_path = download_resource(&config_resource)?; let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?; let vocab_path = download_resource(&vocab_resource)?;
// Set-up model // Set-up model
let device = Device::Cpu; let device = Device::Cpu;
let vs = nn::VarStore::new(device); let vs = nn::VarStore::new(device);
let tokenizer: BertTokenizer = BertTokenizer::from_file(vocab_path.to_str().unwrap(), true); let tokenizer: BertTokenizer = BertTokenizer::from_file(vocab_path.to_str().unwrap(), true);
@ -207,37 +240,42 @@ fn bert_for_token_classification() -> failure::Fallible<()> {
config.output_hidden_states = Some(true); config.output_hidden_states = Some(true);
let bert_model = BertForTokenClassification::new(&vs.root(), &config); let bert_model = BertForTokenClassification::new(&vs.root(), &config);
// Define input
// Define input let input = [
let input = ["Looks like one thing is missing", "It\'s like comparing oranges to apples"]; "Looks like one thing is missing",
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0); "It\'s like comparing oranges to apples",
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap(); ];
let tokenized_input = tokenized_input. let tokenized_input =
iter(). tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
map(|input| input.token_ids.clone()). let max_len = tokenized_input
map(|mut input| { .iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let tokenized_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]); input.extend(vec![0; max_len - input.len()]);
input input
}). })
map(|input| .map(|input| Tensor::of_slice(&(input)))
Tensor::of_slice(&(input))). .collect::<Vec<_>>();
collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device); let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass // Forward pass
let (output, all_hidden_states, all_attentions) = no_grad(|| { let (output, all_hidden_states, all_attentions) =
bert_model no_grad(|| bert_model.forward_t(Some(input_tensor), None, None, None, None, false));
.forward_t(Some(input_tensor),
None,
None,
None,
None,
false)
});
assert_eq!(output.size(), &[2, 11, 4]); assert_eq!(output.size(), &[2, 11, 4]);
assert_eq!(config.num_hidden_layers as usize, all_hidden_states.unwrap().len()); assert_eq!(
assert_eq!(config.num_hidden_layers as usize, all_attentions.unwrap().len()); config.num_hidden_layers as usize,
all_hidden_states.unwrap().len()
);
assert_eq!(
config.num_hidden_layers as usize,
all_attentions.unwrap().len()
);
Ok(()) Ok(())
} }
@ -245,12 +283,14 @@ fn bert_for_token_classification() -> failure::Fallible<()> {
#[test] #[test]
fn bert_for_question_answering() -> failure::Fallible<()> { fn bert_for_question_answering() -> failure::Fallible<()> {
// Resources paths // Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT)); let config_resource =
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT)); Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT));
let vocab_resource =
Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT));
let config_path = download_resource(&config_resource)?; let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?; let vocab_path = download_resource(&vocab_resource)?;
// Set-up model // Set-up model
let device = Device::Cpu; let device = Device::Cpu;
let vs = nn::VarStore::new(device); let vs = nn::VarStore::new(device);
let tokenizer: BertTokenizer = BertTokenizer::from_file(vocab_path.to_str().unwrap(), true); let tokenizer: BertTokenizer = BertTokenizer::from_file(vocab_path.to_str().unwrap(), true);
@ -259,53 +299,59 @@ fn bert_for_question_answering() -> failure::Fallible<()> {
config.output_hidden_states = Some(true); config.output_hidden_states = Some(true);
let bert_model = BertForQuestionAnswering::new(&vs.root(), &config); let bert_model = BertForQuestionAnswering::new(&vs.root(), &config);
// Define input // Define input
let input = ["Looks like one thing is missing", "It\'s like comparing oranges to apples"]; let input = [
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0); "Looks like one thing is missing",
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap(); "It\'s like comparing oranges to apples",
let tokenized_input = tokenized_input. ];
iter(). let tokenized_input =
map(|input| input.token_ids.clone()). tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
map(|mut input| { let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let tokenized_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]); input.extend(vec![0; max_len - input.len()]);
input input
}). })
map(|input| .map(|input| Tensor::of_slice(&(input)))
Tensor::of_slice(&(input))). .collect::<Vec<_>>();
collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device); let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass // Forward pass
let (start_scores, end_scores, all_hidden_states, all_attentions) = no_grad(|| { let (start_scores, end_scores, all_hidden_states, all_attentions) =
bert_model no_grad(|| bert_model.forward_t(Some(input_tensor), None, None, None, None, false));
.forward_t(Some(input_tensor),
None,
None,
None,
None,
false)
});
assert_eq!(start_scores.size(), &[2, 11]); assert_eq!(start_scores.size(), &[2, 11]);
assert_eq!(end_scores.size(), &[2, 11]); assert_eq!(end_scores.size(), &[2, 11]);
assert_eq!(config.num_hidden_layers as usize, all_hidden_states.unwrap().len()); assert_eq!(
assert_eq!(config.num_hidden_layers as usize, all_attentions.unwrap().len()); config.num_hidden_layers as usize,
all_hidden_states.unwrap().len()
);
assert_eq!(
config.num_hidden_layers as usize,
all_attentions.unwrap().len()
);
Ok(()) Ok(())
} }
#[test] #[test]
fn bert_pre_trained_ner() -> failure::Fallible<()> { fn bert_pre_trained_ner() -> failure::Fallible<()> {
// Set-up model // Set-up model
let ner_model = NERModel::new(Default::default())?; let ner_model = NERModel::new(Default::default())?;
// Define input // Define input
let input = [ let input = [
"My name is Amy. I live in Paris.", "My name is Amy. I live in Paris.",
"Paris is a city in France." "Paris is a city in France.",
]; ];
// Run model // Run model
let output = ner_model.predict(&input); let output = ner_model.predict(&input);
assert_eq!(output.len(), 4); assert_eq!(output.len(), 4);
@ -327,4 +373,4 @@ fn bert_pre_trained_ner() -> failure::Fallible<()> {
assert_eq!(output[3].label, "I-LOC"); assert_eq!(output[3].label, "I-LOC");
Ok(()) Ok(())
} }

View File

@ -1,22 +1,26 @@
use tch::{Device, Tensor, nn, no_grad}; use rust_bert::distilbert::{
use rust_tokenizers::preprocessing::tokenizer::base_tokenizer::{Tokenizer, TruncationStrategy}; DistilBertConfig, DistilBertConfigResources, DistilBertForQuestionAnswering,
use rust_tokenizers::bert_tokenizer::BertTokenizer; DistilBertForTokenClassification, DistilBertModelMaskedLM, DistilBertModelResources,
use rust_tokenizers::preprocessing::vocab::base_vocab::Vocab; DistilBertVocabResources,
use rust_bert::Config; };
use rust_bert::distilbert::{DistilBertConfig, DistilBertModelMaskedLM, DistilBertForQuestionAnswering, DistilBertForTokenClassification, DistilBertModelResources, DistilBertConfigResources, DistilBertVocabResources}; use rust_bert::pipelines::question_answering::{QaInput, QuestionAnsweringModel};
use rust_bert::pipelines::sentiment::{SentimentModel, SentimentPolarity}; use rust_bert::pipelines::sentiment::{SentimentModel, SentimentPolarity};
use rust_bert::pipelines::question_answering::{QuestionAnsweringModel, QaInput}; use rust_bert::resources::{download_resource, RemoteResource, Resource};
use rust_bert::resources::{Resource, RemoteResource, download_resource}; use rust_bert::Config;
use rust_tokenizers::bert_tokenizer::BertTokenizer;
use rust_tokenizers::preprocessing::tokenizer::base_tokenizer::{Tokenizer, TruncationStrategy};
use rust_tokenizers::preprocessing::vocab::base_vocab::Vocab;
use std::collections::HashMap; use std::collections::HashMap;
use tch::{nn, no_grad, Device, Tensor};
extern crate failure; extern crate failure;
#[test] #[test]
fn distilbert_sentiment_classifier() -> failure::Fallible<()> { fn distilbert_sentiment_classifier() -> failure::Fallible<()> {
// Set-up classifier // Set-up classifier
let sentiment_classifier = SentimentModel::new(Default::default())?; let sentiment_classifier = SentimentModel::new(Default::default())?;
// Get sentiments // Get sentiments
let input = [ let input = [
"Probably my all-time favorite movie, a story of selflessness, sacrifice and dedication to a noble cause, but it's not preachy or boring.", "Probably my all-time favorite movie, a story of selflessness, sacrifice and dedication to a noble cause, but it's not preachy or boring.",
"This film tried to be too many things all at once: stinging political satire, Hollywood blockbuster, sappy romantic comedy, family values promo...", "This film tried to be too many things all at once: stinging political satire, Hollywood blockbuster, sappy romantic comedy, family values promo...",
@ -36,18 +40,23 @@ fn distilbert_sentiment_classifier() -> failure::Fallible<()> {
Ok(()) Ok(())
} }
#[test] #[test]
fn distilbert_masked_lm() -> failure::Fallible<()> { fn distilbert_masked_lm() -> failure::Fallible<()> {
// Resources paths // Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(DistilBertConfigResources::DISTIL_BERT)); let config_resource = Resource::Remote(RemoteResource::from_pretrained(
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(DistilBertVocabResources::DISTIL_BERT)); DistilBertConfigResources::DISTIL_BERT,
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(DistilBertModelResources::DISTIL_BERT)); ));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
DistilBertVocabResources::DISTIL_BERT,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
DistilBertModelResources::DISTIL_BERT,
));
let config_path = download_resource(&config_resource)?; let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?; let vocab_path = download_resource(&vocab_resource)?;
let weights_path = download_resource(&weights_resource)?; let weights_path = download_resource(&weights_resource)?;
// Set-up masked LM model // Set-up masked LM model
let device = Device::cuda_if_available(); let device = Device::cuda_if_available();
let mut vs = nn::VarStore::new(device); let mut vs = nn::VarStore::new(device);
let tokenizer: BertTokenizer = BertTokenizer::from_file(vocab_path.to_str().unwrap(), true); let tokenizer: BertTokenizer = BertTokenizer::from_file(vocab_path.to_str().unwrap(), true);
@ -55,59 +64,68 @@ fn distilbert_masked_lm() -> failure::Fallible<()> {
let distil_bert_model = DistilBertModelMaskedLM::new(&vs.root(), &config); let distil_bert_model = DistilBertModelMaskedLM::new(&vs.root(), &config);
vs.load(weights_path)?; vs.load(weights_path)?;
// Define input // Define input
let input = ["Looks like one thing is missing", "It\'s like comparing oranges to apples"]; let input = [
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0); "Looks like one thing is missing",
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap(); "It\'s like comparing oranges to apples",
let mut tokenized_input = tokenized_input. ];
iter(). let tokenized_input =
map(|input| input.token_ids.clone()). tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
map(|mut input| { let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let mut tokenized_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]); input.extend(vec![0; max_len - input.len()]);
input input
}). })
collect::<Vec<_>>(); .collect::<Vec<_>>();
// Masking the token [thing] of sentence 1 and [oranges] of sentence 2 // Masking the token [thing] of sentence 1 and [oranges] of sentence 2
tokenized_input[0][4] = 103; tokenized_input[0][4] = 103;
tokenized_input[1][6] = 103; tokenized_input[1][6] = 103;
let tokenized_input = tokenized_input. let tokenized_input = tokenized_input
iter(). .iter()
map(|input| .map(|input| Tensor::of_slice(&(input)))
Tensor::of_slice(&(input))). .collect::<Vec<_>>();
collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device); let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
// Forward pass
let (output, _, _) = no_grad(|| { let (output, _, _) = no_grad(|| {
distil_bert_model distil_bert_model
.forward_t(Some(input_tensor), None, None, false) .forward_t(Some(input_tensor), None, None, false)
.unwrap() .unwrap()
}); });
// Print masked tokens // Print masked tokens
let index_1 = output.get(0).get(4).argmax(0, false); let index_1 = output.get(0).get(4).argmax(0, false);
let index_2 = output.get(1).get(6).argmax(0, false); let index_2 = output.get(1).get(6).argmax(0, false);
let word_1 = tokenizer.vocab().id_to_token(&index_1.int64_value(&[])); let word_1 = tokenizer.vocab().id_to_token(&index_1.int64_value(&[]));
let word_2 = tokenizer.vocab().id_to_token(&index_2.int64_value(&[])); let word_2 = tokenizer.vocab().id_to_token(&index_2.int64_value(&[]));
assert_eq!("person", word_1); // Outputs "person" : "Looks like one [person] is missing" assert_eq!("person", word_1); // Outputs "person" : "Looks like one [person] is missing"
assert_eq!("pear", word_2);// Outputs "pear" : "It\'s like comparing [pear] to apples" assert_eq!("pear", word_2); // Outputs "pear" : "It\'s like comparing [pear] to apples"
Ok(()) Ok(())
} }
#[test] #[test]
fn distilbert_for_question_answering() -> failure::Fallible<()> { fn distilbert_for_question_answering() -> failure::Fallible<()> {
// Resources paths
// Resources paths let config_resource = Resource::Remote(RemoteResource::from_pretrained(
let config_resource = Resource::Remote(RemoteResource::from_pretrained(DistilBertConfigResources::DISTIL_BERT_SQUAD)); DistilBertConfigResources::DISTIL_BERT_SQUAD,
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(DistilBertVocabResources::DISTIL_BERT_SQUAD)); ));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
DistilBertVocabResources::DISTIL_BERT_SQUAD,
));
let config_path = download_resource(&config_resource)?; let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?; let vocab_path = download_resource(&vocab_resource)?;
// Set-up masked LM model // Set-up masked LM model
let device = Device::cuda_if_available(); let device = Device::cuda_if_available();
let vs = nn::VarStore::new(device); let vs = nn::VarStore::new(device);
let tokenizer: BertTokenizer = BertTokenizer::from_file(vocab_path.to_str().unwrap(), true); let tokenizer: BertTokenizer = BertTokenizer::from_file(vocab_path.to_str().unwrap(), true);
@ -116,23 +134,30 @@ fn distilbert_for_question_answering() -> failure::Fallible<()> {
config.output_hidden_states = Some(true); config.output_hidden_states = Some(true);
let distil_bert_model = DistilBertForQuestionAnswering::new(&vs.root(), &config); let distil_bert_model = DistilBertForQuestionAnswering::new(&vs.root(), &config);
// Define input // Define input
let input = ["Looks like one thing is missing", "It\'s like comparing oranges to apples"]; let input = [
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0); "Looks like one thing is missing",
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap(); "It\'s like comparing oranges to apples",
let tokenized_input = tokenized_input. ];
iter(). let tokenized_input =
map(|input| input.token_ids.clone()). tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
map(|mut input| { let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let tokenized_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]); input.extend(vec![0; max_len - input.len()]);
input input
}). })
map(|input| .map(|input| Tensor::of_slice(&(input)))
Tensor::of_slice(&(input))). .collect::<Vec<_>>();
collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device); let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass // Forward pass
let (start_scores, end_scores, all_hidden_states, all_attentions) = no_grad(|| { let (start_scores, end_scores, all_hidden_states, all_attentions) = no_grad(|| {
distil_bert_model distil_bert_model
.forward_t(Some(input_tensor), None, None, false) .forward_t(Some(input_tensor), None, None, false)
@ -149,14 +174,17 @@ fn distilbert_for_question_answering() -> failure::Fallible<()> {
#[test] #[test]
fn distilbert_for_token_classification() -> failure::Fallible<()> { fn distilbert_for_token_classification() -> failure::Fallible<()> {
// Resources paths
// Resources paths let config_resource = Resource::Remote(RemoteResource::from_pretrained(
let config_resource = Resource::Remote(RemoteResource::from_pretrained(DistilBertConfigResources::DISTIL_BERT)); DistilBertConfigResources::DISTIL_BERT,
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(DistilBertVocabResources::DISTIL_BERT)); ));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
DistilBertVocabResources::DISTIL_BERT,
));
let config_path = download_resource(&config_resource)?; let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?; let vocab_path = download_resource(&vocab_resource)?;
// Set-up masked LM model // Set-up masked LM model
let device = Device::cuda_if_available(); let device = Device::cuda_if_available();
let vs = nn::VarStore::new(device); let vs = nn::VarStore::new(device);
let tokenizer: BertTokenizer = BertTokenizer::from_file(vocab_path.to_str().unwrap(), true); let tokenizer: BertTokenizer = BertTokenizer::from_file(vocab_path.to_str().unwrap(), true);
@ -171,23 +199,30 @@ fn distilbert_for_token_classification() -> failure::Fallible<()> {
config.id2label = Some(dummy_label_mapping); config.id2label = Some(dummy_label_mapping);
let distil_bert_model = DistilBertForTokenClassification::new(&vs.root(), &config); let distil_bert_model = DistilBertForTokenClassification::new(&vs.root(), &config);
// Define input // Define input
let input = ["Looks like one thing is missing", "It\'s like comparing oranges to apples"]; let input = [
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0); "Looks like one thing is missing",
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap(); "It\'s like comparing oranges to apples",
let tokenized_input = tokenized_input. ];
iter(). let tokenized_input =
map(|input| input.token_ids.clone()). tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
map(|mut input| { let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let tokenized_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]); input.extend(vec![0; max_len - input.len()]);
input input
}). })
map(|input| .map(|input| Tensor::of_slice(&(input)))
Tensor::of_slice(&(input))). .collect::<Vec<_>>();
collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device); let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass // Forward pass
let (output, all_hidden_states, all_attentions) = no_grad(|| { let (output, all_hidden_states, all_attentions) = no_grad(|| {
distil_bert_model distil_bert_model
.forward_t(Some(input_tensor), None, None, false) .forward_t(Some(input_tensor), None, None, false)
@ -203,15 +238,15 @@ fn distilbert_for_token_classification() -> failure::Fallible<()> {
#[test] #[test]
fn distilbert_question_answering() -> failure::Fallible<()> { fn distilbert_question_answering() -> failure::Fallible<()> {
// Set-up question answering model // Set-up question answering model
let qa_model = QuestionAnsweringModel::new(Default::default())?; let qa_model = QuestionAnsweringModel::new(Default::default())?;
// Define input // Define input
let question = String::from("Where does Amy live ?"); let question = String::from("Where does Amy live ?");
let context = String::from("Amy lives in Amsterdam"); let context = String::from("Amy lives in Amsterdam");
let qa_input = QaInput { question, context }; let qa_input = QaInput { question, context };
let answers = qa_model.predict(&vec!(qa_input), 1, 32); let answers = qa_model.predict(&vec![qa_input], 1, 32);
assert_eq!(answers.len(), 1 as usize); assert_eq!(answers.len(), 1 as usize);
assert_eq!(answers[0].len(), 1 as usize); assert_eq!(answers[0].len(), 1 as usize);
@ -221,4 +256,4 @@ fn distilbert_question_answering() -> failure::Fallible<()> {
assert_eq!(answers[0][0].answer, "Amsterdam"); assert_eq!(answers[0][0].answer, "Amsterdam");
Ok(()) Ok(())
} }

View File

@ -1,73 +1,100 @@
use tch::{Device, nn, Tensor}; use rust_bert::gpt2::{
use rust_tokenizers::{Gpt2Tokenizer, TruncationStrategy, Tokenizer}; GPT2LMHeadModel, Gpt2Config, Gpt2ConfigResources, Gpt2MergesResources, Gpt2ModelResources,
Gpt2VocabResources,
};
use rust_bert::pipelines::generation::{Cache, LMHeadModel};
use rust_bert::resources::{download_resource, RemoteResource, Resource};
use rust_bert::Config; use rust_bert::Config;
use rust_bert::gpt2::{Gpt2Config, GPT2LMHeadModel, Gpt2ConfigResources, Gpt2VocabResources, Gpt2MergesResources, Gpt2ModelResources}; use rust_tokenizers::{Gpt2Tokenizer, Tokenizer, TruncationStrategy};
use rust_bert::pipelines::generation::{LMHeadModel, Cache}; use tch::{nn, Device, Tensor};
use rust_bert::resources::{Resource, download_resource, RemoteResource};
#[test] #[test]
fn distilgpt2_lm_model() -> failure::Fallible<()> { fn distilgpt2_lm_model() -> failure::Fallible<()> {
// Resources paths // Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::DISTIL_GPT2)); let config_resource = Resource::Remote(RemoteResource::from_pretrained(
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::DISTIL_GPT2)); Gpt2ConfigResources::DISTIL_GPT2,
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::DISTIL_GPT2)); ));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::DISTIL_GPT2)); let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
Gpt2VocabResources::DISTIL_GPT2,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
Gpt2MergesResources::DISTIL_GPT2,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
Gpt2ModelResources::DISTIL_GPT2,
));
let config_path = download_resource(&config_resource)?; let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?; let vocab_path = download_resource(&vocab_resource)?;
let merges_path = download_resource(&merges_resource)?; let merges_path = download_resource(&merges_resource)?;
let weights_path = download_resource(&weights_resource)?; let weights_path = download_resource(&weights_resource)?;
// Set-up masked LM model // Set-up masked LM model
let device = Device::Cpu; let device = Device::Cpu;
let mut vs = nn::VarStore::new(device); let mut vs = nn::VarStore::new(device);
let tokenizer: Gpt2Tokenizer = Gpt2Tokenizer::from_file(vocab_path.to_str().unwrap(), merges_path.to_str().unwrap(), false); let tokenizer: Gpt2Tokenizer = Gpt2Tokenizer::from_file(
vocab_path.to_str().unwrap(),
merges_path.to_str().unwrap(),
false,
);
let config = Gpt2Config::from_file(config_path); let config = Gpt2Config::from_file(config_path);
let gpt2_model = GPT2LMHeadModel::new(&vs.root(), &config); let gpt2_model = GPT2LMHeadModel::new(&vs.root(), &config);
vs.load(weights_path)?; vs.load(weights_path)?;
// Define input // Define input
let input = ["One two three four five six seven eight nine ten eleven"]; let input = ["One two three four five six seven eight nine ten eleven"];
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0); let tokenized_input =
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap(); tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let tokenized_input = tokenized_input. let max_len = tokenized_input
iter(). .iter()
map(|input| input.token_ids.clone()). .map(|input| input.token_ids.len())
map(|mut input| { .max()
.unwrap();
let tokenized_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]); input.extend(vec![0; max_len - input.len()]);
input input
}). })
map(|input| .map(|input| Tensor::of_slice(&(input)))
Tensor::of_slice(&(input))). .collect::<Vec<_>>();
collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device); let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass // Forward pass
let (output, _, past, _, _) = gpt2_model.forward_t( let (output, _, past, _, _) = gpt2_model
&Some(input_tensor), .forward_t(
Cache::None, &Some(input_tensor),
&None, Cache::None,
&None, &None,
&None, &None,
&None, &None,
None, &None,
&None, None,
false).unwrap(); &None,
false,
)
.unwrap();
let next_word_id = output.get(0).get(-1).argmax(-1, true).int64_value(&[0]); let next_word_id = output.get(0).get(-1).argmax(-1, true).int64_value(&[0]);
let next_word = tokenizer.decode(vec!(next_word_id), true, true); let next_word = tokenizer.decode(vec![next_word_id], true, true);
assert_eq!(output.size(), vec!(1, 11, 50257)); assert_eq!(output.size(), vec!(1, 11, 50257));
match past { match past {
Cache::GPT2Cache(past) => { Cache::GPT2Cache(past) => {
assert!(past.is_some()); assert!(past.is_some());
assert_eq!(past.as_ref().unwrap().len(), config.n_layer as usize); assert_eq!(past.as_ref().unwrap().len(), config.n_layer as usize);
assert_eq!(past.as_ref().unwrap()[0].size(), vec!(2, 1, config.n_head, 11, 64)); assert_eq!(
past.as_ref().unwrap()[0].size(),
vec!(2, 1, config.n_head, 11, 64)
);
} }
_ => panic!("Wrong cache returned for GPT2") _ => panic!("Wrong cache returned for GPT2"),
} }
assert!((output.double_value(&[0, output.size()[1] - 1, next_word_id]) - (-48.7065)).abs() < 1e-4); assert!(
(output.double_value(&[0, output.size()[1] - 1, next_word_id]) - (-48.7065)).abs() < 1e-4
);
assert_eq!(next_word_id, 14104i64); assert_eq!(next_word_id, 14104i64);
assert_eq!(next_word, String::from(" twelve")); assert_eq!(next_word, String::from(" twelve"));
Ok(()) Ok(())
} }

View File

@ -1,20 +1,29 @@
use rust_bert::resources::{Resource, download_resource, RemoteResource}; use rust_bert::electra::{
use rust_bert::electra::{ElectraConfigResources, ElectraVocabResources, ElectraModelResources, ElectraConfig, ElectraForMaskedLM, ElectraDiscriminator}; ElectraConfig, ElectraConfigResources, ElectraDiscriminator, ElectraForMaskedLM,
use tch::{Device, nn, Tensor, no_grad}; ElectraModelResources, ElectraVocabResources,
use rust_tokenizers::{BertTokenizer, TruncationStrategy, Tokenizer, Vocab}; };
use rust_bert::resources::{download_resource, RemoteResource, Resource};
use rust_bert::Config; use rust_bert::Config;
use rust_tokenizers::{BertTokenizer, Tokenizer, TruncationStrategy, Vocab};
use tch::{nn, no_grad, Device, Tensor};
#[test] #[test]
fn electra_masked_lm() -> failure::Fallible<()> { fn electra_masked_lm() -> failure::Fallible<()> {
// Resources paths // Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(ElectraConfigResources::BASE_GENERATOR)); let config_resource = Resource::Remote(RemoteResource::from_pretrained(
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(ElectraVocabResources::BASE_GENERATOR)); ElectraConfigResources::BASE_GENERATOR,
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(ElectraModelResources::BASE_GENERATOR)); ));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
ElectraVocabResources::BASE_GENERATOR,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
ElectraModelResources::BASE_GENERATOR,
));
let config_path = download_resource(&config_resource)?; let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?; let vocab_path = download_resource(&vocab_resource)?;
let weights_path = download_resource(&weights_resource)?; let weights_path = download_resource(&weights_resource)?;
// Set-up masked LM model // Set-up masked LM model
let device = Device::Cpu; let device = Device::Cpu;
let mut vs = nn::VarStore::new(device); let mut vs = nn::VarStore::new(device);
let tokenizer: BertTokenizer = BertTokenizer::from_file(vocab_path.to_str().unwrap(), true); let tokenizer: BertTokenizer = BertTokenizer::from_file(vocab_path.to_str().unwrap(), true);
@ -24,60 +33,70 @@ fn electra_masked_lm() -> failure::Fallible<()> {
let electra_model = ElectraForMaskedLM::new(&vs.root(), &config); let electra_model = ElectraForMaskedLM::new(&vs.root(), &config);
vs.load(weights_path)?; vs.load(weights_path)?;
// Define input // Define input
let input = ["Looks like one [MASK] is missing", "It was a very nice and [MASK] day"]; let input = [
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0); "Looks like one [MASK] is missing",
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap(); "It was a very nice and [MASK] day",
let tokenized_input = tokenized_input. ];
iter(). let tokenized_input =
map(|input| input.token_ids.clone()). tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
map(|mut input| { let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let tokenized_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]); input.extend(vec![0; max_len - input.len()]);
input input
}). })
map(|input| .map(|input| Tensor::of_slice(&(input)))
Tensor::of_slice(&(input))). .collect::<Vec<_>>();
collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device); let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass // Forward pass
let (output, let (output, all_hidden_states, all_attentions) =
all_hidden_states, no_grad(|| electra_model.forward_t(Some(input_tensor), None, None, None, None, false));
all_attentions) = no_grad(|| {
electra_model
.forward_t(Some(input_tensor),
None,
None,
None,
None,
false)
});
// Decode output // Decode output
let index_1 = output.get(0).get(4).argmax(0, false); let index_1 = output.get(0).get(4).argmax(0, false);
let index_2 = output.get(1).get(7).argmax(0, false); let index_2 = output.get(1).get(7).argmax(0, false);
let word_1 = tokenizer.vocab().id_to_token(&index_1.int64_value(&[])); let word_1 = tokenizer.vocab().id_to_token(&index_1.int64_value(&[]));
let word_2 = tokenizer.vocab().id_to_token(&index_2.int64_value(&[])); let word_2 = tokenizer.vocab().id_to_token(&index_2.int64_value(&[]));
assert_eq!(output.size(), &[2, 10, config.vocab_size]); assert_eq!(output.size(), &[2, 10, config.vocab_size]);
assert_eq!(config.num_hidden_layers as usize, all_hidden_states.unwrap().len()); assert_eq!(
assert_eq!(config.num_hidden_layers as usize, all_attentions.unwrap().len()); config.num_hidden_layers as usize,
all_hidden_states.unwrap().len()
);
assert_eq!(
config.num_hidden_layers as usize,
all_attentions.unwrap().len()
);
assert_eq!("thing", word_1); // Outputs "person" : "Looks like one [person] is missing" assert_eq!("thing", word_1); // Outputs "person" : "Looks like one [person] is missing"
assert_eq!("sunny", word_2);// Outputs "pear" : "It was a very nice and [sunny] day" assert_eq!("sunny", word_2); // Outputs "pear" : "It was a very nice and [sunny] day"
Ok(()) Ok(())
} }
#[test] #[test]
fn electra_discriminator() -> failure::Fallible<()> { fn electra_discriminator() -> failure::Fallible<()> {
// Resources paths // Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(ElectraConfigResources::BASE_DISCRIMINATOR)); let config_resource = Resource::Remote(RemoteResource::from_pretrained(
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(ElectraVocabResources::BASE_DISCRIMINATOR)); ElectraConfigResources::BASE_DISCRIMINATOR,
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(ElectraModelResources::BASE_DISCRIMINATOR)); ));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
ElectraVocabResources::BASE_DISCRIMINATOR,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
ElectraModelResources::BASE_DISCRIMINATOR,
));
let config_path = download_resource(&config_resource)?; let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?; let vocab_path = download_resource(&vocab_resource)?;
let weights_path = download_resource(&weights_resource)?; let weights_path = download_resource(&weights_resource)?;
// Set-up masked LM model // Set-up masked LM model
let device = Device::Cpu; let device = Device::Cpu;
let mut vs = nn::VarStore::new(device); let mut vs = nn::VarStore::new(device);
let tokenizer: BertTokenizer = BertTokenizer::from_file(vocab_path.to_str().unwrap(), true); let tokenizer: BertTokenizer = BertTokenizer::from_file(vocab_path.to_str().unwrap(), true);
@ -85,35 +104,34 @@ fn electra_discriminator() -> failure::Fallible<()> {
let electra_model = ElectraDiscriminator::new(&vs.root(), &config); let electra_model = ElectraDiscriminator::new(&vs.root(), &config);
vs.load(weights_path)?; vs.load(weights_path)?;
// Define input // Define input
let input = ["One Two Three Ten Five Six Seven Eight"]; let input = ["One Two Three Ten Five Six Seven Eight"];
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0); let tokenized_input =
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap(); tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let encoded_input = tokenized_input. let max_len = tokenized_input
iter(). .iter()
map(|input| input.token_ids.clone()). .map(|input| input.token_ids.len())
map(|mut input| { .max()
.unwrap();
let encoded_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]); input.extend(vec![0; max_len - input.len()]);
input input
}). })
map(|input| .map(|input| Tensor::of_slice(&(input)))
Tensor::of_slice(&(input))). .collect::<Vec<_>>();
collect::<Vec<_>>();
let input_tensor = Tensor::stack(encoded_input.as_slice(), 0).to(device); let input_tensor = Tensor::stack(encoded_input.as_slice(), 0).to(device);
// Forward pass // Forward pass
let (output, _, _) = no_grad(|| { let (output, _, _) =
electra_model no_grad(|| electra_model.forward_t(Some(input_tensor), None, None, None, None, false));
.forward_t(Some(input_tensor),
None,
None,
None,
None,
false)
});
// Validate model predictions // Validate model predictions
let expected_probabilities = vec!(0.0101, 0.0030, 0.0010, 0.0018, 0.9489, 0.0067, 0.0026, 0.0017, 0.0311, 0.0101); let expected_probabilities = vec![
0.0101, 0.0030, 0.0010, 0.0018, 0.9489, 0.0067, 0.0026, 0.0017, 0.0311, 0.0101,
];
let probabilities = output.iter::<f64>().unwrap().collect::<Vec<f64>>(); let probabilities = output.iter::<f64>().unwrap().collect::<Vec<f64>>();
assert_eq!(output.size(), &[10]); assert_eq!(output.size(), &[10]);
@ -122,4 +140,4 @@ fn electra_discriminator() -> failure::Fallible<()> {
} }
Ok(()) Ok(())
} }

View File

@ -1,87 +1,115 @@
use tch::{Device, nn, Tensor}; use rust_bert::gpt2::{
use rust_tokenizers::{Gpt2Tokenizer, TruncationStrategy, Tokenizer}; GPT2LMHeadModel, Gpt2Config, Gpt2ConfigResources, Gpt2MergesResources, Gpt2ModelResources,
Gpt2VocabResources,
};
use rust_bert::pipelines::generation::{
Cache, GPT2Generator, GenerateConfig, LMHeadModel, LanguageGenerator,
};
use rust_bert::resources::{download_resource, RemoteResource, Resource};
use rust_bert::Config; use rust_bert::Config;
use rust_bert::pipelines::generation::{GPT2Generator, LanguageGenerator, GenerateConfig, LMHeadModel, Cache}; use rust_tokenizers::{Gpt2Tokenizer, Tokenizer, TruncationStrategy};
use rust_bert::gpt2::{Gpt2Config, GPT2LMHeadModel, Gpt2ConfigResources, Gpt2MergesResources, Gpt2VocabResources, Gpt2ModelResources}; use tch::{nn, Device, Tensor};
use rust_bert::resources::{RemoteResource, Resource, download_resource};
#[test] #[test]
fn gpt2_lm_model() -> failure::Fallible<()> { fn gpt2_lm_model() -> failure::Fallible<()> {
// Resources paths // Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2)); let config_resource =
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2)); Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2)); let vocab_resource =
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2)); Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2));
let merges_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
let weights_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
let config_path = download_resource(&config_resource)?; let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?; let vocab_path = download_resource(&vocab_resource)?;
let merges_path = download_resource(&merges_resource)?; let merges_path = download_resource(&merges_resource)?;
let weights_path = download_resource(&weights_resource)?; let weights_path = download_resource(&weights_resource)?;
// Set-up masked LM model // Set-up masked LM model
let device = Device::Cpu; let device = Device::Cpu;
let mut vs = nn::VarStore::new(device); let mut vs = nn::VarStore::new(device);
let tokenizer: Gpt2Tokenizer = Gpt2Tokenizer::from_file(vocab_path.to_str().unwrap(), merges_path.to_str().unwrap(), false); let tokenizer: Gpt2Tokenizer = Gpt2Tokenizer::from_file(
vocab_path.to_str().unwrap(),
merges_path.to_str().unwrap(),
false,
);
let config = Gpt2Config::from_file(config_path); let config = Gpt2Config::from_file(config_path);
let gpt2_model = GPT2LMHeadModel::new(&vs.root(), &config); let gpt2_model = GPT2LMHeadModel::new(&vs.root(), &config);
vs.load(weights_path)?; vs.load(weights_path)?;
// Define input // Define input
let input = ["One two three four"]; let input = ["One two three four"];
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0); let tokenized_input =
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap(); tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let tokenized_input = tokenized_input. let max_len = tokenized_input
iter(). .iter()
map(|input| input.token_ids.clone()). .map(|input| input.token_ids.len())
map(|mut input| { .max()
.unwrap();
let tokenized_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]); input.extend(vec![0; max_len - input.len()]);
input input
}). })
map(|input| .map(|input| Tensor::of_slice(&(input)))
Tensor::of_slice(&(input))). .collect::<Vec<_>>();
collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device); let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass // Forward pass
let (output, _, past, _, _) = gpt2_model.forward_t( let (output, _, past, _, _) = gpt2_model
&Some(input_tensor), .forward_t(
Cache::None, &Some(input_tensor),
&None, Cache::None,
&None, &None,
&None, &None,
&None, &None,
None, &None,
&None, None,
false).unwrap(); &None,
false,
)
.unwrap();
let next_word_id = output.get(0).get(-1).argmax(-1, true).int64_value(&[0]); let next_word_id = output.get(0).get(-1).argmax(-1, true).int64_value(&[0]);
let next_word = tokenizer.decode(vec!(next_word_id), true, true); let next_word = tokenizer.decode(vec![next_word_id], true, true);
assert_eq!(output.size(), vec!(1, 4, 50257)); assert_eq!(output.size(), vec!(1, 4, 50257));
match past { match past {
Cache::GPT2Cache(past) => { Cache::GPT2Cache(past) => {
assert!(past.is_some()); assert!(past.is_some());
assert_eq!(past.as_ref().unwrap().len(), config.n_layer as usize); assert_eq!(past.as_ref().unwrap().len(), config.n_layer as usize);
assert_eq!(past.as_ref().unwrap()[0].size(), vec!(2, 1, config.n_head, 4, 64)); assert_eq!(
past.as_ref().unwrap()[0].size(),
vec!(2, 1, config.n_head, 4, 64)
);
} }
_ => panic!("Wrong cache returned for GPT2") _ => panic!("Wrong cache returned for GPT2"),
} }
assert!((output.double_value(&[0, output.size()[1] - 1, next_word_id]) - (-69.4948)).abs() < 1e-4); assert!(
(output.double_value(&[0, output.size()[1] - 1, next_word_id]) - (-69.4948)).abs() < 1e-4
);
assert_eq!(next_word_id, 1936i64); assert_eq!(next_word_id, 1936i64);
assert_eq!(next_word, String::from(" five")); assert_eq!(next_word, String::from(" five"));
Ok(()) Ok(())
} }
#[test] #[test]
fn gpt2_generation_greedy() -> failure::Fallible<()> { fn gpt2_generation_greedy() -> failure::Fallible<()> {
// Resources definition // Resources definition
let config_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2)); let config_resource =
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2)); Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2)); let vocab_resource =
let model_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2)); Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2));
let merges_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
let model_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
// Set-up masked LM model // Set-up masked LM model
let generate_config = GenerateConfig { let generate_config = GenerateConfig {
model_resource, model_resource,
config_resource, config_resource,
@ -97,7 +125,7 @@ fn gpt2_generation_greedy() -> failure::Fallible<()> {
let model = GPT2Generator::new(generate_config)?; let model = GPT2Generator::new(generate_config)?;
let input_context = "The cat"; let input_context = "The cat";
let output = model.generate(Some(vec!(input_context)), None); let output = model.generate(Some(vec![input_context]), None);
assert_eq!(output.len(), 1); assert_eq!(output.len(), 1);
assert_eq!(output[0], "The cat was found in a field near the town of Keflavik, about 30 miles (48 kilometers) south-east of Moscow.\n\n\n"); assert_eq!(output[0], "The cat was found in a field near the town of Keflavik, about 30 miles (48 kilometers) south-east of Moscow.\n\n\n");
@ -108,12 +136,16 @@ fn gpt2_generation_greedy() -> failure::Fallible<()> {
#[test] #[test]
fn gpt2_generation_beam_search() -> failure::Fallible<()> { fn gpt2_generation_beam_search() -> failure::Fallible<()> {
// Resources definition // Resources definition
let config_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2)); let config_resource =
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2)); Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2)); let vocab_resource =
let model_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2)); Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2));
let merges_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
let model_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
// Set-up masked LM model // Set-up masked LM model
let generate_config = GenerateConfig { let generate_config = GenerateConfig {
model_resource, model_resource,
config_resource, config_resource,
@ -129,12 +161,21 @@ fn gpt2_generation_beam_search() -> failure::Fallible<()> {
let model = GPT2Generator::new(generate_config)?; let model = GPT2Generator::new(generate_config)?;
let input_context = "The dog"; let input_context = "The dog";
let output = model.generate(Some(vec!(input_context)), None); let output = model.generate(Some(vec![input_context]), None);
assert_eq!(output.len(), 3); assert_eq!(output.len(), 3);
assert_eq!(output[0], "The dog was found in the backyard of a home in the 6200 block of South Main Street."); assert_eq!(
assert_eq!(output[1], "The dog was found in the backyard of a home in the 6500 block of South Main Street."); output[0],
assert_eq!(output[2], "The dog was found in the backyard of a home in the 6200 block of South Main Street,"); "The dog was found in the backyard of a home in the 6200 block of South Main Street."
);
assert_eq!(
output[1],
"The dog was found in the backyard of a home in the 6500 block of South Main Street."
);
assert_eq!(
output[2],
"The dog was found in the backyard of a home in the 6200 block of South Main Street,"
);
Ok(()) Ok(())
} }
@ -142,12 +183,16 @@ fn gpt2_generation_beam_search() -> failure::Fallible<()> {
#[test] #[test]
fn gpt2_generation_beam_search_multiple_prompts_without_padding() -> failure::Fallible<()> { fn gpt2_generation_beam_search_multiple_prompts_without_padding() -> failure::Fallible<()> {
// Resources definition // Resources definition
let config_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2)); let config_resource =
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2)); Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2)); let vocab_resource =
let model_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2)); Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2));
let merges_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
let model_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
// Set-up masked LM model // Set-up masked LM model
let generate_config = GenerateConfig { let generate_config = GenerateConfig {
model_resource, model_resource,
config_resource, config_resource,
@ -164,15 +209,33 @@ fn gpt2_generation_beam_search_multiple_prompts_without_padding() -> failure::Fa
let input_context_1 = "The dog"; let input_context_1 = "The dog";
let input_context_2 = "The cat"; let input_context_2 = "The cat";
let output = model.generate(Some(vec!(input_context_1, input_context_2)), None); let output = model.generate(Some(vec![input_context_1, input_context_2]), None);
assert_eq!(output.len(), 6); assert_eq!(output.len(), 6);
assert_eq!(output[0], "The dog was found in the backyard of a home in the 6200 block of South Main Street."); assert_eq!(
assert_eq!(output[1], "The dog was found in the backyard of a home in the 6500 block of South Main Street."); output[0],
assert_eq!(output[2], "The dog was found in the backyard of a home in the 6200 block of South Main Street,"); "The dog was found in the backyard of a home in the 6200 block of South Main Street."
assert_eq!(output[3], "The cat-and-mouse game.\n\n\"I think it\'s going to be interesting to"); );
assert_eq!(output[4], "The cat-and-mouse game.\n\n\"I think it\'s going to be a very"); assert_eq!(
assert_eq!(output[5], "The cat-and-mouse game.\n\n\"I think it\'s going to be very interesting"); output[1],
"The dog was found in the backyard of a home in the 6500 block of South Main Street."
);
assert_eq!(
output[2],
"The dog was found in the backyard of a home in the 6200 block of South Main Street,"
);
assert_eq!(
output[3],
"The cat-and-mouse game.\n\n\"I think it\'s going to be interesting to"
);
assert_eq!(
output[4],
"The cat-and-mouse game.\n\n\"I think it\'s going to be a very"
);
assert_eq!(
output[5],
"The cat-and-mouse game.\n\n\"I think it\'s going to be very interesting"
);
Ok(()) Ok(())
} }
@ -180,12 +243,16 @@ fn gpt2_generation_beam_search_multiple_prompts_without_padding() -> failure::Fa
#[test] #[test]
fn gpt2_generation_beam_search_multiple_prompts_with_padding() -> failure::Fallible<()> { fn gpt2_generation_beam_search_multiple_prompts_with_padding() -> failure::Fallible<()> {
// Resources definition // Resources definition
let config_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2)); let config_resource =
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2)); Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2)); let vocab_resource =
let model_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2)); Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2));
let merges_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
let model_resource =
Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
// Set-up masked LM model // Set-up masked LM model
let generate_config = GenerateConfig { let generate_config = GenerateConfig {
model_resource, model_resource,
config_resource, config_resource,
@ -202,15 +269,33 @@ fn gpt2_generation_beam_search_multiple_prompts_with_padding() -> failure::Falli
let input_context_1 = "The dog"; let input_context_1 = "The dog";
let input_context_2 = "The cat was"; let input_context_2 = "The cat was";
let output = model.generate(Some(vec!(input_context_1, input_context_2)), None); let output = model.generate(Some(vec![input_context_1, input_context_2]), None);
assert_eq!(output.len(), 6); assert_eq!(output.len(), 6);
assert_eq!(output[0], "The dog was found dead on the side of the road in the middle of the night.\n"); assert_eq!(
assert_eq!(output[1], "The dog was found dead on the side of the road in the middle of the night on Sunday"); output[0],
assert_eq!(output[2], "The dog was found dead on the side of the road in the middle of the night on Saturday"); "The dog was found dead on the side of the road in the middle of the night.\n"
assert_eq!(output[3], "The cat was taken to a local hospital, where it was treated and released.\n\nPolice said"); );
assert_eq!(output[4], "The cat was taken to a local hospital, where it was treated and released.\n\n\"It"); assert_eq!(
assert_eq!(output[5], "The cat was taken to a local hospital, where it was treated and released.\n\n\"We"); output[1],
"The dog was found dead on the side of the road in the middle of the night on Sunday"
);
assert_eq!(
output[2],
"The dog was found dead on the side of the road in the middle of the night on Saturday"
);
assert_eq!(
output[3],
"The cat was taken to a local hospital, where it was treated and released.\n\nPolice said"
);
assert_eq!(
output[4],
"The cat was taken to a local hospital, where it was treated and released.\n\n\"It"
);
assert_eq!(
output[5],
"The cat was taken to a local hospital, where it was treated and released.\n\n\"We"
);
Ok(()) Ok(())
} }

View File

@ -1,12 +1,11 @@
use rust_bert::pipelines::translation::{TranslationConfig, Language, TranslationModel}; use rust_bert::pipelines::translation::{Language, TranslationConfig, TranslationModel};
use tch::Device; use tch::Device;
#[test] #[test]
#[cfg_attr(not(feature = "all-tests"), ignore)] #[cfg_attr(not(feature = "all-tests"), ignore)]
fn test_translation() -> failure::Fallible<()> { fn test_translation() -> failure::Fallible<()> {
// Set-up translation model
// Set-up translation model let translation_config = TranslationConfig::new(Language::EnglishToFrench, Device::Cpu);
let translation_config = TranslationConfig::new(Language::EnglishToFrench, Device::Cpu);
let model = TranslationModel::new(translation_config)?; let model = TranslationModel::new(translation_config)?;
let input_context_1 = "The quick brown fox jumps over the lazy dog"; let input_context_1 = "The quick brown fox jumps over the lazy dog";
@ -15,8 +14,11 @@ fn test_translation() -> failure::Fallible<()> {
let output = model.translate(&[input_context_1, input_context_2]); let output = model.translate(&[input_context_1, input_context_2]);
assert_eq!(output.len(), 2); assert_eq!(output.len(), 2);
assert_eq!(output[0], " Le rapide renard brun saute sur le chien paresseux"); assert_eq!(
output[0],
" Le rapide renard brun saute sur le chien paresseux"
);
assert_eq!(output[1], " Le chien ne s'est pas réveillé."); assert_eq!(output[1], " Le chien ne s'est pas réveillé.");
Ok(()) Ok(())
} }

View File

@ -1,64 +1,90 @@
use tch::{Device, nn, Tensor};
use rust_tokenizers::{TruncationStrategy, Tokenizer, OpenAiGptTokenizer};
use rust_bert::Config;
use rust_bert::pipelines::generation::{OpenAIGenerator, LanguageGenerator, GenerateConfig, LMHeadModel, Cache};
use rust_bert::gpt2::Gpt2Config; use rust_bert::gpt2::Gpt2Config;
use rust_bert::openai_gpt::{OpenAIGPTLMHeadModel, OpenAiGptConfigResources, OpenAiGptVocabResources, OpenAiGptMergesResources, OpenAiGptModelResources}; use rust_bert::openai_gpt::{
use rust_bert::resources::{RemoteResource, Resource, download_resource}; OpenAIGPTLMHeadModel, OpenAiGptConfigResources, OpenAiGptMergesResources,
OpenAiGptModelResources, OpenAiGptVocabResources,
};
use rust_bert::pipelines::generation::{
Cache, GenerateConfig, LMHeadModel, LanguageGenerator, OpenAIGenerator,
};
use rust_bert::resources::{download_resource, RemoteResource, Resource};
use rust_bert::Config;
use rust_tokenizers::{OpenAiGptTokenizer, Tokenizer, TruncationStrategy};
use tch::{nn, Device, Tensor};
#[test] #[test]
fn openai_gpt_lm_model() -> failure::Fallible<()> { fn openai_gpt_lm_model() -> failure::Fallible<()> {
// Resources paths // Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptConfigResources::GPT)); let config_resource = Resource::Remote(RemoteResource::from_pretrained(
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptVocabResources::GPT)); OpenAiGptConfigResources::GPT,
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptMergesResources::GPT)); ));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptModelResources::GPT)); let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
OpenAiGptVocabResources::GPT,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
OpenAiGptMergesResources::GPT,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
OpenAiGptModelResources::GPT,
));
let config_path = download_resource(&config_resource)?; let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?; let vocab_path = download_resource(&vocab_resource)?;
let merges_path = download_resource(&merges_resource)?; let merges_path = download_resource(&merges_resource)?;
let weights_path = download_resource(&weights_resource)?; let weights_path = download_resource(&weights_resource)?;
// Set-up masked LM model // Set-up masked LM model
let device = Device::Cpu; let device = Device::Cpu;
let mut vs = nn::VarStore::new(device); let mut vs = nn::VarStore::new(device);
let tokenizer = OpenAiGptTokenizer::from_file(vocab_path.to_str().unwrap(), merges_path.to_str().unwrap(), true); let tokenizer = OpenAiGptTokenizer::from_file(
vocab_path.to_str().unwrap(),
merges_path.to_str().unwrap(),
true,
);
let config = Gpt2Config::from_file(config_path); let config = Gpt2Config::from_file(config_path);
let openai_gpt = OpenAIGPTLMHeadModel::new(&vs.root(), &config); let openai_gpt = OpenAIGPTLMHeadModel::new(&vs.root(), &config);
vs.load(weights_path)?; vs.load(weights_path)?;
// Define input // Define input
let input = ["Wondering what the next word will"]; let input = ["Wondering what the next word will"];
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0); let tokenized_input =
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap(); tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let tokenized_input = tokenized_input. let max_len = tokenized_input
iter(). .iter()
map(|input| input.token_ids.clone()). .map(|input| input.token_ids.len())
map(|mut input| { .max()
.unwrap();
let tokenized_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]); input.extend(vec![0; max_len - input.len()]);
input input
}). })
map(|input| .map(|input| Tensor::of_slice(&(input)))
Tensor::of_slice(&(input))). .collect::<Vec<_>>();
collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device); let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass // Forward pass
let (output, _, _, _, _) = openai_gpt.forward_t( let (output, _, _, _, _) = openai_gpt
&Some(input_tensor), .forward_t(
Cache::None, &Some(input_tensor),
&None, Cache::None,
&None, &None,
&None, &None,
&None, &None,
None, &None,
&None, None,
false).unwrap(); &None,
false,
)
.unwrap();
let next_word_id = output.get(0).get(-1).argmax(-1, true).int64_value(&[0]); let next_word_id = output.get(0).get(-1).argmax(-1, true).int64_value(&[0]);
let next_word = tokenizer.decode(vec!(next_word_id), true, true); let next_word = tokenizer.decode(vec![next_word_id], true, true);
assert_eq!(output.size(), vec!(1, 6, 40478)); assert_eq!(output.size(), vec!(1, 6, 40478));
assert!((output.double_value(&[0, output.size()[1] - 1, next_word_id]) - (9.1056)).abs() < 1e-4); assert!(
(output.double_value(&[0, output.size()[1] - 1, next_word_id]) - (9.1056)).abs() < 1e-4
);
assert_eq!(next_word_id, 580i64); assert_eq!(next_word_id, 580i64);
assert_eq!(next_word, String::from("be")); assert_eq!(next_word, String::from("be"));
@ -68,12 +94,20 @@ fn openai_gpt_lm_model() -> failure::Fallible<()> {
#[test] #[test]
fn openai_gpt_generation_greedy() -> failure::Fallible<()> { fn openai_gpt_generation_greedy() -> failure::Fallible<()> {
// Resources paths // Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptConfigResources::GPT)); let config_resource = Resource::Remote(RemoteResource::from_pretrained(
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptVocabResources::GPT)); OpenAiGptConfigResources::GPT,
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptMergesResources::GPT)); ));
let model_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptModelResources::GPT)); let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
OpenAiGptVocabResources::GPT,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
OpenAiGptMergesResources::GPT,
));
let model_resource = Resource::Remote(RemoteResource::from_pretrained(
OpenAiGptModelResources::GPT,
));
// Set-up masked LM model // Set-up masked LM model
let generate_config = GenerateConfig { let generate_config = GenerateConfig {
model_resource, model_resource,
config_resource, config_resource,
@ -90,7 +124,7 @@ fn openai_gpt_generation_greedy() -> failure::Fallible<()> {
let model = OpenAIGenerator::new(generate_config)?; let model = OpenAIGenerator::new(generate_config)?;
let input_context = "It was an intense machine dialogue. "; let input_context = "It was an intense machine dialogue. ";
let output = model.generate(Some(vec!(input_context)), None); let output = model.generate(Some(vec![input_context]), None);
assert_eq!(output.len(), 1); assert_eq!(output.len(), 1);
assert_eq!(output[0], "it was an intense machine dialogue. \n \" i\'m sorry, but we have to go now! the police are on their way and they\'re going after you - or at least that\'s what my"); assert_eq!(output[0], "it was an intense machine dialogue. \n \" i\'m sorry, but we have to go now! the police are on their way and they\'re going after you - or at least that\'s what my");
@ -101,12 +135,20 @@ fn openai_gpt_generation_greedy() -> failure::Fallible<()> {
#[test] #[test]
fn openai_gpt_generation_beam_search() -> failure::Fallible<()> { fn openai_gpt_generation_beam_search() -> failure::Fallible<()> {
// Resources paths // Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptConfigResources::GPT)); let config_resource = Resource::Remote(RemoteResource::from_pretrained(
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptVocabResources::GPT)); OpenAiGptConfigResources::GPT,
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptMergesResources::GPT)); ));
let model_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptModelResources::GPT)); let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
OpenAiGptVocabResources::GPT,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
OpenAiGptMergesResources::GPT,
));
let model_resource = Resource::Remote(RemoteResource::from_pretrained(
OpenAiGptModelResources::GPT,
));
// Set-up masked LM model // Set-up masked LM model
let generate_config = GenerateConfig { let generate_config = GenerateConfig {
model_resource, model_resource,
config_resource, config_resource,
@ -122,12 +164,21 @@ fn openai_gpt_generation_beam_search() -> failure::Fallible<()> {
let model = OpenAIGenerator::new(generate_config)?; let model = OpenAIGenerator::new(generate_config)?;
let input_context = "The dog is"; let input_context = "The dog is";
let output = model.generate(Some(vec!(input_context)), None); let output = model.generate(Some(vec![input_context]), None);
assert_eq!(output.len(), 3); assert_eq!(output.len(), 3);
assert_eq!(output[0], "the dog isn\'t going anywhere. i\'m going to take care of him. i \'ll be right"); assert_eq!(
assert_eq!(output[1], "the dog isn\'t going anywhere. i\'m going to take care of him. i \'ll be back"); output[0],
assert_eq!(output[2], "the dog isn\'t going anywhere. i\'m going to take care of him. \" \n \" i"); "the dog isn\'t going anywhere. i\'m going to take care of him. i \'ll be right"
);
assert_eq!(
output[1],
"the dog isn\'t going anywhere. i\'m going to take care of him. i \'ll be back"
);
assert_eq!(
output[2],
"the dog isn\'t going anywhere. i\'m going to take care of him. \" \n \" i"
);
Ok(()) Ok(())
} }
@ -135,12 +186,20 @@ fn openai_gpt_generation_beam_search() -> failure::Fallible<()> {
#[test] #[test]
fn openai_gpt_generation_beam_search_multiple_prompts_without_padding() -> failure::Fallible<()> { fn openai_gpt_generation_beam_search_multiple_prompts_without_padding() -> failure::Fallible<()> {
// Resources paths // Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptConfigResources::GPT)); let config_resource = Resource::Remote(RemoteResource::from_pretrained(
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptVocabResources::GPT)); OpenAiGptConfigResources::GPT,
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptMergesResources::GPT)); ));
let model_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptModelResources::GPT)); let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
OpenAiGptVocabResources::GPT,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
OpenAiGptMergesResources::GPT,
));
let model_resource = Resource::Remote(RemoteResource::from_pretrained(
OpenAiGptModelResources::GPT,
));
// Set-up masked LM model // Set-up masked LM model
let generate_config = GenerateConfig { let generate_config = GenerateConfig {
model_resource, model_resource,
config_resource, config_resource,
@ -157,18 +216,36 @@ fn openai_gpt_generation_beam_search_multiple_prompts_without_padding() -> failu
let input_context_1 = "The dog is"; let input_context_1 = "The dog is";
let input_context_2 = "The cat"; let input_context_2 = "The cat";
let output = model.generate(Some(vec!(input_context_1, input_context_2)), None); let output = model.generate(Some(vec![input_context_1, input_context_2]), None);
assert_eq!(output.len(), 6); assert_eq!(output.len(), 6);
// Unpadded sequence (generation for `The dog is`) is identical to the // Unpadded sequence (generation for `The dog is`) is identical to the
assert_eq!(output[0], "the dog isn\'t going anywhere. i\'m going to take care of him. i \'ll be right"); assert_eq!(
assert_eq!(output[1], "the dog isn\'t going anywhere. i\'m going to take care of him. i \'ll be back"); output[0],
assert_eq!(output[2], "the dog isn\'t going anywhere. i\'m going to take care of him. \" \n \" i"); "the dog isn\'t going anywhere. i\'m going to take care of him. i \'ll be right"
);
assert_eq!(
output[1],
"the dog isn\'t going anywhere. i\'m going to take care of him. i \'ll be back"
);
assert_eq!(
output[2],
"the dog isn\'t going anywhere. i\'m going to take care of him. \" \n \" i"
);
assert_eq!(output[3], "the cat. \" \n \" i don\'t know what you\'re talking about. i don\'t"); assert_eq!(
assert_eq!(output[4], "the cat. \" \n \" i don\'t know what you\'re talking about. i\'m not"); output[3],
assert_eq!(output[5], "the cat. \" \n \" i don\'t know what you\'re talking about. i do know"); "the cat. \" \n \" i don\'t know what you\'re talking about. i don\'t"
);
assert_eq!(
output[4],
"the cat. \" \n \" i don\'t know what you\'re talking about. i\'m not"
);
assert_eq!(
output[5],
"the cat. \" \n \" i don\'t know what you\'re talking about. i do know"
);
Ok(()) Ok(())
} }
@ -176,12 +253,20 @@ fn openai_gpt_generation_beam_search_multiple_prompts_without_padding() -> failu
#[test] #[test]
fn openai_gpt_generation_beam_search_multiple_prompts_with_padding() -> failure::Fallible<()> { fn openai_gpt_generation_beam_search_multiple_prompts_with_padding() -> failure::Fallible<()> {
// Resources paths // Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptConfigResources::GPT)); let config_resource = Resource::Remote(RemoteResource::from_pretrained(
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptVocabResources::GPT)); OpenAiGptConfigResources::GPT,
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptMergesResources::GPT)); ));
let model_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptModelResources::GPT)); let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
OpenAiGptVocabResources::GPT,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
OpenAiGptMergesResources::GPT,
));
let model_resource = Resource::Remote(RemoteResource::from_pretrained(
OpenAiGptModelResources::GPT,
));
// Set-up masked LM model // Set-up masked LM model
let generate_config = GenerateConfig { let generate_config = GenerateConfig {
model_resource, model_resource,
config_resource, config_resource,
@ -198,16 +283,34 @@ fn openai_gpt_generation_beam_search_multiple_prompts_with_padding() -> failure:
let input_context_1 = "The dog is"; let input_context_1 = "The dog is";
let input_context_2 = "The cat was in"; let input_context_2 = "The cat was in";
let output = model.generate(Some(vec!(input_context_1, input_context_2)), None); let output = model.generate(Some(vec![input_context_1, input_context_2]), None);
assert_eq!(output.len(), 6); assert_eq!(output.len(), 6);
// Left padding impacts the generated sentences output // Left padding impacts the generated sentences output
assert_eq!(output[0], "the dog is a dog. \" \n \" i don\'t know what you\'re talking about."); assert_eq!(
assert_eq!(output[1], "the dog is a dog. \" \n \" i don\'t know what you\'re talking about,"); output[0],
assert_eq!(output[2], "the dog is a dog. \" \n \" i don\'t know what you\'re talking about!"); "the dog is a dog. \" \n \" i don\'t know what you\'re talking about."
assert_eq!(output[3], "the cat was in the room with them. \n \" what\'s going on? \" i asked."); );
assert_eq!(output[4], "the cat was in the room with them. \n \" what\'s going on? \" she asked."); assert_eq!(
assert_eq!(output[5], "the cat was in the room with them. \n \" what\'s going on? why are you all"); output[1],
"the dog is a dog. \" \n \" i don\'t know what you\'re talking about,"
);
assert_eq!(
output[2],
"the dog is a dog. \" \n \" i don\'t know what you\'re talking about!"
);
assert_eq!(
output[3],
"the cat was in the room with them. \n \" what\'s going on? \" i asked."
);
assert_eq!(
output[4],
"the cat was in the room with them. \n \" what\'s going on? \" she asked."
);
assert_eq!(
output[5],
"the cat was in the room with them. \n \" what\'s going on? why are you all"
);
Ok(()) Ok(())
} }

View File

@ -1,75 +1,99 @@
use tch::{Device, nn, Tensor, no_grad};
use rust_tokenizers::{RobertaTokenizer, TruncationStrategy, Tokenizer, Vocab};
use rust_bert::Config;
use rust_bert::bert::BertConfig; use rust_bert::bert::BertConfig;
use rust_bert::roberta::{RobertaForMaskedLM, RobertaForSequenceClassification, RobertaForMultipleChoice, RobertaForTokenClassification, RobertaForQuestionAnswering, RobertaConfigResources, RobertaVocabResources, RobertaMergesResources, RobertaModelResources}; use rust_bert::resources::{download_resource, RemoteResource, Resource};
use rust_bert::roberta::{
RobertaConfigResources, RobertaForMaskedLM, RobertaForMultipleChoice,
RobertaForQuestionAnswering, RobertaForSequenceClassification, RobertaForTokenClassification,
RobertaMergesResources, RobertaModelResources, RobertaVocabResources,
};
use rust_bert::Config;
use rust_tokenizers::{RobertaTokenizer, Tokenizer, TruncationStrategy, Vocab};
use std::collections::HashMap; use std::collections::HashMap;
use rust_bert::resources::{RemoteResource, Resource, download_resource}; use tch::{nn, no_grad, Device, Tensor};
#[test] #[test]
fn roberta_masked_lm() -> failure::Fallible<()> { fn roberta_masked_lm() -> failure::Fallible<()> {
// Resources paths // Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaConfigResources::ROBERTA)); let config_resource = Resource::Remote(RemoteResource::from_pretrained(
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaVocabResources::ROBERTA)); RobertaConfigResources::ROBERTA,
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaMergesResources::ROBERTA)); ));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaModelResources::ROBERTA)); let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
RobertaVocabResources::ROBERTA,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
RobertaMergesResources::ROBERTA,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
RobertaModelResources::ROBERTA,
));
let config_path = download_resource(&config_resource)?; let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?; let vocab_path = download_resource(&vocab_resource)?;
let merges_path = download_resource(&merges_resource)?; let merges_path = download_resource(&merges_resource)?;
let weights_path = download_resource(&weights_resource)?; let weights_path = download_resource(&weights_resource)?;
// Set-up masked LM model // Set-up masked LM model
let device = Device::Cpu; let device = Device::Cpu;
let mut vs = nn::VarStore::new(device); let mut vs = nn::VarStore::new(device);
let tokenizer: RobertaTokenizer = RobertaTokenizer::from_file(vocab_path.to_str().unwrap(), merges_path.to_str().unwrap(), true); let tokenizer: RobertaTokenizer = RobertaTokenizer::from_file(
vocab_path.to_str().unwrap(),
merges_path.to_str().unwrap(),
true,
);
let config = BertConfig::from_file(config_path); let config = BertConfig::from_file(config_path);
let roberta_model = RobertaForMaskedLM::new(&vs.root(), &config); let roberta_model = RobertaForMaskedLM::new(&vs.root(), &config);
vs.load(weights_path)?; vs.load(weights_path)?;
// Define input // Define input
let input = ["<pad> Looks like one thing is missing", "It\'s like comparing oranges to apples"]; let input = [
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0); "<pad> Looks like one thing is missing",
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap(); "It\'s like comparing oranges to apples",
let mut tokenized_input = tokenized_input. ];
iter(). let tokenized_input =
map(|input| input.token_ids.clone()). tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
map(|mut input| { let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let mut tokenized_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]); input.extend(vec![0; max_len - input.len()]);
input input
}). })
collect::<Vec<_>>(); .collect::<Vec<_>>();
// Masking the token [thing] of sentence 1 and [oranges] of sentence 2 // Masking the token [thing] of sentence 1 and [oranges] of sentence 2
tokenized_input[0][4] = 103; tokenized_input[0][4] = 103;
tokenized_input[1][5] = 103; tokenized_input[1][5] = 103;
let tokenized_input = tokenized_input. let tokenized_input = tokenized_input
iter(). .iter()
map(|input| .map(|input| Tensor::of_slice(&(input)))
Tensor::of_slice(&(input))). .collect::<Vec<_>>();
collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device); let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass // Forward pass
let (output, _, _) = no_grad(|| { let (output, _, _) = no_grad(|| {
roberta_model roberta_model.forward_t(
.forward_t(Some(input_tensor), Some(input_tensor),
None, None,
None, None,
None, None,
None, None,
&None, &None,
&None, &None,
false) false,
)
}); });
// Print masked tokens // Print masked tokens
let index_1 = output.get(0).get(4).argmax(0, false); let index_1 = output.get(0).get(4).argmax(0, false);
let index_2 = output.get(1).get(5).argmax(0, false); let index_2 = output.get(1).get(5).argmax(0, false);
let word_1 = tokenizer.vocab().id_to_token(&index_1.int64_value(&[])); let word_1 = tokenizer.vocab().id_to_token(&index_1.int64_value(&[]));
let word_2 = tokenizer.vocab().id_to_token(&index_2.int64_value(&[])); let word_2 = tokenizer.vocab().id_to_token(&index_2.int64_value(&[]));
assert_eq!("Ġsome", word_1); // Outputs "person" : "Looks like [some] thing is missing" assert_eq!("Ġsome", word_1); // Outputs "person" : "Looks like [some] thing is missing"
assert_eq!("Ġapples", word_2);// Outputs "pear" : "It\'s like comparing [apples] to apples" assert_eq!("Ġapples", word_2); // Outputs "pear" : "It\'s like comparing [apples] to apples"
Ok(()) Ok(())
} }
@ -77,17 +101,27 @@ fn roberta_masked_lm() -> failure::Fallible<()> {
#[test] #[test]
fn roberta_for_sequence_classification() -> failure::Fallible<()> { fn roberta_for_sequence_classification() -> failure::Fallible<()> {
// Resources paths // Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaConfigResources::ROBERTA)); let config_resource = Resource::Remote(RemoteResource::from_pretrained(
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaVocabResources::ROBERTA)); RobertaConfigResources::ROBERTA,
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaMergesResources::ROBERTA)); ));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
RobertaVocabResources::ROBERTA,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
RobertaMergesResources::ROBERTA,
));
let config_path = download_resource(&config_resource)?; let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?; let vocab_path = download_resource(&vocab_resource)?;
let merges_path = download_resource(&merges_resource)?; let merges_path = download_resource(&merges_resource)?;
// Set-up model // Set-up model
let device = Device::Cpu; let device = Device::Cpu;
let vs = nn::VarStore::new(device); let vs = nn::VarStore::new(device);
let tokenizer: RobertaTokenizer = RobertaTokenizer::from_file(vocab_path.to_str().unwrap(), merges_path.to_str().unwrap(), true); let tokenizer: RobertaTokenizer = RobertaTokenizer::from_file(
vocab_path.to_str().unwrap(),
merges_path.to_str().unwrap(),
true,
);
let mut config = BertConfig::from_file(config_path); let mut config = BertConfig::from_file(config_path);
let mut dummy_label_mapping = HashMap::new(); let mut dummy_label_mapping = HashMap::new();
dummy_label_mapping.insert(0, String::from("Positive")); dummy_label_mapping.insert(0, String::from("Positive"));
@ -98,37 +132,42 @@ fn roberta_for_sequence_classification() -> failure::Fallible<()> {
config.output_hidden_states = Some(true); config.output_hidden_states = Some(true);
let roberta_model = RobertaForSequenceClassification::new(&vs.root(), &config); let roberta_model = RobertaForSequenceClassification::new(&vs.root(), &config);
// Define input
// Define input let input = [
let input = ["Looks like one thing is missing", "It\'s like comparing oranges to apples"]; "Looks like one thing is missing",
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0); "It\'s like comparing oranges to apples",
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap(); ];
let tokenized_input = tokenized_input. let tokenized_input =
iter(). tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
map(|input| input.token_ids.clone()). let max_len = tokenized_input
map(|mut input| { .iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let tokenized_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]); input.extend(vec![0; max_len - input.len()]);
input input
}). })
map(|input| .map(|input| Tensor::of_slice(&(input)))
Tensor::of_slice(&(input))). .collect::<Vec<_>>();
collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device); let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass // Forward pass
let (output, all_hidden_states, all_attentions) = no_grad(|| { let (output, all_hidden_states, all_attentions) =
roberta_model no_grad(|| roberta_model.forward_t(Some(input_tensor), None, None, None, None, false));
.forward_t(Some(input_tensor),
None,
None,
None,
None,
false)
});
assert_eq!(output.size(), &[2, 3]); assert_eq!(output.size(), &[2, 3]);
assert_eq!(config.num_hidden_layers as usize, all_hidden_states.unwrap().len()); assert_eq!(
assert_eq!(config.num_hidden_layers as usize, all_attentions.unwrap().len()); config.num_hidden_layers as usize,
all_hidden_states.unwrap().len()
);
assert_eq!(
config.num_hidden_layers as usize,
all_attentions.unwrap().len()
);
Ok(()) Ok(())
} }
@ -136,52 +175,70 @@ fn roberta_for_sequence_classification() -> failure::Fallible<()> {
#[test] #[test]
fn roberta_for_multiple_choice() -> failure::Fallible<()> { fn roberta_for_multiple_choice() -> failure::Fallible<()> {
// Resources paths // Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaConfigResources::ROBERTA)); let config_resource = Resource::Remote(RemoteResource::from_pretrained(
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaVocabResources::ROBERTA)); RobertaConfigResources::ROBERTA,
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaMergesResources::ROBERTA)); ));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
RobertaVocabResources::ROBERTA,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
RobertaMergesResources::ROBERTA,
));
let config_path = download_resource(&config_resource)?; let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?; let vocab_path = download_resource(&vocab_resource)?;
let merges_path = download_resource(&merges_resource)?; let merges_path = download_resource(&merges_resource)?;
// Set-up model // Set-up model
let device = Device::Cpu; let device = Device::Cpu;
let vs = nn::VarStore::new(device); let vs = nn::VarStore::new(device);
let tokenizer: RobertaTokenizer = RobertaTokenizer::from_file(vocab_path.to_str().unwrap(), merges_path.to_str().unwrap(), true); let tokenizer: RobertaTokenizer = RobertaTokenizer::from_file(
vocab_path.to_str().unwrap(),
merges_path.to_str().unwrap(),
true,
);
let mut config = BertConfig::from_file(config_path); let mut config = BertConfig::from_file(config_path);
config.output_attentions = Some(true); config.output_attentions = Some(true);
config.output_hidden_states = Some(true); config.output_hidden_states = Some(true);
let roberta_model = RobertaForMultipleChoice::new(&vs.root(), &config); let roberta_model = RobertaForMultipleChoice::new(&vs.root(), &config);
// Define input
// Define input let input = [
let input = ["Looks like one thing is missing", "It\'s like comparing oranges to apples"]; "Looks like one thing is missing",
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0); "It\'s like comparing oranges to apples",
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap(); ];
let tokenized_input = tokenized_input. let tokenized_input =
iter(). tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
map(|input| input.token_ids.clone()). let max_len = tokenized_input
map(|mut input| { .iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let tokenized_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]); input.extend(vec![0; max_len - input.len()]);
input input
}). })
map(|input| .map(|input| Tensor::of_slice(&(input)))
Tensor::of_slice(&(input))). .collect::<Vec<_>>();
collect::<Vec<_>>(); let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0)
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device).unsqueeze(0); .to(device)
.unsqueeze(0);
// Forward pass // Forward pass
let (output, all_hidden_states, all_attentions) = no_grad(|| { let (output, all_hidden_states, all_attentions) =
roberta_model no_grad(|| roberta_model.forward_t(input_tensor, None, None, None, false));
.forward_t(input_tensor,
None,
None,
None,
false)
});
assert_eq!(output.size(), &[1, 2]); assert_eq!(output.size(), &[1, 2]);
assert_eq!(config.num_hidden_layers as usize, all_hidden_states.unwrap().len()); assert_eq!(
assert_eq!(config.num_hidden_layers as usize, all_attentions.unwrap().len()); config.num_hidden_layers as usize,
all_hidden_states.unwrap().len()
);
assert_eq!(
config.num_hidden_layers as usize,
all_attentions.unwrap().len()
);
Ok(()) Ok(())
} }
@ -189,17 +246,27 @@ fn roberta_for_multiple_choice() -> failure::Fallible<()> {
#[test] #[test]
fn roberta_for_token_classification() -> failure::Fallible<()> { fn roberta_for_token_classification() -> failure::Fallible<()> {
// Resources paths // Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaConfigResources::ROBERTA)); let config_resource = Resource::Remote(RemoteResource::from_pretrained(
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaVocabResources::ROBERTA)); RobertaConfigResources::ROBERTA,
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaMergesResources::ROBERTA)); ));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
RobertaVocabResources::ROBERTA,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
RobertaMergesResources::ROBERTA,
));
let config_path = download_resource(&config_resource)?; let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?; let vocab_path = download_resource(&vocab_resource)?;
let merges_path = download_resource(&merges_resource)?; let merges_path = download_resource(&merges_resource)?;
// Set-up model // Set-up model
let device = Device::Cpu; let device = Device::Cpu;
let vs = nn::VarStore::new(device); let vs = nn::VarStore::new(device);
let tokenizer: RobertaTokenizer = RobertaTokenizer::from_file(vocab_path.to_str().unwrap(), merges_path.to_str().unwrap(), true); let tokenizer: RobertaTokenizer = RobertaTokenizer::from_file(
vocab_path.to_str().unwrap(),
merges_path.to_str().unwrap(),
true,
);
let mut config = BertConfig::from_file(config_path); let mut config = BertConfig::from_file(config_path);
let mut dummy_label_mapping = HashMap::new(); let mut dummy_label_mapping = HashMap::new();
dummy_label_mapping.insert(0, String::from("O")); dummy_label_mapping.insert(0, String::from("O"));
@ -211,55 +278,70 @@ fn roberta_for_token_classification() -> failure::Fallible<()> {
config.output_hidden_states = Some(true); config.output_hidden_states = Some(true);
let roberta_model = RobertaForTokenClassification::new(&vs.root(), &config); let roberta_model = RobertaForTokenClassification::new(&vs.root(), &config);
// Define input // Define input
let input = ["Looks like one thing is missing", "It\'s like comparing oranges to apples"]; let input = [
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0); "Looks like one thing is missing",
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap(); "It\'s like comparing oranges to apples",
let tokenized_input = tokenized_input. ];
iter(). let tokenized_input =
map(|input| input.token_ids.clone()). tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
map(|mut input| { let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let tokenized_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]); input.extend(vec![0; max_len - input.len()]);
input input
}). })
map(|input| .map(|input| Tensor::of_slice(&(input)))
Tensor::of_slice(&(input))). .collect::<Vec<_>>();
collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device); let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass // Forward pass
let (output, all_hidden_states, all_attentions) = no_grad(|| { let (output, all_hidden_states, all_attentions) =
roberta_model no_grad(|| roberta_model.forward_t(Some(input_tensor), None, None, None, None, false));
.forward_t(Some(input_tensor),
None,
None,
None,
None,
false)
});
assert_eq!(output.size(), &[2, 9, 4]); assert_eq!(output.size(), &[2, 9, 4]);
assert_eq!(config.num_hidden_layers as usize, all_hidden_states.unwrap().len()); assert_eq!(
assert_eq!(config.num_hidden_layers as usize, all_attentions.unwrap().len()); config.num_hidden_layers as usize,
all_hidden_states.unwrap().len()
);
assert_eq!(
config.num_hidden_layers as usize,
all_attentions.unwrap().len()
);
Ok(()) Ok(())
} }
#[test] #[test]
fn roberta_for_question_answering() -> failure::Fallible<()> { fn roberta_for_question_answering() -> failure::Fallible<()> {
// Resources paths // Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaConfigResources::ROBERTA)); let config_resource = Resource::Remote(RemoteResource::from_pretrained(
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaVocabResources::ROBERTA)); RobertaConfigResources::ROBERTA,
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaMergesResources::ROBERTA)); ));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
RobertaVocabResources::ROBERTA,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
RobertaMergesResources::ROBERTA,
));
let config_path = download_resource(&config_resource)?; let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?; let vocab_path = download_resource(&vocab_resource)?;
let merges_path = download_resource(&merges_resource)?; let merges_path = download_resource(&merges_resource)?;
// Set-up model // Set-up model
let device = Device::Cpu; let device = Device::Cpu;
let vs = nn::VarStore::new(device); let vs = nn::VarStore::new(device);
let tokenizer: RobertaTokenizer = RobertaTokenizer::from_file(vocab_path.to_str().unwrap(), merges_path.to_str().unwrap(), true); let tokenizer: RobertaTokenizer = RobertaTokenizer::from_file(
vocab_path.to_str().unwrap(),
merges_path.to_str().unwrap(),
true,
);
let mut config = BertConfig::from_file(config_path); let mut config = BertConfig::from_file(config_path);
let mut dummy_label_mapping = HashMap::new(); let mut dummy_label_mapping = HashMap::new();
dummy_label_mapping.insert(0, String::from("Positive")); dummy_label_mapping.insert(0, String::from("Positive"));
@ -270,38 +352,43 @@ fn roberta_for_question_answering() -> failure::Fallible<()> {
config.output_hidden_states = Some(true); config.output_hidden_states = Some(true);
let roberta_model = RobertaForQuestionAnswering::new(&vs.root(), &config); let roberta_model = RobertaForQuestionAnswering::new(&vs.root(), &config);
// Define input
// Define input let input = [
let input = ["Looks like one thing is missing", "It\'s like comparing oranges to apples"]; "Looks like one thing is missing",
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0); "It\'s like comparing oranges to apples",
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap(); ];
let tokenized_input = tokenized_input. let tokenized_input =
iter(). tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
map(|input| input.token_ids.clone()). let max_len = tokenized_input
map(|mut input| { .iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let tokenized_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]); input.extend(vec![0; max_len - input.len()]);
input input
}). })
map(|input| .map(|input| Tensor::of_slice(&(input)))
Tensor::of_slice(&(input))). .collect::<Vec<_>>();
collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device); let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass // Forward pass
let (start_scores, end_scores, all_hidden_states, all_attentions) = no_grad(|| { let (start_scores, end_scores, all_hidden_states, all_attentions) =
roberta_model no_grad(|| roberta_model.forward_t(Some(input_tensor), None, None, None, None, false));
.forward_t(Some(input_tensor),
None,
None,
None,
None,
false)
});
assert_eq!(start_scores.size(), &[2, 9]); assert_eq!(start_scores.size(), &[2, 9]);
assert_eq!(end_scores.size(), &[2, 9]); assert_eq!(end_scores.size(), &[2, 9]);
assert_eq!(config.num_hidden_layers as usize, all_hidden_states.unwrap().len()); assert_eq!(
assert_eq!(config.num_hidden_layers as usize, all_attentions.unwrap().len()); config.num_hidden_layers as usize,
all_hidden_states.unwrap().len()
);
assert_eq!(
config.num_hidden_layers as usize,
all_attentions.unwrap().len()
);
Ok(()) Ok(())
} }