mirror of
https://github.com/guillaume-be/rust-bert.git
synced 2024-10-05 16:47:24 +03:00
Code formatted using rustfmt
This commit is contained in:
parent
0624a5368c
commit
47e36c4e8c
@ -13,58 +13,67 @@
|
|||||||
|
|
||||||
extern crate failure;
|
extern crate failure;
|
||||||
|
|
||||||
use tch::{Device, nn, Tensor, no_grad};
|
use rust_bert::albert::{
|
||||||
use rust_tokenizers::{AlbertTokenizer, TruncationStrategy, Tokenizer, Vocab};
|
AlbertConfig, AlbertConfigResources, AlbertForMaskedLM, AlbertModelResources,
|
||||||
|
AlbertVocabResources,
|
||||||
|
};
|
||||||
|
use rust_bert::resources::{download_resource, RemoteResource, Resource};
|
||||||
use rust_bert::Config;
|
use rust_bert::Config;
|
||||||
use rust_bert::resources::{Resource, download_resource, RemoteResource};
|
use rust_tokenizers::{AlbertTokenizer, Tokenizer, TruncationStrategy, Vocab};
|
||||||
use rust_bert::albert::{AlbertConfig, AlbertForMaskedLM, AlbertConfigResources, AlbertVocabResources, AlbertModelResources};
|
use tch::{nn, no_grad, Device, Tensor};
|
||||||
|
|
||||||
|
|
||||||
fn main() -> failure::Fallible<()> {
|
fn main() -> failure::Fallible<()> {
|
||||||
// Resources paths
|
// Resources paths
|
||||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(AlbertConfigResources::ALBERT_BASE_V2));
|
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(AlbertVocabResources::ALBERT_BASE_V2));
|
AlbertConfigResources::ALBERT_BASE_V2,
|
||||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(AlbertModelResources::ALBERT_BASE_V2));
|
));
|
||||||
|
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||||
|
AlbertVocabResources::ALBERT_BASE_V2,
|
||||||
|
));
|
||||||
|
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||||
|
AlbertModelResources::ALBERT_BASE_V2,
|
||||||
|
));
|
||||||
let config_path = download_resource(&config_resource)?;
|
let config_path = download_resource(&config_resource)?;
|
||||||
let vocab_path = download_resource(&vocab_resource)?;
|
let vocab_path = download_resource(&vocab_resource)?;
|
||||||
let weights_path = download_resource(&weights_resource)?;
|
let weights_path = download_resource(&weights_resource)?;
|
||||||
|
|
||||||
// Set-up masked LM model
|
// Set-up masked LM model
|
||||||
let device = Device::Cpu;
|
let device = Device::Cpu;
|
||||||
let mut vs = nn::VarStore::new(device);
|
let mut vs = nn::VarStore::new(device);
|
||||||
let tokenizer: AlbertTokenizer = AlbertTokenizer::from_file(vocab_path.to_str().unwrap(), true, false);
|
let tokenizer: AlbertTokenizer =
|
||||||
|
AlbertTokenizer::from_file(vocab_path.to_str().unwrap(), true, false);
|
||||||
let config = AlbertConfig::from_file(config_path);
|
let config = AlbertConfig::from_file(config_path);
|
||||||
let albert_model = AlbertForMaskedLM::new(&vs.root(), &config);
|
let albert_model = AlbertForMaskedLM::new(&vs.root(), &config);
|
||||||
vs.load(weights_path)?;
|
vs.load(weights_path)?;
|
||||||
|
|
||||||
// Define input
|
// Define input
|
||||||
let input = ["Looks like one [MASK] is missing", "It was a very nice and [MASK] day"];
|
let input = [
|
||||||
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
"Looks like one [MASK] is missing",
|
||||||
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
|
"It was a very nice and [MASK] day",
|
||||||
let tokenized_input = tokenized_input.
|
];
|
||||||
iter().
|
let tokenized_input =
|
||||||
map(|input| input.token_ids.clone()).
|
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
||||||
map(|mut input| {
|
let max_len = tokenized_input
|
||||||
|
.iter()
|
||||||
|
.map(|input| input.token_ids.len())
|
||||||
|
.max()
|
||||||
|
.unwrap();
|
||||||
|
let tokenized_input = tokenized_input
|
||||||
|
.iter()
|
||||||
|
.map(|input| input.token_ids.clone())
|
||||||
|
.map(|mut input| {
|
||||||
input.extend(vec![0; max_len - input.len()]);
|
input.extend(vec![0; max_len - input.len()]);
|
||||||
input
|
input
|
||||||
}).
|
})
|
||||||
map(|input|
|
.map(|input| Tensor::of_slice(&(input)))
|
||||||
Tensor::of_slice(&(input))).
|
.collect::<Vec<_>>();
|
||||||
collect::<Vec<_>>();
|
|
||||||
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
|
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
|
||||||
|
|
||||||
// Forward pass
|
// Forward pass
|
||||||
let (output, _, _) = no_grad(|| {
|
let (output, _, _) =
|
||||||
albert_model
|
no_grad(|| albert_model.forward_t(Some(input_tensor), None, None, None, None, false));
|
||||||
.forward_t(Some(input_tensor),
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
false)
|
|
||||||
});
|
|
||||||
println!("{:?}", output.double_value(&[0, 0, 0]));
|
println!("{:?}", output.double_value(&[0, 0, 0]));
|
||||||
// Print masked tokens
|
// Print masked tokens
|
||||||
let index_1 = output.get(0).get(4).argmax(0, false);
|
let index_1 = output.get(0).get(4).argmax(0, false);
|
||||||
let index_2 = output.get(1).get(7).argmax(0, false);
|
let index_2 = output.get(1).get(7).argmax(0, false);
|
||||||
let word_1 = tokenizer.vocab().id_to_token(&index_1.int64_value(&[]));
|
let word_1 = tokenizer.vocab().id_to_token(&index_1.int64_value(&[]));
|
||||||
@ -74,4 +83,4 @@ fn main() -> failure::Fallible<()> {
|
|||||||
println!("{} - {}", &index_2.int64_value(&[]), word_2); // Outputs "_enjoyable" : "It was a very nice and [enjoyable] day"
|
println!("{} - {}", &index_2.int64_value(&[]), word_2); // Outputs "_enjoyable" : "It was a very nice and [enjoyable] day"
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -12,33 +12,43 @@
|
|||||||
|
|
||||||
extern crate failure;
|
extern crate failure;
|
||||||
|
|
||||||
use tch::{Device, nn, Tensor, no_grad};
|
use rust_bert::bart::{
|
||||||
use rust_tokenizers::{RobertaTokenizer, TruncationStrategy, Tokenizer};
|
BartConfig, BartConfigResources, BartMergesResources, BartModel, BartModelResources,
|
||||||
use rust_bert::bart::{BartConfig, BartConfigResources, BartVocabResources, BartMergesResources, BartModelResources, BartModel};
|
BartVocabResources,
|
||||||
|
};
|
||||||
|
use rust_bert::resources::{download_resource, RemoteResource, Resource};
|
||||||
use rust_bert::Config;
|
use rust_bert::Config;
|
||||||
use rust_bert::resources::{Resource, download_resource, RemoteResource};
|
use rust_tokenizers::{RobertaTokenizer, Tokenizer, TruncationStrategy};
|
||||||
|
use tch::{nn, no_grad, Device, Tensor};
|
||||||
|
|
||||||
fn main() -> failure::Fallible<()> {
|
fn main() -> failure::Fallible<()> {
|
||||||
// Resources paths
|
// Resources paths
|
||||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(BartConfigResources::BART));
|
let config_resource =
|
||||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(BartVocabResources::BART));
|
Resource::Remote(RemoteResource::from_pretrained(BartConfigResources::BART));
|
||||||
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(BartMergesResources::BART));
|
let vocab_resource =
|
||||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(BartModelResources::BART));
|
Resource::Remote(RemoteResource::from_pretrained(BartVocabResources::BART));
|
||||||
|
let merges_resource =
|
||||||
|
Resource::Remote(RemoteResource::from_pretrained(BartMergesResources::BART));
|
||||||
|
let weights_resource =
|
||||||
|
Resource::Remote(RemoteResource::from_pretrained(BartModelResources::BART));
|
||||||
let config_path = download_resource(&config_resource)?;
|
let config_path = download_resource(&config_resource)?;
|
||||||
let vocab_path = download_resource(&vocab_resource)?;
|
let vocab_path = download_resource(&vocab_resource)?;
|
||||||
let merges_path = download_resource(&merges_resource)?;
|
let merges_path = download_resource(&merges_resource)?;
|
||||||
let weights_path = download_resource(&weights_resource)?;
|
let weights_path = download_resource(&weights_resource)?;
|
||||||
|
|
||||||
// Set-up masked LM model
|
// Set-up masked LM model
|
||||||
let device = Device::cuda_if_available();
|
let device = Device::cuda_if_available();
|
||||||
let mut vs = nn::VarStore::new(device);
|
let mut vs = nn::VarStore::new(device);
|
||||||
let tokenizer = RobertaTokenizer::from_file(vocab_path.to_str().unwrap(), merges_path.to_str().unwrap(), false);
|
let tokenizer = RobertaTokenizer::from_file(
|
||||||
|
vocab_path.to_str().unwrap(),
|
||||||
|
merges_path.to_str().unwrap(),
|
||||||
|
false,
|
||||||
|
);
|
||||||
let config = BartConfig::from_file(config_path);
|
let config = BartConfig::from_file(config_path);
|
||||||
let bart_model = BartModel::new(&vs.root(), &config, false);
|
let bart_model = BartModel::new(&vs.root(), &config, false);
|
||||||
vs.load(weights_path)?;
|
vs.load(weights_path)?;
|
||||||
|
|
||||||
// Define input
|
// Define input
|
||||||
let input = ["In findings published Tuesday in Cornell University's arXiv by a team of scientists \
|
let input = ["In findings published Tuesday in Cornell University's arXiv by a team of scientists \
|
||||||
from the University of Montreal and a separate report published Wednesday in Nature Astronomy by a team \
|
from the University of Montreal and a separate report published Wednesday in Nature Astronomy by a team \
|
||||||
from University College London (UCL), the presence of water vapour was confirmed in the atmosphere of K2-18b, \
|
from University College London (UCL), the presence of water vapour was confirmed in the atmosphere of K2-18b, \
|
||||||
@ -61,36 +71,32 @@ on K2-18b lasts 33 Earth days. According to The Guardian, astronomers were optim
|
|||||||
telescope — scheduled for launch in 2021 — and the European Space Agency's 2028 ARIEL program, could reveal more \
|
telescope — scheduled for launch in 2021 — and the European Space Agency's 2028 ARIEL program, could reveal more \
|
||||||
about exoplanets like K2-18b."];
|
about exoplanets like K2-18b."];
|
||||||
|
|
||||||
// Credits: WikiNews, CC BY 2.5 license (https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b)
|
// Credits: WikiNews, CC BY 2.5 license (https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b)
|
||||||
|
|
||||||
let tokenized_input = tokenizer.encode_list(input.to_vec(), 1024, &TruncationStrategy::LongestFirst, 0);
|
let tokenized_input =
|
||||||
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
|
tokenizer.encode_list(input.to_vec(), 1024, &TruncationStrategy::LongestFirst, 0);
|
||||||
let tokenized_input = tokenized_input.
|
let max_len = tokenized_input
|
||||||
iter().
|
.iter()
|
||||||
map(|input| input.token_ids.clone()).
|
.map(|input| input.token_ids.len())
|
||||||
map(|mut input| {
|
.max()
|
||||||
|
.unwrap();
|
||||||
|
let tokenized_input = tokenized_input
|
||||||
|
.iter()
|
||||||
|
.map(|input| input.token_ids.clone())
|
||||||
|
.map(|mut input| {
|
||||||
input.extend(vec![0; max_len - input.len()]);
|
input.extend(vec![0; max_len - input.len()]);
|
||||||
input
|
input
|
||||||
}).
|
})
|
||||||
map(|input|
|
.map(|input| Tensor::of_slice(&(input)))
|
||||||
Tensor::of_slice(&(input))).
|
.collect::<Vec<_>>();
|
||||||
collect::<Vec<_>>();
|
|
||||||
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
|
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
|
||||||
|
|
||||||
// Forward pass
|
// Forward pass
|
||||||
let (decoder_output, encoder_output, _, _, _, _, _) = no_grad(|| {
|
let (decoder_output, encoder_output, _, _, _, _, _) =
|
||||||
bart_model
|
no_grad(|| bart_model.forward_t(Some(&input_tensor), None, None, None, None, None, false));
|
||||||
.forward_t(Some(&input_tensor),
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
false)
|
|
||||||
});
|
|
||||||
|
|
||||||
// Print masked tokens
|
// Print masked tokens
|
||||||
println!("{:?}", encoder_output);
|
println!("{:?}", encoder_output);
|
||||||
println!("{:?}", decoder_output);
|
println!("{:?}", decoder_output);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -12,23 +12,27 @@
|
|||||||
|
|
||||||
extern crate failure;
|
extern crate failure;
|
||||||
|
|
||||||
use tch::{Device, nn, Tensor, no_grad};
|
use rust_bert::bert::{
|
||||||
use rust_tokenizers::{BertTokenizer, TruncationStrategy, Tokenizer, Vocab};
|
BertConfig, BertConfigResources, BertForMaskedLM, BertModelResources, BertVocabResources,
|
||||||
|
};
|
||||||
|
use rust_bert::resources::{download_resource, RemoteResource, Resource};
|
||||||
use rust_bert::Config;
|
use rust_bert::Config;
|
||||||
use rust_bert::bert::{BertConfig, BertForMaskedLM, BertConfigResources, BertVocabResources, BertModelResources};
|
use rust_tokenizers::{BertTokenizer, Tokenizer, TruncationStrategy, Vocab};
|
||||||
use rust_bert::resources::{Resource, download_resource, RemoteResource};
|
use tch::{nn, no_grad, Device, Tensor};
|
||||||
|
|
||||||
|
|
||||||
fn main() -> failure::Fallible<()> {
|
fn main() -> failure::Fallible<()> {
|
||||||
// Resources paths
|
// Resources paths
|
||||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT));
|
let config_resource =
|
||||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT));
|
Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT));
|
||||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(BertModelResources::BERT));
|
let vocab_resource =
|
||||||
|
Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT));
|
||||||
|
let weights_resource =
|
||||||
|
Resource::Remote(RemoteResource::from_pretrained(BertModelResources::BERT));
|
||||||
let config_path = download_resource(&config_resource)?;
|
let config_path = download_resource(&config_resource)?;
|
||||||
let vocab_path = download_resource(&vocab_resource)?;
|
let vocab_path = download_resource(&vocab_resource)?;
|
||||||
let weights_path = download_resource(&weights_resource)?;
|
let weights_path = download_resource(&weights_resource)?;
|
||||||
|
|
||||||
// Set-up masked LM model
|
// Set-up masked LM model
|
||||||
let device = Device::Cpu;
|
let device = Device::Cpu;
|
||||||
let mut vs = nn::VarStore::new(device);
|
let mut vs = nn::VarStore::new(device);
|
||||||
let tokenizer: BertTokenizer = BertTokenizer::from_file(vocab_path.to_str().unwrap(), true);
|
let tokenizer: BertTokenizer = BertTokenizer::from_file(vocab_path.to_str().unwrap(), true);
|
||||||
@ -36,43 +40,51 @@ fn main() -> failure::Fallible<()> {
|
|||||||
let bert_model = BertForMaskedLM::new(&vs.root(), &config);
|
let bert_model = BertForMaskedLM::new(&vs.root(), &config);
|
||||||
vs.load(weights_path)?;
|
vs.load(weights_path)?;
|
||||||
|
|
||||||
// Define input
|
// Define input
|
||||||
let input = ["Looks like one [MASK] is missing", "It was a very nice and [MASK] day"];
|
let input = [
|
||||||
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
"Looks like one [MASK] is missing",
|
||||||
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
|
"It was a very nice and [MASK] day",
|
||||||
let tokenized_input = tokenized_input.
|
];
|
||||||
iter().
|
let tokenized_input =
|
||||||
map(|input| input.token_ids.clone()).
|
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
||||||
map(|mut input| {
|
let max_len = tokenized_input
|
||||||
|
.iter()
|
||||||
|
.map(|input| input.token_ids.len())
|
||||||
|
.max()
|
||||||
|
.unwrap();
|
||||||
|
let tokenized_input = tokenized_input
|
||||||
|
.iter()
|
||||||
|
.map(|input| input.token_ids.clone())
|
||||||
|
.map(|mut input| {
|
||||||
input.extend(vec![0; max_len - input.len()]);
|
input.extend(vec![0; max_len - input.len()]);
|
||||||
input
|
input
|
||||||
}).
|
})
|
||||||
map(|input|
|
.map(|input| Tensor::of_slice(&(input)))
|
||||||
Tensor::of_slice(&(input))).
|
.collect::<Vec<_>>();
|
||||||
collect::<Vec<_>>();
|
|
||||||
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
|
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
|
||||||
|
|
||||||
// Forward pass
|
// Forward pass
|
||||||
let (output, _, _) = no_grad(|| {
|
let (output, _, _) = no_grad(|| {
|
||||||
bert_model
|
bert_model.forward_t(
|
||||||
.forward_t(Some(input_tensor),
|
Some(input_tensor),
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
&None,
|
&None,
|
||||||
&None,
|
&None,
|
||||||
false)
|
false,
|
||||||
|
)
|
||||||
});
|
});
|
||||||
|
|
||||||
// Print masked tokens
|
// Print masked tokens
|
||||||
let index_1 = output.get(0).get(4).argmax(0, false);
|
let index_1 = output.get(0).get(4).argmax(0, false);
|
||||||
let index_2 = output.get(1).get(7).argmax(0, false);
|
let index_2 = output.get(1).get(7).argmax(0, false);
|
||||||
let word_1 = tokenizer.vocab().id_to_token(&index_1.int64_value(&[]));
|
let word_1 = tokenizer.vocab().id_to_token(&index_1.int64_value(&[]));
|
||||||
let word_2 = tokenizer.vocab().id_to_token(&index_2.int64_value(&[]));
|
let word_2 = tokenizer.vocab().id_to_token(&index_2.int64_value(&[]));
|
||||||
|
|
||||||
println!("{}", word_1); // Outputs "person" : "Looks like one [person] is missing"
|
println!("{}", word_1); // Outputs "person" : "Looks like one [person] is missing"
|
||||||
println!("{}", word_2);// Outputs "pear" : "It was a very nice and [pleasant] day"
|
println!("{}", word_2); // Outputs "pear" : "It was a very nice and [pleasant] day"
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -11,25 +11,33 @@
|
|||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
extern crate failure;
|
extern crate failure;
|
||||||
|
|
||||||
use tch::{Device, Tensor, nn, no_grad};
|
use rust_bert::distilbert::{
|
||||||
use rust_tokenizers::preprocessing::tokenizer::base_tokenizer::{Tokenizer, TruncationStrategy};
|
DistilBertConfig, DistilBertConfigResources, DistilBertModelMaskedLM, DistilBertModelResources,
|
||||||
use rust_tokenizers::bert_tokenizer::BertTokenizer;
|
DistilBertVocabResources,
|
||||||
use rust_tokenizers::preprocessing::vocab::base_vocab::Vocab;
|
};
|
||||||
|
use rust_bert::resources::{download_resource, RemoteResource, Resource};
|
||||||
use rust_bert::Config;
|
use rust_bert::Config;
|
||||||
use rust_bert::distilbert::{DistilBertConfig, DistilBertModelMaskedLM, DistilBertConfigResources, DistilBertVocabResources, DistilBertModelResources};
|
use rust_tokenizers::bert_tokenizer::BertTokenizer;
|
||||||
use rust_bert::resources::{Resource, download_resource, RemoteResource};
|
use rust_tokenizers::preprocessing::tokenizer::base_tokenizer::{Tokenizer, TruncationStrategy};
|
||||||
|
use rust_tokenizers::preprocessing::vocab::base_vocab::Vocab;
|
||||||
|
use tch::{nn, no_grad, Device, Tensor};
|
||||||
|
|
||||||
fn main() -> failure::Fallible<()> {
|
fn main() -> failure::Fallible<()> {
|
||||||
// Resources paths
|
// Resources paths
|
||||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(DistilBertConfigResources::DISTIL_BERT));
|
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(DistilBertVocabResources::DISTIL_BERT));
|
DistilBertConfigResources::DISTIL_BERT,
|
||||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(DistilBertModelResources::DISTIL_BERT));
|
));
|
||||||
|
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||||
|
DistilBertVocabResources::DISTIL_BERT,
|
||||||
|
));
|
||||||
|
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||||
|
DistilBertModelResources::DISTIL_BERT,
|
||||||
|
));
|
||||||
let config_path = download_resource(&config_resource)?;
|
let config_path = download_resource(&config_resource)?;
|
||||||
let vocab_path = download_resource(&vocab_resource)?;
|
let vocab_path = download_resource(&vocab_resource)?;
|
||||||
let weights_path = download_resource(&weights_resource)?;
|
let weights_path = download_resource(&weights_resource)?;
|
||||||
|
|
||||||
// Set-up masked LM model
|
// Set-up masked LM model
|
||||||
let device = Device::Cpu;
|
let device = Device::Cpu;
|
||||||
let mut vs = nn::VarStore::new(device);
|
let mut vs = nn::VarStore::new(device);
|
||||||
let tokenizer: BertTokenizer = BertTokenizer::from_file(vocab_path.to_str().unwrap(), true);
|
let tokenizer: BertTokenizer = BertTokenizer::from_file(vocab_path.to_str().unwrap(), true);
|
||||||
@ -37,45 +45,51 @@ fn main() -> failure::Fallible<()> {
|
|||||||
let distil_bert_model = DistilBertModelMaskedLM::new(&vs.root(), &config);
|
let distil_bert_model = DistilBertModelMaskedLM::new(&vs.root(), &config);
|
||||||
vs.load(weights_path)?;
|
vs.load(weights_path)?;
|
||||||
|
|
||||||
// Define input
|
// Define input
|
||||||
let input = ["Looks like one thing is missing", "It\'s like comparing oranges to apples"];
|
let input = [
|
||||||
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
"Looks like one thing is missing",
|
||||||
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
|
"It\'s like comparing oranges to apples",
|
||||||
let mut tokenized_input = tokenized_input.
|
];
|
||||||
iter().
|
let tokenized_input =
|
||||||
map(|input| input.token_ids.clone()).
|
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
||||||
map(|mut input| {
|
let max_len = tokenized_input
|
||||||
|
.iter()
|
||||||
|
.map(|input| input.token_ids.len())
|
||||||
|
.max()
|
||||||
|
.unwrap();
|
||||||
|
let mut tokenized_input = tokenized_input
|
||||||
|
.iter()
|
||||||
|
.map(|input| input.token_ids.clone())
|
||||||
|
.map(|mut input| {
|
||||||
input.extend(vec![0; max_len - input.len()]);
|
input.extend(vec![0; max_len - input.len()]);
|
||||||
input
|
input
|
||||||
}).
|
})
|
||||||
collect::<Vec<_>>();
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
// Masking the token [thing] of sentence 1 and [oranges] of sentence 2
|
// Masking the token [thing] of sentence 1 and [oranges] of sentence 2
|
||||||
tokenized_input[0][4] = 103;
|
tokenized_input[0][4] = 103;
|
||||||
tokenized_input[1][6] = 103;
|
tokenized_input[1][6] = 103;
|
||||||
let tokenized_input = tokenized_input.
|
let tokenized_input = tokenized_input
|
||||||
iter().
|
.iter()
|
||||||
map(|input|
|
.map(|input| Tensor::of_slice(&(input)))
|
||||||
Tensor::of_slice(&(input))).
|
.collect::<Vec<_>>();
|
||||||
collect::<Vec<_>>();
|
|
||||||
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
|
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
|
||||||
|
|
||||||
|
// Forward pass
|
||||||
// Forward pass
|
|
||||||
let (output, _, _) = no_grad(|| {
|
let (output, _, _) = no_grad(|| {
|
||||||
distil_bert_model
|
distil_bert_model
|
||||||
.forward_t(Some(input_tensor), None, None, false)
|
.forward_t(Some(input_tensor), None, None, false)
|
||||||
.unwrap()
|
.unwrap()
|
||||||
});
|
});
|
||||||
|
|
||||||
// Print masked tokens
|
// Print masked tokens
|
||||||
let index_1 = output.get(0).get(4).argmax(0, false);
|
let index_1 = output.get(0).get(4).argmax(0, false);
|
||||||
let index_2 = output.get(1).get(6).argmax(0, false);
|
let index_2 = output.get(1).get(6).argmax(0, false);
|
||||||
let word_1 = tokenizer.vocab().id_to_token(&index_1.int64_value(&[]));
|
let word_1 = tokenizer.vocab().id_to_token(&index_1.int64_value(&[]));
|
||||||
let word_2 = tokenizer.vocab().id_to_token(&index_2.int64_value(&[]));
|
let word_2 = tokenizer.vocab().id_to_token(&index_2.int64_value(&[]));
|
||||||
|
|
||||||
println!("{}", word_1); // Outputs "person" : "Looks like one [person] is missing"
|
println!("{}", word_1); // Outputs "person" : "Looks like one [person] is missing"
|
||||||
println!("{}", word_2);// Outputs "pear" : "It\'s like comparing [pear] to apples"
|
println!("{}", word_2); // Outputs "pear" : "It\'s like comparing [pear] to apples"
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -1,26 +1,44 @@
|
|||||||
extern crate failure;
|
extern crate failure;
|
||||||
|
|
||||||
use rust_bert::gpt2::{Gpt2ConfigResources, Gpt2VocabResources, Gpt2MergesResources, Gpt2ModelResources};
|
use rust_bert::albert::{AlbertConfigResources, AlbertModelResources, AlbertVocabResources};
|
||||||
use rust_bert::distilbert::{DistilBertModelResources, DistilBertConfigResources, DistilBertVocabResources};
|
use rust_bert::bart::{
|
||||||
use rust_bert::openai_gpt::{OpenAiGptConfigResources, OpenAiGptVocabResources, OpenAiGptMergesResources, OpenAiGptModelResources};
|
BartConfigResources, BartMergesResources, BartModelResources, BartVocabResources,
|
||||||
use rust_bert::roberta::{RobertaConfigResources, RobertaVocabResources, RobertaMergesResources, RobertaModelResources};
|
};
|
||||||
use rust_bert::bert::{BertConfigResources, BertVocabResources, BertModelResources};
|
use rust_bert::bert::{BertConfigResources, BertModelResources, BertVocabResources};
|
||||||
use rust_bert::bart::{BartConfigResources, BartVocabResources, BartMergesResources, BartModelResources};
|
use rust_bert::distilbert::{
|
||||||
use rust_bert::resources::{Resource, download_resource, RemoteResource};
|
DistilBertConfigResources, DistilBertModelResources, DistilBertVocabResources,
|
||||||
use rust_bert::electra::{ElectraConfigResources, ElectraVocabResources, ElectraModelResources};
|
};
|
||||||
use rust_bert::albert::{AlbertConfigResources, AlbertVocabResources, AlbertModelResources};
|
use rust_bert::electra::{ElectraConfigResources, ElectraModelResources, ElectraVocabResources};
|
||||||
|
use rust_bert::gpt2::{
|
||||||
|
Gpt2ConfigResources, Gpt2MergesResources, Gpt2ModelResources, Gpt2VocabResources,
|
||||||
|
};
|
||||||
|
use rust_bert::openai_gpt::{
|
||||||
|
OpenAiGptConfigResources, OpenAiGptMergesResources, OpenAiGptModelResources,
|
||||||
|
OpenAiGptVocabResources,
|
||||||
|
};
|
||||||
|
use rust_bert::resources::{download_resource, RemoteResource, Resource};
|
||||||
|
use rust_bert::roberta::{
|
||||||
|
RobertaConfigResources, RobertaMergesResources, RobertaModelResources, RobertaVocabResources,
|
||||||
|
};
|
||||||
|
|
||||||
/// This example downloads and caches all dependencies used in model tests. This allows for safe
|
/// This example downloads and caches all dependencies used in model tests. This allows for safe
|
||||||
/// multi threaded testing (two test using the same resource would otherwise download the file to
|
/// multi threaded testing (two test using the same resource would otherwise download the file to
|
||||||
/// the same location).
|
/// the same location).
|
||||||
|
|
||||||
|
|
||||||
fn download_distil_gpt2() -> failure::Fallible<()> {
|
fn download_distil_gpt2() -> failure::Fallible<()> {
|
||||||
// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
|
// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
|
||||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::DISTIL_GPT2));
|
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::DISTIL_GPT2));
|
Gpt2ConfigResources::DISTIL_GPT2,
|
||||||
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::DISTIL_GPT2));
|
));
|
||||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::DISTIL_GPT2));
|
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||||
|
Gpt2VocabResources::DISTIL_GPT2,
|
||||||
|
));
|
||||||
|
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||||
|
Gpt2MergesResources::DISTIL_GPT2,
|
||||||
|
));
|
||||||
|
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||||
|
Gpt2ModelResources::DISTIL_GPT2,
|
||||||
|
));
|
||||||
let _ = download_resource(&config_resource)?;
|
let _ = download_resource(&config_resource)?;
|
||||||
let _ = download_resource(&vocab_resource)?;
|
let _ = download_resource(&vocab_resource)?;
|
||||||
let _ = download_resource(&merges_resource)?;
|
let _ = download_resource(&merges_resource)?;
|
||||||
@ -29,10 +47,16 @@ fn download_distil_gpt2() -> failure::Fallible<()> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn download_distilbert_sst2() -> failure::Fallible<()> {
|
fn download_distilbert_sst2() -> failure::Fallible<()> {
|
||||||
// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
|
// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
|
||||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(DistilBertModelResources::DISTIL_BERT_SST2));
|
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(DistilBertConfigResources::DISTIL_BERT_SST2));
|
DistilBertModelResources::DISTIL_BERT_SST2,
|
||||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(DistilBertVocabResources::DISTIL_BERT_SST2));
|
));
|
||||||
|
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||||
|
DistilBertConfigResources::DISTIL_BERT_SST2,
|
||||||
|
));
|
||||||
|
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||||
|
DistilBertVocabResources::DISTIL_BERT_SST2,
|
||||||
|
));
|
||||||
let _ = download_resource(&config_resource)?;
|
let _ = download_resource(&config_resource)?;
|
||||||
let _ = download_resource(&vocab_resource)?;
|
let _ = download_resource(&vocab_resource)?;
|
||||||
let _ = download_resource(&weights_resource)?;
|
let _ = download_resource(&weights_resource)?;
|
||||||
@ -40,10 +64,16 @@ fn download_distilbert_sst2() -> failure::Fallible<()> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn download_distilbert_qa() -> failure::Fallible<()> {
|
fn download_distilbert_qa() -> failure::Fallible<()> {
|
||||||
// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
|
// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
|
||||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(DistilBertModelResources::DISTIL_BERT_SQUAD));
|
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(DistilBertConfigResources::DISTIL_BERT_SQUAD));
|
DistilBertModelResources::DISTIL_BERT_SQUAD,
|
||||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(DistilBertVocabResources::DISTIL_BERT_SQUAD));
|
));
|
||||||
|
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||||
|
DistilBertConfigResources::DISTIL_BERT_SQUAD,
|
||||||
|
));
|
||||||
|
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||||
|
DistilBertVocabResources::DISTIL_BERT_SQUAD,
|
||||||
|
));
|
||||||
let _ = download_resource(&config_resource)?;
|
let _ = download_resource(&config_resource)?;
|
||||||
let _ = download_resource(&vocab_resource)?;
|
let _ = download_resource(&vocab_resource)?;
|
||||||
let _ = download_resource(&weights_resource)?;
|
let _ = download_resource(&weights_resource)?;
|
||||||
@ -51,10 +81,16 @@ fn download_distilbert_qa() -> failure::Fallible<()> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn download_distilbert() -> failure::Fallible<()> {
|
fn download_distilbert() -> failure::Fallible<()> {
|
||||||
// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
|
// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
|
||||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(DistilBertModelResources::DISTIL_BERT));
|
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(DistilBertConfigResources::DISTIL_BERT));
|
DistilBertModelResources::DISTIL_BERT,
|
||||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(DistilBertVocabResources::DISTIL_BERT));
|
));
|
||||||
|
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||||
|
DistilBertConfigResources::DISTIL_BERT,
|
||||||
|
));
|
||||||
|
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||||
|
DistilBertVocabResources::DISTIL_BERT,
|
||||||
|
));
|
||||||
let _ = download_resource(&config_resource)?;
|
let _ = download_resource(&config_resource)?;
|
||||||
let _ = download_resource(&vocab_resource)?;
|
let _ = download_resource(&vocab_resource)?;
|
||||||
let _ = download_resource(&weights_resource)?;
|
let _ = download_resource(&weights_resource)?;
|
||||||
@ -62,11 +98,15 @@ fn download_distilbert() -> failure::Fallible<()> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn download_gpt2() -> failure::Fallible<()> {
|
fn download_gpt2() -> failure::Fallible<()> {
|
||||||
// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2. Modified with conversion to C-array format.
|
// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2. Modified with conversion to C-array format.
|
||||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2));
|
let config_resource =
|
||||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2));
|
Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2));
|
||||||
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
|
let vocab_resource =
|
||||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
|
Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2));
|
||||||
|
let merges_resource =
|
||||||
|
Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
|
||||||
|
let weights_resource =
|
||||||
|
Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
|
||||||
let _ = download_resource(&config_resource)?;
|
let _ = download_resource(&config_resource)?;
|
||||||
let _ = download_resource(&vocab_resource)?;
|
let _ = download_resource(&vocab_resource)?;
|
||||||
let _ = download_resource(&merges_resource)?;
|
let _ = download_resource(&merges_resource)?;
|
||||||
@ -75,11 +115,19 @@ fn download_gpt2() -> failure::Fallible<()> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn download_gpt() -> failure::Fallible<()> {
|
fn download_gpt() -> failure::Fallible<()> {
|
||||||
// Shared under MIT license by the OpenAI team at https://github.com/openai/finetune-transformer-lm. Modified with conversion to C-array format.
|
// Shared under MIT license by the OpenAI team at https://github.com/openai/finetune-transformer-lm. Modified with conversion to C-array format.
|
||||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptConfigResources::GPT));
|
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptVocabResources::GPT));
|
OpenAiGptConfigResources::GPT,
|
||||||
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptMergesResources::GPT));
|
));
|
||||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptModelResources::GPT));
|
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||||
|
OpenAiGptVocabResources::GPT,
|
||||||
|
));
|
||||||
|
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||||
|
OpenAiGptMergesResources::GPT,
|
||||||
|
));
|
||||||
|
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||||
|
OpenAiGptModelResources::GPT,
|
||||||
|
));
|
||||||
let _ = download_resource(&config_resource)?;
|
let _ = download_resource(&config_resource)?;
|
||||||
let _ = download_resource(&vocab_resource)?;
|
let _ = download_resource(&vocab_resource)?;
|
||||||
let _ = download_resource(&merges_resource)?;
|
let _ = download_resource(&merges_resource)?;
|
||||||
@ -88,11 +136,19 @@ fn download_gpt() -> failure::Fallible<()> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn download_roberta() -> failure::Fallible<()> {
|
fn download_roberta() -> failure::Fallible<()> {
|
||||||
// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
|
// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
|
||||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaConfigResources::ROBERTA));
|
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaVocabResources::ROBERTA));
|
RobertaConfigResources::ROBERTA,
|
||||||
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaMergesResources::ROBERTA));
|
));
|
||||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaModelResources::ROBERTA));
|
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||||
|
RobertaVocabResources::ROBERTA,
|
||||||
|
));
|
||||||
|
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||||
|
RobertaMergesResources::ROBERTA,
|
||||||
|
));
|
||||||
|
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||||
|
RobertaModelResources::ROBERTA,
|
||||||
|
));
|
||||||
let _ = download_resource(&config_resource)?;
|
let _ = download_resource(&config_resource)?;
|
||||||
let _ = download_resource(&vocab_resource)?;
|
let _ = download_resource(&vocab_resource)?;
|
||||||
let _ = download_resource(&merges_resource)?;
|
let _ = download_resource(&merges_resource)?;
|
||||||
@ -101,10 +157,13 @@ fn download_roberta() -> failure::Fallible<()> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn download_bert() -> failure::Fallible<()> {
|
fn download_bert() -> failure::Fallible<()> {
|
||||||
// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/bert. Modified with conversion to C-array format.
|
// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/bert. Modified with conversion to C-array format.
|
||||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT));
|
let config_resource =
|
||||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT));
|
Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT));
|
||||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(BertModelResources::BERT));
|
let vocab_resource =
|
||||||
|
Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT));
|
||||||
|
let weights_resource =
|
||||||
|
Resource::Remote(RemoteResource::from_pretrained(BertModelResources::BERT));
|
||||||
let _ = download_resource(&config_resource)?;
|
let _ = download_resource(&config_resource)?;
|
||||||
let _ = download_resource(&vocab_resource)?;
|
let _ = download_resource(&vocab_resource)?;
|
||||||
let _ = download_resource(&weights_resource)?;
|
let _ = download_resource(&weights_resource)?;
|
||||||
@ -112,10 +171,16 @@ fn download_bert() -> failure::Fallible<()> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn download_bert_ner() -> failure::Fallible<()> {
|
fn download_bert_ner() -> failure::Fallible<()> {
|
||||||
// Shared under MIT license by the MDZ Digital Library team at the Bavarian State Library at https://github.com/dbmdz/berts. Modified with conversion to C-array format.
|
// Shared under MIT license by the MDZ Digital Library team at the Bavarian State Library at https://github.com/dbmdz/berts. Modified with conversion to C-array format.
|
||||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT_NER));
|
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT_NER));
|
BertConfigResources::BERT_NER,
|
||||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(BertModelResources::BERT_NER));
|
));
|
||||||
|
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||||
|
BertVocabResources::BERT_NER,
|
||||||
|
));
|
||||||
|
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||||
|
BertModelResources::BERT_NER,
|
||||||
|
));
|
||||||
let _ = download_resource(&config_resource)?;
|
let _ = download_resource(&config_resource)?;
|
||||||
let _ = download_resource(&vocab_resource)?;
|
let _ = download_resource(&vocab_resource)?;
|
||||||
let _ = download_resource(&weights_resource)?;
|
let _ = download_resource(&weights_resource)?;
|
||||||
@ -123,11 +188,15 @@ fn download_bert_ner() -> failure::Fallible<()> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn download_bart() -> failure::Fallible<()> {
|
fn download_bart() -> failure::Fallible<()> {
|
||||||
// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
|
// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
|
||||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(BartConfigResources::BART));
|
let config_resource =
|
||||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(BartVocabResources::BART));
|
Resource::Remote(RemoteResource::from_pretrained(BartConfigResources::BART));
|
||||||
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(BartMergesResources::BART));
|
let vocab_resource =
|
||||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(BartModelResources::BART));
|
Resource::Remote(RemoteResource::from_pretrained(BartVocabResources::BART));
|
||||||
|
let merges_resource =
|
||||||
|
Resource::Remote(RemoteResource::from_pretrained(BartMergesResources::BART));
|
||||||
|
let weights_resource =
|
||||||
|
Resource::Remote(RemoteResource::from_pretrained(BartModelResources::BART));
|
||||||
let _ = download_resource(&config_resource)?;
|
let _ = download_resource(&config_resource)?;
|
||||||
let _ = download_resource(&vocab_resource)?;
|
let _ = download_resource(&vocab_resource)?;
|
||||||
let _ = download_resource(&merges_resource)?;
|
let _ = download_resource(&merges_resource)?;
|
||||||
@ -136,11 +205,19 @@ fn download_bart() -> failure::Fallible<()> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn download_bart_cnn() -> failure::Fallible<()> {
|
fn download_bart_cnn() -> failure::Fallible<()> {
|
||||||
// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
|
// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
|
||||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(BartConfigResources::BART_CNN));
|
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(BartVocabResources::BART_CNN));
|
BartConfigResources::BART_CNN,
|
||||||
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(BartMergesResources::BART_CNN));
|
));
|
||||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(BartModelResources::BART_CNN));
|
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||||
|
BartVocabResources::BART_CNN,
|
||||||
|
));
|
||||||
|
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||||
|
BartMergesResources::BART_CNN,
|
||||||
|
));
|
||||||
|
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||||
|
BartModelResources::BART_CNN,
|
||||||
|
));
|
||||||
let _ = download_resource(&config_resource)?;
|
let _ = download_resource(&config_resource)?;
|
||||||
let _ = download_resource(&vocab_resource)?;
|
let _ = download_resource(&vocab_resource)?;
|
||||||
let _ = download_resource(&merges_resource)?;
|
let _ = download_resource(&merges_resource)?;
|
||||||
@ -149,10 +226,16 @@ fn download_bart_cnn() -> failure::Fallible<()> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn download_electra_generator() -> failure::Fallible<()> {
|
fn download_electra_generator() -> failure::Fallible<()> {
|
||||||
// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/electra. Modified with conversion to C-array format.
|
// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/electra. Modified with conversion to C-array format.
|
||||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(ElectraConfigResources::BASE_GENERATOR));
|
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(ElectraVocabResources::BASE_GENERATOR));
|
ElectraConfigResources::BASE_GENERATOR,
|
||||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(ElectraModelResources::BASE_GENERATOR));
|
));
|
||||||
|
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||||
|
ElectraVocabResources::BASE_GENERATOR,
|
||||||
|
));
|
||||||
|
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||||
|
ElectraModelResources::BASE_GENERATOR,
|
||||||
|
));
|
||||||
let _ = download_resource(&config_resource)?;
|
let _ = download_resource(&config_resource)?;
|
||||||
let _ = download_resource(&vocab_resource)?;
|
let _ = download_resource(&vocab_resource)?;
|
||||||
let _ = download_resource(&weights_resource)?;
|
let _ = download_resource(&weights_resource)?;
|
||||||
@ -160,10 +243,16 @@ fn download_electra_generator() -> failure::Fallible<()> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn download_electra_discriminator() -> failure::Fallible<()> {
|
fn download_electra_discriminator() -> failure::Fallible<()> {
|
||||||
// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/electra. Modified with conversion to C-array format.
|
// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/electra. Modified with conversion to C-array format.
|
||||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(ElectraConfigResources::BASE_DISCRIMINATOR));
|
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(ElectraVocabResources::BASE_DISCRIMINATOR));
|
ElectraConfigResources::BASE_DISCRIMINATOR,
|
||||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(ElectraModelResources::BASE_DISCRIMINATOR));
|
));
|
||||||
|
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||||
|
ElectraVocabResources::BASE_DISCRIMINATOR,
|
||||||
|
));
|
||||||
|
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||||
|
ElectraModelResources::BASE_DISCRIMINATOR,
|
||||||
|
));
|
||||||
let _ = download_resource(&config_resource)?;
|
let _ = download_resource(&config_resource)?;
|
||||||
let _ = download_resource(&vocab_resource)?;
|
let _ = download_resource(&vocab_resource)?;
|
||||||
let _ = download_resource(&weights_resource)?;
|
let _ = download_resource(&weights_resource)?;
|
||||||
@ -171,10 +260,16 @@ fn download_electra_discriminator() -> failure::Fallible<()> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn download_albert_base_v2() -> failure::Fallible<()> {
|
fn download_albert_base_v2() -> failure::Fallible<()> {
|
||||||
// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/ALBERT. Modified with conversion to C-array format.
|
// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/ALBERT. Modified with conversion to C-array format.
|
||||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(AlbertConfigResources::ALBERT_BASE_V2));
|
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(AlbertVocabResources::ALBERT_BASE_V2));
|
AlbertConfigResources::ALBERT_BASE_V2,
|
||||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(AlbertModelResources::ALBERT_BASE_V2));
|
));
|
||||||
|
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||||
|
AlbertVocabResources::ALBERT_BASE_V2,
|
||||||
|
));
|
||||||
|
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||||
|
AlbertModelResources::ALBERT_BASE_V2,
|
||||||
|
));
|
||||||
let _ = download_resource(&config_resource)?;
|
let _ = download_resource(&config_resource)?;
|
||||||
let _ = download_resource(&vocab_resource)?;
|
let _ = download_resource(&vocab_resource)?;
|
||||||
let _ = download_resource(&weights_resource)?;
|
let _ = download_resource(&weights_resource)?;
|
||||||
@ -198,4 +293,4 @@ fn main() -> failure::Fallible<()> {
|
|||||||
let _ = download_albert_base_v2();
|
let _ = download_albert_base_v2();
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -12,23 +12,31 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
|
use rust_bert::electra::{
|
||||||
use rust_bert::resources::{Resource, download_resource, RemoteResource};
|
ElectraConfig, ElectraConfigResources, ElectraDiscriminator, ElectraModelResources,
|
||||||
use rust_bert::electra::{ElectraConfig, ElectraDiscriminator, ElectraConfigResources, ElectraVocabResources, ElectraModelResources};
|
ElectraVocabResources,
|
||||||
|
};
|
||||||
|
use rust_bert::resources::{download_resource, RemoteResource, Resource};
|
||||||
use rust_bert::Config;
|
use rust_bert::Config;
|
||||||
use rust_tokenizers::{BertTokenizer, Tokenizer, TruncationStrategy};
|
use rust_tokenizers::{BertTokenizer, Tokenizer, TruncationStrategy};
|
||||||
use tch::{Tensor, Device, nn, no_grad};
|
use tch::{nn, no_grad, Device, Tensor};
|
||||||
|
|
||||||
fn main() -> failure::Fallible<()> {
|
fn main() -> failure::Fallible<()> {
|
||||||
// Resources paths
|
// Resources paths
|
||||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(ElectraConfigResources::BASE_DISCRIMINATOR));
|
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(ElectraVocabResources::BASE_DISCRIMINATOR));
|
ElectraConfigResources::BASE_DISCRIMINATOR,
|
||||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(ElectraModelResources::BASE_DISCRIMINATOR));
|
));
|
||||||
|
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||||
|
ElectraVocabResources::BASE_DISCRIMINATOR,
|
||||||
|
));
|
||||||
|
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||||
|
ElectraModelResources::BASE_DISCRIMINATOR,
|
||||||
|
));
|
||||||
let config_path = download_resource(&config_resource)?;
|
let config_path = download_resource(&config_resource)?;
|
||||||
let vocab_path = download_resource(&vocab_resource)?;
|
let vocab_path = download_resource(&vocab_resource)?;
|
||||||
let weights_path = download_resource(&weights_resource)?;
|
let weights_path = download_resource(&weights_resource)?;
|
||||||
|
|
||||||
// Set-up masked LM model
|
// Set-up masked LM model
|
||||||
let device = Device::Cpu;
|
let device = Device::Cpu;
|
||||||
let mut vs = nn::VarStore::new(device);
|
let mut vs = nn::VarStore::new(device);
|
||||||
let tokenizer: BertTokenizer = BertTokenizer::from_file(vocab_path.to_str().unwrap(), true);
|
let tokenizer: BertTokenizer = BertTokenizer::from_file(vocab_path.to_str().unwrap(), true);
|
||||||
@ -36,45 +44,45 @@ fn main() -> failure::Fallible<()> {
|
|||||||
let electra_model = ElectraDiscriminator::new(&vs.root(), &config);
|
let electra_model = ElectraDiscriminator::new(&vs.root(), &config);
|
||||||
vs.load(weights_path)?;
|
vs.load(weights_path)?;
|
||||||
|
|
||||||
// Define input
|
// Define input
|
||||||
let input = ["One Two Three Ten Five Six Seven Eight"];
|
let input = ["One Two Three Ten Five Six Seven Eight"];
|
||||||
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
let tokenized_input =
|
||||||
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
|
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
||||||
let encoded_input = tokenized_input.
|
let max_len = tokenized_input
|
||||||
iter().
|
.iter()
|
||||||
map(|input| input.token_ids.clone()).
|
.map(|input| input.token_ids.len())
|
||||||
map(|mut input| {
|
.max()
|
||||||
|
.unwrap();
|
||||||
|
let encoded_input = tokenized_input
|
||||||
|
.iter()
|
||||||
|
.map(|input| input.token_ids.clone())
|
||||||
|
.map(|mut input| {
|
||||||
input.extend(vec![0; max_len - input.len()]);
|
input.extend(vec![0; max_len - input.len()]);
|
||||||
input
|
input
|
||||||
}).
|
})
|
||||||
map(|input|
|
.map(|input| Tensor::of_slice(&(input)))
|
||||||
Tensor::of_slice(&(input))).
|
.collect::<Vec<_>>();
|
||||||
collect::<Vec<_>>();
|
|
||||||
let input_tensor = Tensor::stack(encoded_input.as_slice(), 0).to(device);
|
let input_tensor = Tensor::stack(encoded_input.as_slice(), 0).to(device);
|
||||||
|
|
||||||
// Forward pass
|
// Forward pass
|
||||||
let (output, _, _) = no_grad(|| {
|
let (output, _, _) =
|
||||||
electra_model
|
no_grad(|| electra_model.forward_t(Some(input_tensor), None, None, None, None, false));
|
||||||
.forward_t(Some(input_tensor),
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
false)
|
|
||||||
});
|
|
||||||
|
|
||||||
// Print model predictions
|
// Print model predictions
|
||||||
for (position, token) in tokenized_input[0].token_ids.iter().enumerate() {
|
for (position, token) in tokenized_input[0].token_ids.iter().enumerate() {
|
||||||
let probability = output.double_value(&[position as i64]);
|
let probability = output.double_value(&[position as i64]);
|
||||||
let generated = if probability > 0.5 { "generated" } else { "original" };
|
let generated = if probability > 0.5 {
|
||||||
println!("{:?}: {} ({:.1}%)",
|
"generated"
|
||||||
tokenizer.decode([*token].to_vec(),
|
} else {
|
||||||
false,
|
"original"
|
||||||
false),
|
};
|
||||||
generated,
|
println!(
|
||||||
100f64 * probability)
|
"{:?}: {} ({:.1}%)",
|
||||||
|
tokenizer.decode([*token].to_vec(), false, false),
|
||||||
|
generated,
|
||||||
|
100f64 * probability
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -12,23 +12,31 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
|
use rust_bert::electra::{
|
||||||
use rust_bert::resources::{Resource, download_resource, RemoteResource};
|
ElectraConfig, ElectraConfigResources, ElectraForMaskedLM, ElectraModelResources,
|
||||||
use rust_bert::electra::{ElectraConfig, ElectraForMaskedLM, ElectraModelResources, ElectraConfigResources, ElectraVocabResources};
|
ElectraVocabResources,
|
||||||
|
};
|
||||||
|
use rust_bert::resources::{download_resource, RemoteResource, Resource};
|
||||||
use rust_bert::Config;
|
use rust_bert::Config;
|
||||||
use rust_tokenizers::{BertTokenizer, Tokenizer, TruncationStrategy, Vocab};
|
use rust_tokenizers::{BertTokenizer, Tokenizer, TruncationStrategy, Vocab};
|
||||||
use tch::{Tensor, Device, nn, no_grad};
|
use tch::{nn, no_grad, Device, Tensor};
|
||||||
|
|
||||||
fn main() -> failure::Fallible<()> {
|
fn main() -> failure::Fallible<()> {
|
||||||
// Resources paths
|
// Resources paths
|
||||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(ElectraConfigResources::BASE_GENERATOR));
|
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(ElectraVocabResources::BASE_GENERATOR));
|
ElectraConfigResources::BASE_GENERATOR,
|
||||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(ElectraModelResources::BASE_GENERATOR));
|
));
|
||||||
|
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||||
|
ElectraVocabResources::BASE_GENERATOR,
|
||||||
|
));
|
||||||
|
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||||
|
ElectraModelResources::BASE_GENERATOR,
|
||||||
|
));
|
||||||
let config_path = download_resource(&config_resource)?;
|
let config_path = download_resource(&config_resource)?;
|
||||||
let vocab_path = download_resource(&vocab_resource)?;
|
let vocab_path = download_resource(&vocab_resource)?;
|
||||||
let weights_path = download_resource(&weights_resource)?;
|
let weights_path = download_resource(&weights_resource)?;
|
||||||
|
|
||||||
// Set-up masked LM model
|
// Set-up masked LM model
|
||||||
let device = Device::Cpu;
|
let device = Device::Cpu;
|
||||||
let mut vs = nn::VarStore::new(device);
|
let mut vs = nn::VarStore::new(device);
|
||||||
let tokenizer: BertTokenizer = BertTokenizer::from_file(vocab_path.to_str().unwrap(), true);
|
let tokenizer: BertTokenizer = BertTokenizer::from_file(vocab_path.to_str().unwrap(), true);
|
||||||
@ -36,41 +44,41 @@ fn main() -> failure::Fallible<()> {
|
|||||||
let electra_model = ElectraForMaskedLM::new(&vs.root(), &config);
|
let electra_model = ElectraForMaskedLM::new(&vs.root(), &config);
|
||||||
vs.load(weights_path)?;
|
vs.load(weights_path)?;
|
||||||
|
|
||||||
// Define input
|
// Define input
|
||||||
let input = ["Looks like one [MASK] is missing", "It was a very nice and [MASK] day"];
|
let input = [
|
||||||
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
"Looks like one [MASK] is missing",
|
||||||
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
|
"It was a very nice and [MASK] day",
|
||||||
let tokenized_input = tokenized_input.
|
];
|
||||||
iter().
|
let tokenized_input =
|
||||||
map(|input| input.token_ids.clone()).
|
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
||||||
map(|mut input| {
|
let max_len = tokenized_input
|
||||||
|
.iter()
|
||||||
|
.map(|input| input.token_ids.len())
|
||||||
|
.max()
|
||||||
|
.unwrap();
|
||||||
|
let tokenized_input = tokenized_input
|
||||||
|
.iter()
|
||||||
|
.map(|input| input.token_ids.clone())
|
||||||
|
.map(|mut input| {
|
||||||
input.extend(vec![0; max_len - input.len()]);
|
input.extend(vec![0; max_len - input.len()]);
|
||||||
input
|
input
|
||||||
}).
|
})
|
||||||
map(|input|
|
.map(|input| Tensor::of_slice(&(input)))
|
||||||
Tensor::of_slice(&(input))).
|
.collect::<Vec<_>>();
|
||||||
collect::<Vec<_>>();
|
|
||||||
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
|
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
|
||||||
|
|
||||||
// Forward pass
|
// Forward pass
|
||||||
let (output, _, _) = no_grad(|| {
|
let (output, _, _) =
|
||||||
electra_model
|
no_grad(|| electra_model.forward_t(Some(input_tensor), None, None, None, None, false));
|
||||||
.forward_t(Some(input_tensor),
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
false)
|
|
||||||
});
|
|
||||||
|
|
||||||
// Print masked tokens
|
// Print masked tokens
|
||||||
let index_1 = output.get(0).get(4).argmax(0, false);
|
let index_1 = output.get(0).get(4).argmax(0, false);
|
||||||
let index_2 = output.get(1).get(7).argmax(0, false);
|
let index_2 = output.get(1).get(7).argmax(0, false);
|
||||||
let word_1 = tokenizer.vocab().id_to_token(&index_1.int64_value(&[]));
|
let word_1 = tokenizer.vocab().id_to_token(&index_1.int64_value(&[]));
|
||||||
let word_2 = tokenizer.vocab().id_to_token(&index_2.int64_value(&[]));
|
let word_2 = tokenizer.vocab().id_to_token(&index_2.int64_value(&[]));
|
||||||
|
|
||||||
println!("{}", word_1); // Outputs "thing" : "Looks like one [thing] is missing"
|
println!("{}", word_1); // Outputs "thing" : "Looks like one [thing] is missing"
|
||||||
println!("{}", word_2);// Outputs "sunny" : "It was a very nice and [sunny] day"
|
println!("{}", word_2); // Outputs "sunny" : "It was a very nice and [sunny] day"
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -12,12 +12,10 @@
|
|||||||
|
|
||||||
extern crate failure;
|
extern crate failure;
|
||||||
|
|
||||||
use rust_bert::pipelines::generation::{GPT2Generator, LanguageGenerator, GenerateConfig};
|
use rust_bert::pipelines::generation::{GPT2Generator, GenerateConfig, LanguageGenerator};
|
||||||
|
|
||||||
|
|
||||||
fn main() -> failure::Fallible<()> {
|
fn main() -> failure::Fallible<()> {
|
||||||
|
// Set-up masked LM model
|
||||||
// Set-up masked LM model
|
|
||||||
let generate_config = GenerateConfig {
|
let generate_config = GenerateConfig {
|
||||||
max_length: 30,
|
max_length: 30,
|
||||||
do_sample: true,
|
do_sample: true,
|
||||||
@ -30,10 +28,10 @@ fn main() -> failure::Fallible<()> {
|
|||||||
|
|
||||||
let input_context = "The dog";
|
let input_context = "The dog";
|
||||||
let second_input_context = "The cat was";
|
let second_input_context = "The cat was";
|
||||||
let output = model.generate(Some(vec!(input_context, second_input_context)), None);
|
let output = model.generate(Some(vec![input_context, second_input_context]), None);
|
||||||
|
|
||||||
for sentence in output {
|
for sentence in output {
|
||||||
println!("{:?}", sentence);
|
println!("{:?}", sentence);
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -12,65 +12,82 @@
|
|||||||
|
|
||||||
extern crate failure;
|
extern crate failure;
|
||||||
|
|
||||||
use tch::{Device, nn, Tensor};
|
use rust_bert::gpt2::{
|
||||||
use rust_tokenizers::{TruncationStrategy, Tokenizer, Gpt2Tokenizer};
|
GPT2LMHeadModel, Gpt2Config, Gpt2ConfigResources, Gpt2MergesResources, Gpt2ModelResources,
|
||||||
use rust_bert::gpt2::{Gpt2Config, GPT2LMHeadModel, Gpt2ConfigResources, Gpt2VocabResources, Gpt2MergesResources, Gpt2ModelResources};
|
Gpt2VocabResources,
|
||||||
use rust_bert::pipelines::generation::{LMHeadModel, Cache};
|
};
|
||||||
use rust_bert::resources::{Resource, download_resource, RemoteResource};
|
use rust_bert::pipelines::generation::{Cache, LMHeadModel};
|
||||||
|
use rust_bert::resources::{download_resource, RemoteResource, Resource};
|
||||||
use rust_bert::Config;
|
use rust_bert::Config;
|
||||||
|
use rust_tokenizers::{Gpt2Tokenizer, Tokenizer, TruncationStrategy};
|
||||||
|
use tch::{nn, Device, Tensor};
|
||||||
|
|
||||||
fn main() -> failure::Fallible<()> {
|
fn main() -> failure::Fallible<()> {
|
||||||
// Resources set-up
|
// Resources set-up
|
||||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2));
|
let config_resource =
|
||||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2));
|
Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2));
|
||||||
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
|
let vocab_resource =
|
||||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
|
Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2));
|
||||||
|
let merges_resource =
|
||||||
|
Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
|
||||||
|
let weights_resource =
|
||||||
|
Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
|
||||||
let config_path = download_resource(&config_resource)?;
|
let config_path = download_resource(&config_resource)?;
|
||||||
let vocab_path = download_resource(&vocab_resource)?;
|
let vocab_path = download_resource(&vocab_resource)?;
|
||||||
let merges_path = download_resource(&merges_resource)?;
|
let merges_path = download_resource(&merges_resource)?;
|
||||||
let weights_path = download_resource(&weights_resource)?;
|
let weights_path = download_resource(&weights_resource)?;
|
||||||
|
|
||||||
// Set-up masked LM model
|
// Set-up masked LM model
|
||||||
let device = Device::Cpu;
|
let device = Device::Cpu;
|
||||||
let mut vs = nn::VarStore::new(device);
|
let mut vs = nn::VarStore::new(device);
|
||||||
let tokenizer: Gpt2Tokenizer = Gpt2Tokenizer::from_file(vocab_path.to_str().unwrap(), merges_path.to_str().unwrap(), false);
|
let tokenizer: Gpt2Tokenizer = Gpt2Tokenizer::from_file(
|
||||||
|
vocab_path.to_str().unwrap(),
|
||||||
|
merges_path.to_str().unwrap(),
|
||||||
|
false,
|
||||||
|
);
|
||||||
let config = Gpt2Config::from_file(config_path);
|
let config = Gpt2Config::from_file(config_path);
|
||||||
let gpt2_model = GPT2LMHeadModel::new(&vs.root(), &config);
|
let gpt2_model = GPT2LMHeadModel::new(&vs.root(), &config);
|
||||||
vs.load(weights_path)?;
|
vs.load(weights_path)?;
|
||||||
|
|
||||||
// Define input
|
// Define input
|
||||||
let input = ["One two three four five six seven eight nine ten eleven"];
|
let input = ["One two three four five six seven eight nine ten eleven"];
|
||||||
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
let tokenized_input =
|
||||||
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
|
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
||||||
let tokenized_input = tokenized_input.
|
let max_len = tokenized_input
|
||||||
iter().
|
.iter()
|
||||||
map(|input| input.token_ids.clone()).
|
.map(|input| input.token_ids.len())
|
||||||
map(|mut input| {
|
.max()
|
||||||
|
.unwrap();
|
||||||
|
let tokenized_input = tokenized_input
|
||||||
|
.iter()
|
||||||
|
.map(|input| input.token_ids.clone())
|
||||||
|
.map(|mut input| {
|
||||||
input.extend(vec![0; max_len - input.len()]);
|
input.extend(vec![0; max_len - input.len()]);
|
||||||
input
|
input
|
||||||
}).
|
})
|
||||||
map(|input|
|
.map(|input| Tensor::of_slice(&(input)))
|
||||||
Tensor::of_slice(&(input))).
|
.collect::<Vec<_>>();
|
||||||
collect::<Vec<_>>();
|
|
||||||
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
|
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
|
||||||
|
|
||||||
// Forward pass
|
// Forward pass
|
||||||
let (output, _, _, _, _) = gpt2_model.forward_t(
|
let (output, _, _, _, _) = gpt2_model
|
||||||
&Some(input_tensor),
|
.forward_t(
|
||||||
Cache::None,
|
&Some(input_tensor),
|
||||||
&None,
|
Cache::None,
|
||||||
&None,
|
&None,
|
||||||
&None,
|
&None,
|
||||||
&None,
|
&None,
|
||||||
None,
|
&None,
|
||||||
&None,
|
None,
|
||||||
false).unwrap();
|
&None,
|
||||||
|
false,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
let next_word_id = output.get(0).get(-1).argmax(-1, true).int64_value(&[0]);
|
let next_word_id = output.get(0).get(-1).argmax(-1, true).int64_value(&[0]);
|
||||||
let next_word = tokenizer.decode(vec!(next_word_id), true, true);
|
let next_word = tokenizer.decode(vec![next_word_id], true, true);
|
||||||
println!("Provided input: {}", input[0]);
|
println!("Provided input: {}", input[0]);
|
||||||
println!("Next word: {}", next_word);
|
println!("Next word: {}", next_word);
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -15,20 +15,20 @@ extern crate failure;
|
|||||||
use rust_bert::pipelines::ner::NERModel;
|
use rust_bert::pipelines::ner::NERModel;
|
||||||
|
|
||||||
fn main() -> failure::Fallible<()> {
|
fn main() -> failure::Fallible<()> {
|
||||||
// Set-up model
|
// Set-up model
|
||||||
let ner_model = NERModel::new(Default::default())?;
|
let ner_model = NERModel::new(Default::default())?;
|
||||||
|
|
||||||
// Define input
|
// Define input
|
||||||
let input = [
|
let input = [
|
||||||
"My name is Amélie. I live in Москва.",
|
"My name is Amélie. I live in Москва.",
|
||||||
"Chongqing is a city in China."
|
"Chongqing is a city in China.",
|
||||||
];
|
];
|
||||||
|
|
||||||
// Run model
|
// Run model
|
||||||
let output = ner_model.predict(&input);
|
let output = ner_model.predict(&input);
|
||||||
for entity in output {
|
for entity in output {
|
||||||
println!("{:?}", entity);
|
println!("{:?}", entity);
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -12,66 +12,87 @@
|
|||||||
|
|
||||||
extern crate failure;
|
extern crate failure;
|
||||||
|
|
||||||
use tch::{Device, nn, Tensor};
|
|
||||||
use rust_tokenizers::{TruncationStrategy, Tokenizer, OpenAiGptTokenizer};
|
|
||||||
use rust_bert::gpt2::Gpt2Config;
|
use rust_bert::gpt2::Gpt2Config;
|
||||||
use rust_bert::openai_gpt::{OpenAIGPTLMHeadModel, OpenAiGptConfigResources, OpenAiGptVocabResources, OpenAiGptMergesResources, OpenAiGptModelResources};
|
use rust_bert::openai_gpt::{
|
||||||
use rust_bert::pipelines::generation::{LMHeadModel, Cache};
|
OpenAIGPTLMHeadModel, OpenAiGptConfigResources, OpenAiGptMergesResources,
|
||||||
use rust_bert::resources::{Resource, download_resource, RemoteResource};
|
OpenAiGptModelResources, OpenAiGptVocabResources,
|
||||||
|
};
|
||||||
|
use rust_bert::pipelines::generation::{Cache, LMHeadModel};
|
||||||
|
use rust_bert::resources::{download_resource, RemoteResource, Resource};
|
||||||
use rust_bert::Config;
|
use rust_bert::Config;
|
||||||
|
use rust_tokenizers::{OpenAiGptTokenizer, Tokenizer, TruncationStrategy};
|
||||||
|
use tch::{nn, Device, Tensor};
|
||||||
|
|
||||||
fn main() -> failure::Fallible<()> {
|
fn main() -> failure::Fallible<()> {
|
||||||
// Resources paths
|
// Resources paths
|
||||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptConfigResources::GPT));
|
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptVocabResources::GPT));
|
OpenAiGptConfigResources::GPT,
|
||||||
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptMergesResources::GPT));
|
));
|
||||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptModelResources::GPT));
|
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||||
|
OpenAiGptVocabResources::GPT,
|
||||||
|
));
|
||||||
|
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||||
|
OpenAiGptMergesResources::GPT,
|
||||||
|
));
|
||||||
|
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||||
|
OpenAiGptModelResources::GPT,
|
||||||
|
));
|
||||||
let config_path = download_resource(&config_resource)?;
|
let config_path = download_resource(&config_resource)?;
|
||||||
let vocab_path = download_resource(&vocab_resource)?;
|
let vocab_path = download_resource(&vocab_resource)?;
|
||||||
let merges_path = download_resource(&merges_resource)?;
|
let merges_path = download_resource(&merges_resource)?;
|
||||||
let weights_path = download_resource(&weights_resource)?;
|
let weights_path = download_resource(&weights_resource)?;
|
||||||
|
|
||||||
// Set-up masked LM model
|
// Set-up masked LM model
|
||||||
let device = Device::Cpu;
|
let device = Device::Cpu;
|
||||||
let mut vs = nn::VarStore::new(device);
|
let mut vs = nn::VarStore::new(device);
|
||||||
let tokenizer = OpenAiGptTokenizer::from_file(vocab_path.to_str().unwrap(), merges_path.to_str().unwrap(), true);
|
let tokenizer = OpenAiGptTokenizer::from_file(
|
||||||
|
vocab_path.to_str().unwrap(),
|
||||||
|
merges_path.to_str().unwrap(),
|
||||||
|
true,
|
||||||
|
);
|
||||||
let config = Gpt2Config::from_file(config_path);
|
let config = Gpt2Config::from_file(config_path);
|
||||||
let openai_gpt = OpenAIGPTLMHeadModel::new(&vs.root(), &config);
|
let openai_gpt = OpenAIGPTLMHeadModel::new(&vs.root(), &config);
|
||||||
vs.load(weights_path)?;
|
vs.load(weights_path)?;
|
||||||
|
|
||||||
// Define input
|
// Define input
|
||||||
let input = ["Wondering what the next word will"];
|
let input = ["Wondering what the next word will"];
|
||||||
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
let tokenized_input =
|
||||||
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
|
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
||||||
let tokenized_input = tokenized_input.
|
let max_len = tokenized_input
|
||||||
iter().
|
.iter()
|
||||||
map(|input| input.token_ids.clone()).
|
.map(|input| input.token_ids.len())
|
||||||
map(|mut input| {
|
.max()
|
||||||
|
.unwrap();
|
||||||
|
let tokenized_input = tokenized_input
|
||||||
|
.iter()
|
||||||
|
.map(|input| input.token_ids.clone())
|
||||||
|
.map(|mut input| {
|
||||||
input.extend(vec![0; max_len - input.len()]);
|
input.extend(vec![0; max_len - input.len()]);
|
||||||
input
|
input
|
||||||
}).
|
})
|
||||||
map(|input|
|
.map(|input| Tensor::of_slice(&(input)))
|
||||||
Tensor::of_slice(&(input))).
|
.collect::<Vec<_>>();
|
||||||
collect::<Vec<_>>();
|
|
||||||
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
|
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
|
||||||
|
|
||||||
// Forward pass
|
// Forward pass
|
||||||
let (output, _, _, _, _) = openai_gpt.forward_t(
|
let (output, _, _, _, _) = openai_gpt
|
||||||
&Some(input_tensor),
|
.forward_t(
|
||||||
Cache::None,
|
&Some(input_tensor),
|
||||||
&None,
|
Cache::None,
|
||||||
&None,
|
&None,
|
||||||
&None,
|
&None,
|
||||||
&None,
|
&None,
|
||||||
None,
|
&None,
|
||||||
&None,
|
None,
|
||||||
false).unwrap();
|
&None,
|
||||||
|
false,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
let next_word_id = output.get(0).get(-1).argmax(-1, true).int64_value(&[0]);
|
let next_word_id = output.get(0).get(-1).argmax(-1, true).int64_value(&[0]);
|
||||||
let next_word = tokenizer.decode(vec!(next_word_id), true, true);
|
let next_word = tokenizer.decode(vec![next_word_id], true, true);
|
||||||
println!("Provided input: {}", input[0]);
|
println!("Provided input: {}", input[0]);
|
||||||
println!("Next word: {}", next_word);
|
println!("Next word: {}", next_word);
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -12,23 +12,28 @@
|
|||||||
|
|
||||||
extern crate failure;
|
extern crate failure;
|
||||||
|
|
||||||
use rust_bert::pipelines::question_answering::{QuestionAnsweringModel, QaInput};
|
use rust_bert::pipelines::question_answering::{QaInput, QuestionAnsweringModel};
|
||||||
|
|
||||||
|
|
||||||
fn main() -> failure::Fallible<()> {
|
fn main() -> failure::Fallible<()> {
|
||||||
// Set-up Question Answering model
|
// Set-up Question Answering model
|
||||||
let qa_model = QuestionAnsweringModel::new(Default::default())?;
|
let qa_model = QuestionAnsweringModel::new(Default::default())?;
|
||||||
|
|
||||||
// Define input
|
// Define input
|
||||||
let question_1 = String::from("Where does Amy live ?");
|
let question_1 = String::from("Where does Amy live ?");
|
||||||
let context_1 = String::from("Amy lives in Amsterdam");
|
let context_1 = String::from("Amy lives in Amsterdam");
|
||||||
let question_2 = String::from("Where does Eric live");
|
let question_2 = String::from("Where does Eric live");
|
||||||
let context_2 = String::from("While Amy lives in Amsterdam, Eric is in The Hague.");
|
let context_2 = String::from("While Amy lives in Amsterdam, Eric is in The Hague.");
|
||||||
let qa_input_1 = QaInput { question: question_1, context: context_1 };
|
let qa_input_1 = QaInput {
|
||||||
let qa_input_2 = QaInput { question: question_2, context: context_2 };
|
question: question_1,
|
||||||
|
context: context_1,
|
||||||
|
};
|
||||||
|
let qa_input_2 = QaInput {
|
||||||
|
question: question_2,
|
||||||
|
context: context_2,
|
||||||
|
};
|
||||||
|
|
||||||
// Get answer
|
// Get answer
|
||||||
let answers = qa_model.predict(&vec!(qa_input_1, qa_input_2), 1, 32);
|
let answers = qa_model.predict(&vec![qa_input_1, qa_input_2], 1, 32);
|
||||||
println!("{:?}", answers);
|
println!("{:?}", answers);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -12,77 +12,99 @@
|
|||||||
|
|
||||||
extern crate failure;
|
extern crate failure;
|
||||||
|
|
||||||
use tch::{Device, nn, Tensor, no_grad};
|
|
||||||
use rust_tokenizers::{TruncationStrategy, Tokenizer, Vocab, RobertaTokenizer};
|
|
||||||
use rust_bert::Config;
|
|
||||||
use rust_bert::bert::BertConfig;
|
use rust_bert::bert::BertConfig;
|
||||||
use rust_bert::roberta::{RobertaForMaskedLM, RobertaVocabResources, RobertaConfigResources, RobertaMergesResources, RobertaModelResources};
|
use rust_bert::resources::{download_resource, RemoteResource, Resource};
|
||||||
use rust_bert::resources::{Resource, download_resource, RemoteResource};
|
use rust_bert::roberta::{
|
||||||
|
RobertaConfigResources, RobertaForMaskedLM, RobertaMergesResources, RobertaModelResources,
|
||||||
|
RobertaVocabResources,
|
||||||
|
};
|
||||||
|
use rust_bert::Config;
|
||||||
|
use rust_tokenizers::{RobertaTokenizer, Tokenizer, TruncationStrategy, Vocab};
|
||||||
|
use tch::{nn, no_grad, Device, Tensor};
|
||||||
|
|
||||||
fn main() -> failure::Fallible<()> {
|
fn main() -> failure::Fallible<()> {
|
||||||
// Resources paths
|
// Resources paths
|
||||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaConfigResources::ROBERTA));
|
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaVocabResources::ROBERTA));
|
RobertaConfigResources::ROBERTA,
|
||||||
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaMergesResources::ROBERTA));
|
));
|
||||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaModelResources::ROBERTA));
|
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||||
|
RobertaVocabResources::ROBERTA,
|
||||||
|
));
|
||||||
|
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||||
|
RobertaMergesResources::ROBERTA,
|
||||||
|
));
|
||||||
|
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||||
|
RobertaModelResources::ROBERTA,
|
||||||
|
));
|
||||||
let config_path = download_resource(&config_resource)?;
|
let config_path = download_resource(&config_resource)?;
|
||||||
let vocab_path = download_resource(&vocab_resource)?;
|
let vocab_path = download_resource(&vocab_resource)?;
|
||||||
let merges_path = download_resource(&merges_resource)?;
|
let merges_path = download_resource(&merges_resource)?;
|
||||||
let weights_path = download_resource(&weights_resource)?;
|
let weights_path = download_resource(&weights_resource)?;
|
||||||
|
|
||||||
// Set-up masked LM model
|
// Set-up masked LM model
|
||||||
let device = Device::Cpu;
|
let device = Device::Cpu;
|
||||||
let mut vs = nn::VarStore::new(device);
|
let mut vs = nn::VarStore::new(device);
|
||||||
let tokenizer: RobertaTokenizer = RobertaTokenizer::from_file(vocab_path.to_str().unwrap(), merges_path.to_str().unwrap(), true);
|
let tokenizer: RobertaTokenizer = RobertaTokenizer::from_file(
|
||||||
|
vocab_path.to_str().unwrap(),
|
||||||
|
merges_path.to_str().unwrap(),
|
||||||
|
true,
|
||||||
|
);
|
||||||
let config = BertConfig::from_file(config_path);
|
let config = BertConfig::from_file(config_path);
|
||||||
let bert_model = RobertaForMaskedLM::new(&vs.root(), &config);
|
let bert_model = RobertaForMaskedLM::new(&vs.root(), &config);
|
||||||
vs.load(weights_path)?;
|
vs.load(weights_path)?;
|
||||||
|
|
||||||
// Define input
|
// Define input
|
||||||
let input = ["<pad> Looks like one thing is missing", "It\'s like comparing oranges to apples"];
|
let input = [
|
||||||
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
"<pad> Looks like one thing is missing",
|
||||||
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
|
"It\'s like comparing oranges to apples",
|
||||||
let mut tokenized_input = tokenized_input.
|
];
|
||||||
iter().
|
let tokenized_input =
|
||||||
map(|input| input.token_ids.clone()).
|
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
||||||
map(|mut input| {
|
let max_len = tokenized_input
|
||||||
|
.iter()
|
||||||
|
.map(|input| input.token_ids.len())
|
||||||
|
.max()
|
||||||
|
.unwrap();
|
||||||
|
let mut tokenized_input = tokenized_input
|
||||||
|
.iter()
|
||||||
|
.map(|input| input.token_ids.clone())
|
||||||
|
.map(|mut input| {
|
||||||
input.extend(vec![0; max_len - input.len()]);
|
input.extend(vec![0; max_len - input.len()]);
|
||||||
input
|
input
|
||||||
}).
|
})
|
||||||
collect::<Vec<_>>();
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
// Masking the token [thing] of sentence 1 and [oranges] of sentence 2
|
// Masking the token [thing] of sentence 1 and [oranges] of sentence 2
|
||||||
tokenized_input[0][4] = 103;
|
tokenized_input[0][4] = 103;
|
||||||
tokenized_input[1][5] = 103;
|
tokenized_input[1][5] = 103;
|
||||||
let tokenized_input = tokenized_input.
|
let tokenized_input = tokenized_input
|
||||||
iter().
|
.iter()
|
||||||
map(|input|
|
.map(|input| Tensor::of_slice(&(input)))
|
||||||
Tensor::of_slice(&(input))).
|
.collect::<Vec<_>>();
|
||||||
collect::<Vec<_>>();
|
|
||||||
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
|
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
|
||||||
|
|
||||||
// Forward pass
|
// Forward pass
|
||||||
let (output, _, _) = no_grad(|| {
|
let (output, _, _) = no_grad(|| {
|
||||||
bert_model
|
bert_model.forward_t(
|
||||||
.forward_t(Some(input_tensor),
|
Some(input_tensor),
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
&None,
|
&None,
|
||||||
&None,
|
&None,
|
||||||
false)
|
false,
|
||||||
|
)
|
||||||
});
|
});
|
||||||
|
|
||||||
// Print masked tokens
|
// Print masked tokens
|
||||||
let index_1 = output.get(0).get(4).argmax(0, false);
|
let index_1 = output.get(0).get(4).argmax(0, false);
|
||||||
let index_2 = output.get(1).get(5).argmax(0, false);
|
let index_2 = output.get(1).get(5).argmax(0, false);
|
||||||
let word_1 = tokenizer.vocab().id_to_token(&index_1.int64_value(&[]));
|
let word_1 = tokenizer.vocab().id_to_token(&index_1.int64_value(&[]));
|
||||||
let word_2 = tokenizer.vocab().id_to_token(&index_2.int64_value(&[]));
|
let word_2 = tokenizer.vocab().id_to_token(&index_2.int64_value(&[]));
|
||||||
|
|
||||||
println!("{}", word_1); // Outputs "some" : "Looks like [some] thing is missing"
|
println!("{}", word_1); // Outputs "some" : "Looks like [some] thing is missing"
|
||||||
println!("{}", word_2);// Outputs "apple" : "It\'s like comparing [apple] to apples"
|
println!("{}", word_2); // Outputs "apple" : "It\'s like comparing [apple] to apples"
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -14,23 +14,22 @@ extern crate failure;
|
|||||||
|
|
||||||
use rust_bert::pipelines::sentiment::SentimentModel;
|
use rust_bert::pipelines::sentiment::SentimentModel;
|
||||||
|
|
||||||
|
|
||||||
fn main() -> failure::Fallible<()> {
|
fn main() -> failure::Fallible<()> {
|
||||||
// Set-up classifier
|
// Set-up classifier
|
||||||
let sentiment_classifier = SentimentModel::new(Default::default())?;
|
let sentiment_classifier = SentimentModel::new(Default::default())?;
|
||||||
|
|
||||||
// Define input
|
// Define input
|
||||||
let input = [
|
let input = [
|
||||||
"Probably my all-time favorite movie, a story of selflessness, sacrifice and dedication to a noble cause, but it's not preachy or boring.",
|
"Probably my all-time favorite movie, a story of selflessness, sacrifice and dedication to a noble cause, but it's not preachy or boring.",
|
||||||
"This film tried to be too many things all at once: stinging political satire, Hollywood blockbuster, sappy romantic comedy, family values promo...",
|
"This film tried to be too many things all at once: stinging political satire, Hollywood blockbuster, sappy romantic comedy, family values promo...",
|
||||||
"If you like original gut wrenching laughter you will like this movie. If you are young or old then you will love this movie, hell even my mom liked it.",
|
"If you like original gut wrenching laughter you will like this movie. If you are young or old then you will love this movie, hell even my mom liked it.",
|
||||||
];
|
];
|
||||||
|
|
||||||
// Run model
|
// Run model
|
||||||
let output = sentiment_classifier.predict(&input);
|
let output = sentiment_classifier.predict(&input);
|
||||||
for sentiment in output {
|
for sentiment in output {
|
||||||
println!("{:?}", sentiment);
|
println!("{:?}", sentiment);
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -15,21 +15,21 @@ extern crate failure;
|
|||||||
use rust_bert::pipelines::sequence_classification::SequenceClassificationModel;
|
use rust_bert::pipelines::sequence_classification::SequenceClassificationModel;
|
||||||
|
|
||||||
fn main() -> failure::Fallible<()> {
|
fn main() -> failure::Fallible<()> {
|
||||||
// Set-up model
|
// Set-up model
|
||||||
let sequence_classification_model = SequenceClassificationModel::new(Default::default())?;
|
let sequence_classification_model = SequenceClassificationModel::new(Default::default())?;
|
||||||
|
|
||||||
// Define input
|
// Define input
|
||||||
let input = [
|
let input = [
|
||||||
"Probably my all-time favorite movie, a story of selflessness, sacrifice and dedication to a noble cause, but it's not preachy or boring.",
|
"Probably my all-time favorite movie, a story of selflessness, sacrifice and dedication to a noble cause, but it's not preachy or boring.",
|
||||||
"This film tried to be too many things all at once: stinging political satire, Hollywood blockbuster, sappy romantic comedy, family values promo...",
|
"This film tried to be too many things all at once: stinging political satire, Hollywood blockbuster, sappy romantic comedy, family values promo...",
|
||||||
"If you like original gut wrenching laughter you will like this movie. If you are young or old then you will love this movie, hell even my mom liked it.",
|
"If you like original gut wrenching laughter you will like this movie. If you are young or old then you will love this movie, hell even my mom liked it.",
|
||||||
];
|
];
|
||||||
|
|
||||||
// Run model
|
// Run model
|
||||||
let output = sequence_classification_model.predict(&input);
|
let output = sequence_classification_model.predict(&input);
|
||||||
for label in output {
|
for label in output {
|
||||||
println!("{:?}", label);
|
println!("{:?}", label);
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -15,21 +15,21 @@ extern crate failure;
|
|||||||
use rust_bert::pipelines::sequence_classification::SequenceClassificationModel;
|
use rust_bert::pipelines::sequence_classification::SequenceClassificationModel;
|
||||||
|
|
||||||
fn main() -> failure::Fallible<()> {
|
fn main() -> failure::Fallible<()> {
|
||||||
// Set-up model
|
// Set-up model
|
||||||
let sequence_classification_model = SequenceClassificationModel::new(Default::default())?;
|
let sequence_classification_model = SequenceClassificationModel::new(Default::default())?;
|
||||||
|
|
||||||
// Define input
|
// Define input
|
||||||
let input = [
|
let input = [
|
||||||
"Probably my all-time favorite movie, a story of selflessness, sacrifice and dedication to a noble cause, but it's not preachy or boring.",
|
"Probably my all-time favorite movie, a story of selflessness, sacrifice and dedication to a noble cause, but it's not preachy or boring.",
|
||||||
"This is a neutral sentence.",
|
"This is a neutral sentence.",
|
||||||
"If you like original gut wrenching laughter you will like this movie. If you are young or old then you will love this movie, hell even my mom liked it.",
|
"If you like original gut wrenching laughter you will like this movie. If you are young or old then you will love this movie, hell even my mom liked it.",
|
||||||
];
|
];
|
||||||
|
|
||||||
// Run model
|
// Run model
|
||||||
let output = sequence_classification_model.predict_multilabel(&input, 0.05);
|
let output = sequence_classification_model.predict_multilabel(&input, 0.05);
|
||||||
for label in output {
|
for label in output {
|
||||||
println!("{:?}", label);
|
println!("{:?}", label);
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -12,24 +12,23 @@
|
|||||||
|
|
||||||
extern crate failure;
|
extern crate failure;
|
||||||
|
|
||||||
use std::path::PathBuf;
|
use rust_bert::pipelines::question_answering::{squad_processor, QuestionAnsweringModel};
|
||||||
use std::env;
|
use std::env;
|
||||||
use rust_bert::pipelines::question_answering::{QuestionAnsweringModel, squad_processor};
|
use std::path::PathBuf;
|
||||||
|
|
||||||
|
|
||||||
fn main() -> failure::Fallible<()> {
|
fn main() -> failure::Fallible<()> {
|
||||||
// Set-up Question Answering model
|
// Set-up Question Answering model
|
||||||
let qa_model = QuestionAnsweringModel::new(Default::default())?;
|
let qa_model = QuestionAnsweringModel::new(Default::default())?;
|
||||||
|
|
||||||
// Define input
|
// Define input
|
||||||
let mut squad_path = PathBuf::from(env::var("squad_dataset")
|
let mut squad_path = PathBuf::from(env::var("squad_dataset")
|
||||||
.expect("Please set the \"squad_dataset\" environment variable pointing to the SQuAD dataset folder"));
|
.expect("Please set the \"squad_dataset\" environment variable pointing to the SQuAD dataset folder"));
|
||||||
squad_path.push("dev-v2.0.json");
|
squad_path.push("dev-v2.0.json");
|
||||||
let qa_inputs = squad_processor(squad_path);
|
let qa_inputs = squad_processor(squad_path);
|
||||||
|
|
||||||
// Get answer
|
// Get answer
|
||||||
let answers = qa_model.predict(&qa_inputs, 1, 64);
|
let answers = qa_model.predict(&qa_inputs, 1, 64);
|
||||||
println!("Sample answer: {:?}", answers.first().unwrap());
|
println!("Sample answer: {:?}", answers.first().unwrap());
|
||||||
println!("{}", answers.len());
|
println!("{}", answers.len());
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -10,35 +10,42 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
extern crate failure;
|
|
||||||
extern crate dirs;
|
extern crate dirs;
|
||||||
|
extern crate failure;
|
||||||
|
|
||||||
use std::path::PathBuf;
|
use rust_bert::pipelines::sentiment::{ss2_processor, SentimentModel};
|
||||||
use rust_bert::pipelines::sentiment::{SentimentModel, ss2_processor};
|
|
||||||
use std::env;
|
use std::env;
|
||||||
|
use std::path::PathBuf;
|
||||||
|
|
||||||
fn main() -> failure::Fallible<()> {
|
fn main() -> failure::Fallible<()> {
|
||||||
// Set-up classifier
|
// Set-up classifier
|
||||||
let sentiment_classifier = SentimentModel::new(Default::default())?;
|
let sentiment_classifier = SentimentModel::new(Default::default())?;
|
||||||
|
|
||||||
// Define input
|
// Define input
|
||||||
let mut sst2_path = PathBuf::from(env::var("SST2_PATH")
|
let mut sst2_path = PathBuf::from(env::var("SST2_PATH")
|
||||||
.expect("Please set the \"squad_dataset\" environment variable pointing to the SQuAD dataset folder"));
|
.expect("Please set the \"squad_dataset\" environment variable pointing to the SQuAD dataset folder"));
|
||||||
sst2_path.push("train.tsv");
|
sst2_path.push("train.tsv");
|
||||||
let inputs = ss2_processor(sst2_path).unwrap();
|
let inputs = ss2_processor(sst2_path).unwrap();
|
||||||
|
|
||||||
// Run model
|
// Run model
|
||||||
let batch_size = 64;
|
let batch_size = 64;
|
||||||
let mut output = vec!();
|
let mut output = vec![];
|
||||||
for batch in inputs.chunks(batch_size) {
|
for batch in inputs.chunks(batch_size) {
|
||||||
output.push(sentiment_classifier.predict(batch.iter().map(|v| v.as_str()).collect::<Vec<&str>>().as_slice()));
|
output.push(
|
||||||
|
sentiment_classifier.predict(
|
||||||
|
batch
|
||||||
|
.iter()
|
||||||
|
.map(|v| v.as_str())
|
||||||
|
.collect::<Vec<&str>>()
|
||||||
|
.as_slice(),
|
||||||
|
),
|
||||||
|
);
|
||||||
}
|
}
|
||||||
let mut flat_outputs = vec!();
|
let mut flat_outputs = vec![];
|
||||||
for batch_output in output.iter_mut() {
|
for batch_output in output.iter_mut() {
|
||||||
flat_outputs.append(batch_output);
|
flat_outputs.append(batch_output);
|
||||||
}
|
}
|
||||||
println!("{:?}", flat_outputs.len());
|
println!("{:?}", flat_outputs.len());
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -14,7 +14,6 @@ extern crate failure;
|
|||||||
|
|
||||||
use rust_bert::pipelines::summarization::SummarizationModel;
|
use rust_bert::pipelines::summarization::SummarizationModel;
|
||||||
|
|
||||||
|
|
||||||
fn main() -> failure::Fallible<()> {
|
fn main() -> failure::Fallible<()> {
|
||||||
let summarization_model = SummarizationModel::new(Default::default())?;
|
let summarization_model = SummarizationModel::new(Default::default())?;
|
||||||
|
|
||||||
@ -40,11 +39,11 @@ on K2-18b lasts 33 Earth days. According to The Guardian, astronomers were optim
|
|||||||
telescope — scheduled for launch in 2021 — and the European Space Agency's 2028 ARIEL program, could reveal more \
|
telescope — scheduled for launch in 2021 — and the European Space Agency's 2028 ARIEL program, could reveal more \
|
||||||
about exoplanets like K2-18b."];
|
about exoplanets like K2-18b."];
|
||||||
|
|
||||||
// Credits: WikiNews, CC BY 2.5 license (https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b)
|
// Credits: WikiNews, CC BY 2.5 license (https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b)
|
||||||
let _output = summarization_model.summarize(&input);
|
let _output = summarization_model.summarize(&input);
|
||||||
for sentence in _output {
|
for sentence in _output {
|
||||||
println!("{:?}", sentence);
|
println!("{:?}", sentence);
|
||||||
};
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -10,28 +10,36 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
use rust_bert::pipelines::token_classification::{TokenClassificationModel, TokenClassificationConfig, LabelAggregationOption};
|
use rust_bert::bert::{BertConfigResources, BertModelResources, BertVocabResources};
|
||||||
use rust_bert::resources::{Resource, RemoteResource};
|
|
||||||
use rust_bert::bert::{BertModelResources, BertVocabResources, BertConfigResources};
|
|
||||||
use rust_bert::pipelines::common::ModelType;
|
use rust_bert::pipelines::common::ModelType;
|
||||||
|
use rust_bert::pipelines::token_classification::{
|
||||||
|
LabelAggregationOption, TokenClassificationConfig, TokenClassificationModel,
|
||||||
|
};
|
||||||
|
use rust_bert::resources::{RemoteResource, Resource};
|
||||||
|
|
||||||
fn main() -> failure::Fallible<()> {
|
fn main() -> failure::Fallible<()> {
|
||||||
|
// Load a configuration
|
||||||
// Load a configuration
|
let config = TokenClassificationConfig::new(
|
||||||
let config = TokenClassificationConfig::new(ModelType::Bert,
|
ModelType::Bert,
|
||||||
Resource::Remote(RemoteResource::from_pretrained(BertModelResources::BERT_NER)),
|
Resource::Remote(RemoteResource::from_pretrained(
|
||||||
Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT_NER)),
|
BertModelResources::BERT_NER,
|
||||||
Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT_NER)),
|
)),
|
||||||
None, //merges resource only relevant with ModelType::Roberta
|
Resource::Remote(RemoteResource::from_pretrained(
|
||||||
false, //lowercase
|
BertConfigResources::BERT_NER,
|
||||||
LabelAggregationOption::Mode,
|
)),
|
||||||
|
Resource::Remote(RemoteResource::from_pretrained(
|
||||||
|
BertVocabResources::BERT_NER,
|
||||||
|
)),
|
||||||
|
None, //merges resource only relevant with ModelType::Roberta
|
||||||
|
false, //lowercase
|
||||||
|
LabelAggregationOption::Mode,
|
||||||
);
|
);
|
||||||
|
|
||||||
// Create the model
|
// Create the model
|
||||||
let token_classification_model = TokenClassificationModel::new(config)?;
|
let token_classification_model = TokenClassificationModel::new(config)?;
|
||||||
let input = [
|
let input = [
|
||||||
"My name is Amélie. I live in Москва.",
|
"My name is Amélie. I live in Москва.",
|
||||||
"Chongqing is a city in China."
|
"Chongqing is a city in China.",
|
||||||
];
|
];
|
||||||
let token_outputs = token_classification_model.predict(&input, true, false); //ignore_first_label = true (only returns the NER parts, ignoring first label O)
|
let token_outputs = token_classification_model.predict(&input, true, false); //ignore_first_label = true (only returns the NER parts, ignoring first label O)
|
||||||
|
|
||||||
@ -40,4 +48,4 @@ fn main() -> failure::Fallible<()> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -13,12 +13,12 @@
|
|||||||
|
|
||||||
extern crate failure;
|
extern crate failure;
|
||||||
|
|
||||||
use rust_bert::pipelines::translation::{TranslationConfig, TranslationModel, Language};
|
use rust_bert::pipelines::translation::{Language, TranslationConfig, TranslationModel};
|
||||||
use tch::Device;
|
use tch::Device;
|
||||||
|
|
||||||
fn main() -> failure::Fallible<()> {
|
fn main() -> failure::Fallible<()> {
|
||||||
|
let translation_config =
|
||||||
let translation_config = TranslationConfig::new(Language::EnglishToGerman, Device::cuda_if_available());
|
TranslationConfig::new(Language::EnglishToGerman, Device::cuda_if_available());
|
||||||
let model = TranslationModel::new(translation_config)?;
|
let model = TranslationModel::new(translation_config)?;
|
||||||
|
|
||||||
let input_context_1 = "The quick brown fox jumps over the lazy dog";
|
let input_context_1 = "The quick brown fox jumps over the lazy dog";
|
||||||
@ -30,4 +30,4 @@ fn main() -> failure::Fallible<()> {
|
|||||||
println!("{}", sentence);
|
println!("{}", sentence);
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
1
rustfmt.toml
Normal file
1
rustfmt.toml
Normal file
@ -0,0 +1 @@
|
|||||||
|
format_code_in_doc_comments = true
|
@ -11,16 +11,15 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
use std::collections::HashMap;
|
|
||||||
use crate::Config;
|
|
||||||
use serde::{Deserialize, Serialize};
|
|
||||||
use crate::albert::embeddings::AlbertEmbeddings;
|
use crate::albert::embeddings::AlbertEmbeddings;
|
||||||
use crate::albert::encoder::AlbertTransformer;
|
use crate::albert::encoder::AlbertTransformer;
|
||||||
use tch::{nn, Tensor, Kind};
|
use crate::common::activations::{_gelu, _gelu_new, _mish, _relu, _tanh};
|
||||||
use crate::common::activations::{_tanh, _gelu_new, _gelu, _relu, _mish};
|
|
||||||
use tch::nn::Module;
|
|
||||||
use crate::common::dropout::Dropout;
|
use crate::common::dropout::Dropout;
|
||||||
|
use crate::Config;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use tch::nn::Module;
|
||||||
|
use tch::{nn, Kind, Tensor};
|
||||||
|
|
||||||
/// # ALBERT Pretrained model weight files
|
/// # ALBERT Pretrained model weight files
|
||||||
pub struct AlbertModelResources;
|
pub struct AlbertModelResources;
|
||||||
@ -33,20 +32,28 @@ pub struct AlbertVocabResources;
|
|||||||
|
|
||||||
impl AlbertModelResources {
|
impl AlbertModelResources {
|
||||||
/// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/ALBERT. Modified with conversion to C-array format.
|
/// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/ALBERT. Modified with conversion to C-array format.
|
||||||
pub const ALBERT_BASE_V2: (&'static str, &'static str) = ("albert-base-v2/model.ot", "https://cdn.huggingface.co/albert-base-v2/rust_model.ot");
|
pub const ALBERT_BASE_V2: (&'static str, &'static str) = (
|
||||||
|
"albert-base-v2/model.ot",
|
||||||
|
"https://cdn.huggingface.co/albert-base-v2/rust_model.ot",
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
impl AlbertConfigResources {
|
impl AlbertConfigResources {
|
||||||
/// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/ALBERT. Modified with conversion to C-array format.
|
/// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/ALBERT. Modified with conversion to C-array format.
|
||||||
pub const ALBERT_BASE_V2: (&'static str, &'static str) = ("albert-base-v2/config.json", "https://cdn.huggingface.co/albert-base-v2-config.json");
|
pub const ALBERT_BASE_V2: (&'static str, &'static str) = (
|
||||||
|
"albert-base-v2/config.json",
|
||||||
|
"https://cdn.huggingface.co/albert-base-v2-config.json",
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
impl AlbertVocabResources {
|
impl AlbertVocabResources {
|
||||||
/// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/ALBERT. Modified with conversion to C-array format.
|
/// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/ALBERT. Modified with conversion to C-array format.
|
||||||
pub const ALBERT_BASE_V2: (&'static str, &'static str) = ("albert-base-v2/spiece.model", "https://cdn.huggingface.co/albert-base-v2-spiece.model");
|
pub const ALBERT_BASE_V2: (&'static str, &'static str) = (
|
||||||
|
"albert-base-v2/spiece.model",
|
||||||
|
"https://cdn.huggingface.co/albert-base-v2-spiece.model",
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
#[allow(non_camel_case_types)]
|
#[allow(non_camel_case_types)]
|
||||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||||
/// # Activation function used in the attention layer and masked language model head
|
/// # Activation function used in the attention layer and masked language model head
|
||||||
@ -61,7 +68,6 @@ pub enum Activation {
|
|||||||
mish,
|
mish,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
/// # ALBERT model configuration
|
/// # ALBERT model configuration
|
||||||
/// Defines the ALBERT model architecture (e.g. number of layers, hidden layer size, label mapping...)
|
/// Defines the ALBERT model architecture (e.g. number of layers, hidden layer size, label mapping...)
|
||||||
@ -123,10 +129,10 @@ impl AlbertModel {
|
|||||||
/// # Example
|
/// # Example
|
||||||
///
|
///
|
||||||
/// ```no_run
|
/// ```no_run
|
||||||
/// use tch::{nn, Device};
|
/// use rust_bert::albert::{AlbertConfig, AlbertModel};
|
||||||
/// use rust_bert::Config;
|
/// use rust_bert::Config;
|
||||||
/// use std::path::Path;
|
/// use std::path::Path;
|
||||||
/// use rust_bert::albert::{AlbertConfig, AlbertModel};
|
/// use tch::{nn, Device};
|
||||||
///
|
///
|
||||||
/// let config_path = Path::new("path/to/config.json");
|
/// let config_path = Path::new("path/to/config.json");
|
||||||
/// let device = Device::Cpu;
|
/// let device = Device::Cpu;
|
||||||
@ -134,14 +140,23 @@ impl AlbertModel {
|
|||||||
/// let config = AlbertConfig::from_file(config_path);
|
/// let config = AlbertConfig::from_file(config_path);
|
||||||
/// let albert: AlbertModel = AlbertModel::new(&(&p.root() / "albert"), &config);
|
/// let albert: AlbertModel = AlbertModel::new(&(&p.root() / "albert"), &config);
|
||||||
/// ```
|
/// ```
|
||||||
///
|
|
||||||
pub fn new(p: &nn::Path, config: &AlbertConfig) -> AlbertModel {
|
pub fn new(p: &nn::Path, config: &AlbertConfig) -> AlbertModel {
|
||||||
let embeddings = AlbertEmbeddings::new(&(p / "embeddings"), config);
|
let embeddings = AlbertEmbeddings::new(&(p / "embeddings"), config);
|
||||||
let encoder = AlbertTransformer::new(&(p / "encoder"), config);
|
let encoder = AlbertTransformer::new(&(p / "encoder"), config);
|
||||||
let pooler = nn::linear(&(p / "pooler"), config.hidden_size, config.hidden_size, Default::default());
|
let pooler = nn::linear(
|
||||||
|
&(p / "pooler"),
|
||||||
|
config.hidden_size,
|
||||||
|
config.hidden_size,
|
||||||
|
Default::default(),
|
||||||
|
);
|
||||||
let pooler_activation = Box::new(_tanh);
|
let pooler_activation = Box::new(_tanh);
|
||||||
|
|
||||||
AlbertModel { embeddings, encoder, pooler, pooler_activation }
|
AlbertModel {
|
||||||
|
embeddings,
|
||||||
|
encoder,
|
||||||
|
pooler,
|
||||||
|
pooler_activation,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Forward pass through the model
|
/// Forward pass through the model
|
||||||
@ -165,75 +180,103 @@ impl AlbertModel {
|
|||||||
/// # Example
|
/// # Example
|
||||||
///
|
///
|
||||||
/// ```no_run
|
/// ```no_run
|
||||||
///# use tch::{nn, Device, Tensor, no_grad};
|
/// # use tch::{nn, Device, Tensor, no_grad};
|
||||||
///# use rust_bert::Config;
|
/// # use rust_bert::Config;
|
||||||
///# use std::path::Path;
|
/// # use std::path::Path;
|
||||||
///# use tch::kind::Kind::Int64;
|
/// # use tch::kind::Kind::Int64;
|
||||||
/// use rust_bert::albert::{AlbertConfig, AlbertModel};
|
/// use rust_bert::albert::{AlbertConfig, AlbertModel};
|
||||||
///# let config_path = Path::new("path/to/config.json");
|
/// # let config_path = Path::new("path/to/config.json");
|
||||||
///# let device = Device::Cpu;
|
/// # let device = Device::Cpu;
|
||||||
///# let vs = nn::VarStore::new(device);
|
/// # let vs = nn::VarStore::new(device);
|
||||||
///# let config = AlbertConfig::from_file(config_path);
|
/// # let config = AlbertConfig::from_file(config_path);
|
||||||
///# let albert_model: AlbertModel = AlbertModel::new(&vs.root(), &config);
|
/// # let albert_model: AlbertModel = AlbertModel::new(&vs.root(), &config);
|
||||||
/// let (batch_size, sequence_length) = (64, 128);
|
/// let (batch_size, sequence_length) = (64, 128);
|
||||||
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
|
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
|
||||||
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
|
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
|
||||||
/// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
|
/// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
|
||||||
/// let position_ids = Tensor::arange(sequence_length, (Int64, device)).expand(&[batch_size, sequence_length], true);
|
/// let position_ids = Tensor::arange(sequence_length, (Int64, device))
|
||||||
///
|
/// .expand(&[batch_size, sequence_length], true);
|
||||||
/// let (output, pooled_output, all_hidden_states, all_attentions) = no_grad(|| {
|
|
||||||
/// albert_model
|
|
||||||
/// .forward_t(Some(input_tensor),
|
|
||||||
/// Some(mask),
|
|
||||||
/// Some(token_type_ids),
|
|
||||||
/// Some(position_ids),
|
|
||||||
/// None,
|
|
||||||
/// false).unwrap()
|
|
||||||
/// });
|
|
||||||
///
|
///
|
||||||
|
/// let (output, pooled_output, all_hidden_states, all_attentions) = no_grad(|| {
|
||||||
|
/// albert_model
|
||||||
|
/// .forward_t(
|
||||||
|
/// Some(input_tensor),
|
||||||
|
/// Some(mask),
|
||||||
|
/// Some(token_type_ids),
|
||||||
|
/// Some(position_ids),
|
||||||
|
/// None,
|
||||||
|
/// false,
|
||||||
|
/// )
|
||||||
|
/// .unwrap()
|
||||||
|
/// });
|
||||||
/// ```
|
/// ```
|
||||||
///
|
pub fn forward_t(
|
||||||
pub fn forward_t(&self,
|
&self,
|
||||||
input_ids: Option<Tensor>,
|
input_ids: Option<Tensor>,
|
||||||
mask: Option<Tensor>,
|
mask: Option<Tensor>,
|
||||||
token_type_ids: Option<Tensor>,
|
token_type_ids: Option<Tensor>,
|
||||||
position_ids: Option<Tensor>,
|
position_ids: Option<Tensor>,
|
||||||
input_embeds: Option<Tensor>,
|
input_embeds: Option<Tensor>,
|
||||||
train: bool)
|
train: bool,
|
||||||
-> Result<(Tensor, Tensor, Option<Vec<Tensor>>, Option<Vec<Vec<Tensor>>>), &'static str> {
|
) -> Result<
|
||||||
|
(
|
||||||
|
Tensor,
|
||||||
|
Tensor,
|
||||||
|
Option<Vec<Tensor>>,
|
||||||
|
Option<Vec<Vec<Tensor>>>,
|
||||||
|
),
|
||||||
|
&'static str,
|
||||||
|
> {
|
||||||
let (input_shape, device) = match &input_ids {
|
let (input_shape, device) = match &input_ids {
|
||||||
Some(input_value) => match &input_embeds {
|
Some(input_value) => match &input_embeds {
|
||||||
Some(_) => { return Err("Only one of input ids or input embeddings may be set"); }
|
Some(_) => {
|
||||||
None => (input_value.size(), input_value.device())
|
return Err("Only one of input ids or input embeddings may be set");
|
||||||
}
|
}
|
||||||
|
None => (input_value.size(), input_value.device()),
|
||||||
|
},
|
||||||
None => match &input_embeds {
|
None => match &input_embeds {
|
||||||
Some(embeds) => (vec!(embeds.size()[0], embeds.size()[1]), embeds.device()),
|
Some(embeds) => (vec![embeds.size()[0], embeds.size()[1]], embeds.device()),
|
||||||
None => { return Err("At least one of input ids or input embeddings must be set"); }
|
None => {
|
||||||
}
|
return Err("At least one of input ids or input embeddings must be set");
|
||||||
|
}
|
||||||
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
let mask = match mask {
|
let mask = match mask {
|
||||||
Some(value) => value,
|
Some(value) => value,
|
||||||
None => Tensor::ones(&input_shape, (Kind::Int64, device))
|
None => Tensor::ones(&input_shape, (Kind::Int64, device)),
|
||||||
};
|
};
|
||||||
|
|
||||||
let extended_attention_mask = mask.unsqueeze(1).unsqueeze(2);
|
let extended_attention_mask = mask.unsqueeze(1).unsqueeze(2);
|
||||||
let extended_attention_mask: Tensor = (extended_attention_mask.ones_like() - extended_attention_mask) * -10000.0;
|
let extended_attention_mask: Tensor =
|
||||||
|
(extended_attention_mask.ones_like() - extended_attention_mask) * -10000.0;
|
||||||
|
|
||||||
let embedding_output = match self.embeddings.forward_t(input_ids, token_type_ids, position_ids, input_embeds, train) {
|
let embedding_output = match self.embeddings.forward_t(
|
||||||
|
input_ids,
|
||||||
|
token_type_ids,
|
||||||
|
position_ids,
|
||||||
|
input_embeds,
|
||||||
|
train,
|
||||||
|
) {
|
||||||
Ok(value) => value,
|
Ok(value) => value,
|
||||||
Err(e) => { return Err(e); }
|
Err(e) => {
|
||||||
|
return Err(e);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let (hidden_state, all_hidden_states, all_attentions) =
|
let (hidden_state, all_hidden_states, all_attentions) =
|
||||||
self.encoder.forward_t(&embedding_output,
|
self.encoder
|
||||||
Some(extended_attention_mask),
|
.forward_t(&embedding_output, Some(extended_attention_mask), train);
|
||||||
train);
|
|
||||||
|
|
||||||
let pooled_output = self.pooler.forward(&hidden_state.select(1, 0));
|
let pooled_output = self.pooler.forward(&hidden_state.select(1, 0));
|
||||||
let pooled_output = (self.pooler_activation)(&pooled_output);
|
let pooled_output = (self.pooler_activation)(&pooled_output);
|
||||||
|
|
||||||
Ok((hidden_state, pooled_output, all_hidden_states, all_attentions))
|
Ok((
|
||||||
|
hidden_state,
|
||||||
|
pooled_output,
|
||||||
|
all_hidden_states,
|
||||||
|
all_attentions,
|
||||||
|
))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -248,21 +291,43 @@ impl AlbertMLMHead {
|
|||||||
pub fn new(p: &nn::Path, config: &AlbertConfig) -> AlbertMLMHead {
|
pub fn new(p: &nn::Path, config: &AlbertConfig) -> AlbertMLMHead {
|
||||||
let layer_norm_eps = match config.layer_norm_eps {
|
let layer_norm_eps = match config.layer_norm_eps {
|
||||||
Some(value) => value,
|
Some(value) => value,
|
||||||
None => 1e-12
|
None => 1e-12,
|
||||||
};
|
};
|
||||||
let layer_norm_config = nn::LayerNormConfig { eps: layer_norm_eps, ..Default::default() };
|
let layer_norm_config = nn::LayerNormConfig {
|
||||||
let layer_norm = nn::layer_norm(&(p / "LayerNorm"), vec![config.embedding_size], layer_norm_config);
|
eps: layer_norm_eps,
|
||||||
let dense = nn::linear(&(p / "dense"), config.hidden_size, config.embedding_size, Default::default());
|
..Default::default()
|
||||||
let decoder = nn::linear(&(p / "decoder"), config.embedding_size, config.vocab_size, Default::default());
|
};
|
||||||
|
let layer_norm = nn::layer_norm(
|
||||||
|
&(p / "LayerNorm"),
|
||||||
|
vec![config.embedding_size],
|
||||||
|
layer_norm_config,
|
||||||
|
);
|
||||||
|
let dense = nn::linear(
|
||||||
|
&(p / "dense"),
|
||||||
|
config.hidden_size,
|
||||||
|
config.embedding_size,
|
||||||
|
Default::default(),
|
||||||
|
);
|
||||||
|
let decoder = nn::linear(
|
||||||
|
&(p / "decoder"),
|
||||||
|
config.embedding_size,
|
||||||
|
config.vocab_size,
|
||||||
|
Default::default(),
|
||||||
|
);
|
||||||
|
|
||||||
let activation = Box::new(match &config.hidden_act {
|
let activation = Box::new(match &config.hidden_act {
|
||||||
Activation::gelu_new => _gelu_new,
|
Activation::gelu_new => _gelu_new,
|
||||||
Activation::gelu => _gelu,
|
Activation::gelu => _gelu,
|
||||||
Activation::relu => _relu,
|
Activation::relu => _relu,
|
||||||
Activation::mish => _mish
|
Activation::mish => _mish,
|
||||||
});
|
});
|
||||||
|
|
||||||
AlbertMLMHead { layer_norm, dense, decoder, activation }
|
AlbertMLMHead {
|
||||||
|
layer_norm,
|
||||||
|
dense,
|
||||||
|
decoder,
|
||||||
|
activation,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn forward(&self, hidden_states: &Tensor) -> Tensor {
|
pub fn forward(&self, hidden_states: &Tensor) -> Tensor {
|
||||||
@ -292,10 +357,10 @@ impl AlbertForMaskedLM {
|
|||||||
/// # Example
|
/// # Example
|
||||||
///
|
///
|
||||||
/// ```no_run
|
/// ```no_run
|
||||||
/// use tch::{nn, Device};
|
/// use rust_bert::albert::{AlbertConfig, AlbertForMaskedLM};
|
||||||
/// use rust_bert::Config;
|
/// use rust_bert::Config;
|
||||||
/// use std::path::Path;
|
/// use std::path::Path;
|
||||||
/// use rust_bert::albert::{AlbertConfig, AlbertForMaskedLM};
|
/// use tch::{nn, Device};
|
||||||
///
|
///
|
||||||
/// let config_path = Path::new("path/to/config.json");
|
/// let config_path = Path::new("path/to/config.json");
|
||||||
/// let device = Device::Cpu;
|
/// let device = Device::Cpu;
|
||||||
@ -303,12 +368,14 @@ impl AlbertForMaskedLM {
|
|||||||
/// let config = AlbertConfig::from_file(config_path);
|
/// let config = AlbertConfig::from_file(config_path);
|
||||||
/// let albert: AlbertForMaskedLM = AlbertForMaskedLM::new(&p.root(), &config);
|
/// let albert: AlbertForMaskedLM = AlbertForMaskedLM::new(&p.root(), &config);
|
||||||
/// ```
|
/// ```
|
||||||
///
|
|
||||||
pub fn new(p: &nn::Path, config: &AlbertConfig) -> AlbertForMaskedLM {
|
pub fn new(p: &nn::Path, config: &AlbertConfig) -> AlbertForMaskedLM {
|
||||||
let albert = AlbertModel::new(&(p / "albert"), config);
|
let albert = AlbertModel::new(&(p / "albert"), config);
|
||||||
let predictions = AlbertMLMHead::new(&(p / "predictions"), config);
|
let predictions = AlbertMLMHead::new(&(p / "predictions"), config);
|
||||||
|
|
||||||
AlbertForMaskedLM { albert, predictions }
|
AlbertForMaskedLM {
|
||||||
|
albert,
|
||||||
|
predictions,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Forward pass through the model
|
/// Forward pass through the model
|
||||||
@ -331,42 +398,54 @@ impl AlbertForMaskedLM {
|
|||||||
/// # Example
|
/// # Example
|
||||||
///
|
///
|
||||||
/// ```no_run
|
/// ```no_run
|
||||||
///# use tch::{nn, Device, Tensor, no_grad};
|
/// # use tch::{nn, Device, Tensor, no_grad};
|
||||||
///# use rust_bert::Config;
|
/// # use rust_bert::Config;
|
||||||
///# use std::path::Path;
|
/// # use std::path::Path;
|
||||||
///# use tch::kind::Kind::Int64;
|
/// # use tch::kind::Kind::Int64;
|
||||||
/// use rust_bert::albert::{AlbertConfig, AlbertForMaskedLM};
|
/// use rust_bert::albert::{AlbertConfig, AlbertForMaskedLM};
|
||||||
///# let config_path = Path::new("path/to/config.json");
|
/// # let config_path = Path::new("path/to/config.json");
|
||||||
///# let device = Device::Cpu;
|
/// # let device = Device::Cpu;
|
||||||
///# let vs = nn::VarStore::new(device);
|
/// # let vs = nn::VarStore::new(device);
|
||||||
///# let config = AlbertConfig::from_file(config_path);
|
/// # let config = AlbertConfig::from_file(config_path);
|
||||||
///# let albert_model: AlbertForMaskedLM = AlbertForMaskedLM::new(&vs.root(), &config);
|
/// # let albert_model: AlbertForMaskedLM = AlbertForMaskedLM::new(&vs.root(), &config);
|
||||||
/// let (batch_size, sequence_length) = (64, 128);
|
/// let (batch_size, sequence_length) = (64, 128);
|
||||||
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
|
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
|
||||||
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
|
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
|
||||||
/// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
|
/// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
|
||||||
/// let position_ids = Tensor::arange(sequence_length, (Int64, device)).expand(&[batch_size, sequence_length], true);
|
/// let position_ids = Tensor::arange(sequence_length, (Int64, device))
|
||||||
///
|
/// .expand(&[batch_size, sequence_length], true);
|
||||||
/// let (output, all_hidden_states, all_attentions) = no_grad(|| {
|
|
||||||
/// albert_model
|
|
||||||
/// .forward_t(Some(input_tensor),
|
|
||||||
/// Some(mask),
|
|
||||||
/// Some(token_type_ids),
|
|
||||||
/// Some(position_ids),
|
|
||||||
/// None,
|
|
||||||
/// false)
|
|
||||||
/// });
|
|
||||||
///
|
///
|
||||||
|
/// let (output, all_hidden_states, all_attentions) = no_grad(|| {
|
||||||
|
/// albert_model.forward_t(
|
||||||
|
/// Some(input_tensor),
|
||||||
|
/// Some(mask),
|
||||||
|
/// Some(token_type_ids),
|
||||||
|
/// Some(position_ids),
|
||||||
|
/// None,
|
||||||
|
/// false,
|
||||||
|
/// )
|
||||||
|
/// });
|
||||||
/// ```
|
/// ```
|
||||||
///
|
pub fn forward_t(
|
||||||
pub fn forward_t(&self,
|
&self,
|
||||||
input_ids: Option<Tensor>,
|
input_ids: Option<Tensor>,
|
||||||
mask: Option<Tensor>,
|
mask: Option<Tensor>,
|
||||||
token_type_ids: Option<Tensor>,
|
token_type_ids: Option<Tensor>,
|
||||||
position_ids: Option<Tensor>,
|
position_ids: Option<Tensor>,
|
||||||
input_embeds: Option<Tensor>,
|
input_embeds: Option<Tensor>,
|
||||||
train: bool) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Vec<Tensor>>>) {
|
train: bool,
|
||||||
let (hidden_state, _, all_hidden_states, all_attentions) = self.albert.forward_t(input_ids, mask, token_type_ids, position_ids, input_embeds, train).unwrap();
|
) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Vec<Tensor>>>) {
|
||||||
|
let (hidden_state, _, all_hidden_states, all_attentions) = self
|
||||||
|
.albert
|
||||||
|
.forward_t(
|
||||||
|
input_ids,
|
||||||
|
mask,
|
||||||
|
token_type_ids,
|
||||||
|
position_ids,
|
||||||
|
input_embeds,
|
||||||
|
train,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
let prediction_scores = self.predictions.forward(&hidden_state);
|
let prediction_scores = self.predictions.forward(&hidden_state);
|
||||||
(prediction_scores, all_hidden_states, all_attentions)
|
(prediction_scores, all_hidden_states, all_attentions)
|
||||||
}
|
}
|
||||||
@ -395,29 +474,42 @@ impl AlbertForSequenceClassification {
|
|||||||
/// # Example
|
/// # Example
|
||||||
///
|
///
|
||||||
/// ```no_run
|
/// ```no_run
|
||||||
/// use tch::{nn, Device};
|
/// use rust_bert::albert::{AlbertConfig, AlbertForSequenceClassification};
|
||||||
/// use rust_bert::Config;
|
/// use rust_bert::Config;
|
||||||
/// use std::path::Path;
|
/// use std::path::Path;
|
||||||
/// use rust_bert::albert::{AlbertConfig, AlbertForSequenceClassification};
|
/// use tch::{nn, Device};
|
||||||
///
|
///
|
||||||
/// let config_path = Path::new("path/to/config.json");
|
/// let config_path = Path::new("path/to/config.json");
|
||||||
/// let device = Device::Cpu;
|
/// let device = Device::Cpu;
|
||||||
/// let p = nn::VarStore::new(device);
|
/// let p = nn::VarStore::new(device);
|
||||||
/// let config = AlbertConfig::from_file(config_path);
|
/// let config = AlbertConfig::from_file(config_path);
|
||||||
/// let albert: AlbertForSequenceClassification = AlbertForSequenceClassification::new(&p.root(), &config);
|
/// let albert: AlbertForSequenceClassification =
|
||||||
|
/// AlbertForSequenceClassification::new(&p.root(), &config);
|
||||||
/// ```
|
/// ```
|
||||||
///
|
|
||||||
pub fn new(p: &nn::Path, config: &AlbertConfig) -> AlbertForSequenceClassification {
|
pub fn new(p: &nn::Path, config: &AlbertConfig) -> AlbertForSequenceClassification {
|
||||||
let albert = AlbertModel::new(&(p / "albert"), config);
|
let albert = AlbertModel::new(&(p / "albert"), config);
|
||||||
let classifier_dropout_prob = match config.classifier_dropout_prob {
|
let classifier_dropout_prob = match config.classifier_dropout_prob {
|
||||||
Some(value) => value,
|
Some(value) => value,
|
||||||
None => 0.1
|
None => 0.1,
|
||||||
};
|
};
|
||||||
let dropout = Dropout::new(classifier_dropout_prob);
|
let dropout = Dropout::new(classifier_dropout_prob);
|
||||||
let num_labels = config.id2label.as_ref().expect("num_labels not provided in configuration").len() as i64;
|
let num_labels = config
|
||||||
let classifier = nn::linear(&(p / "classifier"), config.hidden_size, num_labels, Default::default());
|
.id2label
|
||||||
|
.as_ref()
|
||||||
|
.expect("num_labels not provided in configuration")
|
||||||
|
.len() as i64;
|
||||||
|
let classifier = nn::linear(
|
||||||
|
&(p / "classifier"),
|
||||||
|
config.hidden_size,
|
||||||
|
num_labels,
|
||||||
|
Default::default(),
|
||||||
|
);
|
||||||
|
|
||||||
AlbertForSequenceClassification { albert, dropout, classifier }
|
AlbertForSequenceClassification {
|
||||||
|
albert,
|
||||||
|
dropout,
|
||||||
|
classifier,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Forward pass through the model
|
/// Forward pass through the model
|
||||||
@ -440,16 +532,16 @@ impl AlbertForSequenceClassification {
|
|||||||
/// # Example
|
/// # Example
|
||||||
///
|
///
|
||||||
/// ```no_run
|
/// ```no_run
|
||||||
///# use tch::{nn, Device, Tensor, no_grad};
|
/// # use tch::{nn, Device, Tensor, no_grad};
|
||||||
///# use rust_bert::Config;
|
/// # use rust_bert::Config;
|
||||||
///# use std::path::Path;
|
/// # use std::path::Path;
|
||||||
///# use tch::kind::Kind::Int64;
|
/// # use tch::kind::Kind::Int64;
|
||||||
/// use rust_bert::albert::{AlbertConfig, AlbertForSequenceClassification};
|
/// use rust_bert::albert::{AlbertConfig, AlbertForSequenceClassification};
|
||||||
///# let config_path = Path::new("path/to/config.json");
|
/// # let config_path = Path::new("path/to/config.json");
|
||||||
///# let device = Device::Cpu;
|
/// # let device = Device::Cpu;
|
||||||
///# let vs = nn::VarStore::new(device);
|
/// # let vs = nn::VarStore::new(device);
|
||||||
///# let config = AlbertConfig::from_file(config_path);
|
/// # let config = AlbertConfig::from_file(config_path);
|
||||||
///# let albert_model: AlbertForSequenceClassification = AlbertForSequenceClassification::new(&vs.root(), &config);
|
/// # let albert_model: AlbertForSequenceClassification = AlbertForSequenceClassification::new(&vs.root(), &config);
|
||||||
/// let (batch_size, sequence_length) = (64, 128);
|
/// let (batch_size, sequence_length) = (64, 128);
|
||||||
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
|
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
|
||||||
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
|
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
|
||||||
@ -465,18 +557,30 @@ impl AlbertForSequenceClassification {
|
|||||||
/// None,
|
/// None,
|
||||||
/// false)
|
/// false)
|
||||||
/// });
|
/// });
|
||||||
///
|
|
||||||
/// ```
|
/// ```
|
||||||
///
|
pub fn forward_t(
|
||||||
pub fn forward_t(&self,
|
&self,
|
||||||
input_ids: Option<Tensor>,
|
input_ids: Option<Tensor>,
|
||||||
mask: Option<Tensor>,
|
mask: Option<Tensor>,
|
||||||
token_type_ids: Option<Tensor>,
|
token_type_ids: Option<Tensor>,
|
||||||
position_ids: Option<Tensor>,
|
position_ids: Option<Tensor>,
|
||||||
input_embeds: Option<Tensor>,
|
input_embeds: Option<Tensor>,
|
||||||
train: bool) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Vec<Tensor>>>) {
|
train: bool,
|
||||||
let (_, pooled_output, all_hidden_states, all_attentions) = self.albert.forward_t(input_ids, mask, token_type_ids, position_ids, input_embeds, train).unwrap();
|
) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Vec<Tensor>>>) {
|
||||||
let logits = pooled_output.apply_t(&self.dropout, train).apply(&self.classifier);
|
let (_, pooled_output, all_hidden_states, all_attentions) = self
|
||||||
|
.albert
|
||||||
|
.forward_t(
|
||||||
|
input_ids,
|
||||||
|
mask,
|
||||||
|
token_type_ids,
|
||||||
|
position_ids,
|
||||||
|
input_embeds,
|
||||||
|
train,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
let logits = pooled_output
|
||||||
|
.apply_t(&self.dropout, train)
|
||||||
|
.apply(&self.classifier);
|
||||||
(logits, all_hidden_states, all_attentions)
|
(logits, all_hidden_states, all_attentions)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -505,25 +609,38 @@ impl AlbertForTokenClassification {
|
|||||||
/// # Example
|
/// # Example
|
||||||
///
|
///
|
||||||
/// ```no_run
|
/// ```no_run
|
||||||
/// use tch::{nn, Device};
|
/// use rust_bert::albert::{AlbertConfig, AlbertForTokenClassification};
|
||||||
/// use rust_bert::Config;
|
/// use rust_bert::Config;
|
||||||
/// use std::path::Path;
|
/// use std::path::Path;
|
||||||
/// use rust_bert::albert::{AlbertConfig, AlbertForTokenClassification};
|
/// use tch::{nn, Device};
|
||||||
///
|
///
|
||||||
/// let config_path = Path::new("path/to/config.json");
|
/// let config_path = Path::new("path/to/config.json");
|
||||||
/// let device = Device::Cpu;
|
/// let device = Device::Cpu;
|
||||||
/// let p = nn::VarStore::new(device);
|
/// let p = nn::VarStore::new(device);
|
||||||
/// let config = AlbertConfig::from_file(config_path);
|
/// let config = AlbertConfig::from_file(config_path);
|
||||||
/// let albert: AlbertForTokenClassification = AlbertForTokenClassification::new(&p.root(), &config);
|
/// let albert: AlbertForTokenClassification =
|
||||||
|
/// AlbertForTokenClassification::new(&p.root(), &config);
|
||||||
/// ```
|
/// ```
|
||||||
///
|
|
||||||
pub fn new(p: &nn::Path, config: &AlbertConfig) -> AlbertForTokenClassification {
|
pub fn new(p: &nn::Path, config: &AlbertConfig) -> AlbertForTokenClassification {
|
||||||
let albert = AlbertModel::new(&(p / "albert"), config);
|
let albert = AlbertModel::new(&(p / "albert"), config);
|
||||||
let dropout = Dropout::new(config.hidden_dropout_prob);
|
let dropout = Dropout::new(config.hidden_dropout_prob);
|
||||||
let num_labels = config.id2label.as_ref().expect("num_labels not provided in configuration").len() as i64;
|
let num_labels = config
|
||||||
let classifier = nn::linear(&(p / "classifier"), config.hidden_size, num_labels, Default::default());
|
.id2label
|
||||||
|
.as_ref()
|
||||||
|
.expect("num_labels not provided in configuration")
|
||||||
|
.len() as i64;
|
||||||
|
let classifier = nn::linear(
|
||||||
|
&(p / "classifier"),
|
||||||
|
config.hidden_size,
|
||||||
|
num_labels,
|
||||||
|
Default::default(),
|
||||||
|
);
|
||||||
|
|
||||||
AlbertForTokenClassification { albert, dropout, classifier }
|
AlbertForTokenClassification {
|
||||||
|
albert,
|
||||||
|
dropout,
|
||||||
|
classifier,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Forward pass through the model
|
/// Forward pass through the model
|
||||||
@ -546,16 +663,16 @@ impl AlbertForTokenClassification {
|
|||||||
/// # Example
|
/// # Example
|
||||||
///
|
///
|
||||||
/// ```no_run
|
/// ```no_run
|
||||||
///# use tch::{nn, Device, Tensor, no_grad};
|
/// # use tch::{nn, Device, Tensor, no_grad};
|
||||||
///# use rust_bert::Config;
|
/// # use rust_bert::Config;
|
||||||
///# use std::path::Path;
|
/// # use std::path::Path;
|
||||||
///# use tch::kind::Kind::Int64;
|
/// # use tch::kind::Kind::Int64;
|
||||||
/// use rust_bert::albert::{AlbertConfig, AlbertForTokenClassification};
|
/// use rust_bert::albert::{AlbertConfig, AlbertForTokenClassification};
|
||||||
///# let config_path = Path::new("path/to/config.json");
|
/// # let config_path = Path::new("path/to/config.json");
|
||||||
///# let device = Device::Cpu;
|
/// # let device = Device::Cpu;
|
||||||
///# let vs = nn::VarStore::new(device);
|
/// # let vs = nn::VarStore::new(device);
|
||||||
///# let config = AlbertConfig::from_file(config_path);
|
/// # let config = AlbertConfig::from_file(config_path);
|
||||||
///# let albert_model: AlbertForTokenClassification = AlbertForTokenClassification::new(&vs.root(), &config);
|
/// # let albert_model: AlbertForTokenClassification = AlbertForTokenClassification::new(&vs.root(), &config);
|
||||||
/// let (batch_size, sequence_length) = (64, 128);
|
/// let (batch_size, sequence_length) = (64, 128);
|
||||||
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
|
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
|
||||||
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
|
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
|
||||||
@ -571,18 +688,30 @@ impl AlbertForTokenClassification {
|
|||||||
/// None,
|
/// None,
|
||||||
/// false)
|
/// false)
|
||||||
/// });
|
/// });
|
||||||
///
|
|
||||||
/// ```
|
/// ```
|
||||||
///
|
pub fn forward_t(
|
||||||
pub fn forward_t(&self,
|
&self,
|
||||||
input_ids: Option<Tensor>,
|
input_ids: Option<Tensor>,
|
||||||
mask: Option<Tensor>,
|
mask: Option<Tensor>,
|
||||||
token_type_ids: Option<Tensor>,
|
token_type_ids: Option<Tensor>,
|
||||||
position_ids: Option<Tensor>,
|
position_ids: Option<Tensor>,
|
||||||
input_embeds: Option<Tensor>,
|
input_embeds: Option<Tensor>,
|
||||||
train: bool) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Vec<Tensor>>>) {
|
train: bool,
|
||||||
let (sequence_output, _, all_hidden_states, all_attentions) = self.albert.forward_t(input_ids, mask, token_type_ids, position_ids, input_embeds, train).unwrap();
|
) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Vec<Tensor>>>) {
|
||||||
let logits = sequence_output.apply_t(&self.dropout, train).apply(&self.classifier);
|
let (sequence_output, _, all_hidden_states, all_attentions) = self
|
||||||
|
.albert
|
||||||
|
.forward_t(
|
||||||
|
input_ids,
|
||||||
|
mask,
|
||||||
|
token_type_ids,
|
||||||
|
position_ids,
|
||||||
|
input_embeds,
|
||||||
|
train,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
let logits = sequence_output
|
||||||
|
.apply_t(&self.dropout, train)
|
||||||
|
.apply(&self.classifier);
|
||||||
(logits, all_hidden_states, all_attentions)
|
(logits, all_hidden_states, all_attentions)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -610,10 +739,10 @@ impl AlbertForQuestionAnswering {
|
|||||||
/// # Example
|
/// # Example
|
||||||
///
|
///
|
||||||
/// ```no_run
|
/// ```no_run
|
||||||
/// use tch::{nn, Device};
|
/// use rust_bert::albert::{AlbertConfig, AlbertForQuestionAnswering};
|
||||||
/// use rust_bert::Config;
|
/// use rust_bert::Config;
|
||||||
/// use std::path::Path;
|
/// use std::path::Path;
|
||||||
/// use rust_bert::albert::{AlbertConfig, AlbertForQuestionAnswering};
|
/// use tch::{nn, Device};
|
||||||
///
|
///
|
||||||
/// let config_path = Path::new("path/to/config.json");
|
/// let config_path = Path::new("path/to/config.json");
|
||||||
/// let device = Device::Cpu;
|
/// let device = Device::Cpu;
|
||||||
@ -621,11 +750,15 @@ impl AlbertForQuestionAnswering {
|
|||||||
/// let config = AlbertConfig::from_file(config_path);
|
/// let config = AlbertConfig::from_file(config_path);
|
||||||
/// let albert: AlbertForQuestionAnswering = AlbertForQuestionAnswering::new(&p.root(), &config);
|
/// let albert: AlbertForQuestionAnswering = AlbertForQuestionAnswering::new(&p.root(), &config);
|
||||||
/// ```
|
/// ```
|
||||||
///
|
|
||||||
pub fn new(p: &nn::Path, config: &AlbertConfig) -> AlbertForQuestionAnswering {
|
pub fn new(p: &nn::Path, config: &AlbertConfig) -> AlbertForQuestionAnswering {
|
||||||
let albert = AlbertModel::new(&(p / "albert"), config);
|
let albert = AlbertModel::new(&(p / "albert"), config);
|
||||||
let num_labels = 2;
|
let num_labels = 2;
|
||||||
let qa_outputs = nn::linear(&(p / "qa_outputs"), config.hidden_size, num_labels, Default::default());
|
let qa_outputs = nn::linear(
|
||||||
|
&(p / "qa_outputs"),
|
||||||
|
config.hidden_size,
|
||||||
|
num_labels,
|
||||||
|
Default::default(),
|
||||||
|
);
|
||||||
|
|
||||||
AlbertForQuestionAnswering { albert, qa_outputs }
|
AlbertForQuestionAnswering { albert, qa_outputs }
|
||||||
}
|
}
|
||||||
@ -651,16 +784,16 @@ impl AlbertForQuestionAnswering {
|
|||||||
/// # Example
|
/// # Example
|
||||||
///
|
///
|
||||||
/// ```no_run
|
/// ```no_run
|
||||||
///# use tch::{nn, Device, Tensor, no_grad};
|
/// # use tch::{nn, Device, Tensor, no_grad};
|
||||||
///# use rust_bert::Config;
|
/// # use rust_bert::Config;
|
||||||
///# use std::path::Path;
|
/// # use std::path::Path;
|
||||||
///# use tch::kind::Kind::Int64;
|
/// # use tch::kind::Kind::Int64;
|
||||||
/// use rust_bert::albert::{AlbertConfig, AlbertForQuestionAnswering};
|
/// use rust_bert::albert::{AlbertConfig, AlbertForQuestionAnswering};
|
||||||
///# let config_path = Path::new("path/to/config.json");
|
/// # let config_path = Path::new("path/to/config.json");
|
||||||
///# let device = Device::Cpu;
|
/// # let device = Device::Cpu;
|
||||||
///# let vs = nn::VarStore::new(device);
|
/// # let vs = nn::VarStore::new(device);
|
||||||
///# let config = AlbertConfig::from_file(config_path);
|
/// # let config = AlbertConfig::from_file(config_path);
|
||||||
///# let albert_model: AlbertForQuestionAnswering = AlbertForQuestionAnswering::new(&vs.root(), &config);
|
/// # let albert_model: AlbertForQuestionAnswering = AlbertForQuestionAnswering::new(&vs.root(), &config);
|
||||||
/// let (batch_size, sequence_length) = (64, 128);
|
/// let (batch_size, sequence_length) = (64, 128);
|
||||||
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
|
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
|
||||||
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
|
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
|
||||||
@ -676,17 +809,32 @@ impl AlbertForQuestionAnswering {
|
|||||||
/// None,
|
/// None,
|
||||||
/// false)
|
/// false)
|
||||||
/// });
|
/// });
|
||||||
///
|
|
||||||
/// ```
|
/// ```
|
||||||
///
|
pub fn forward_t(
|
||||||
pub fn forward_t(&self,
|
&self,
|
||||||
input_ids: Option<Tensor>,
|
input_ids: Option<Tensor>,
|
||||||
mask: Option<Tensor>,
|
mask: Option<Tensor>,
|
||||||
token_type_ids: Option<Tensor>,
|
token_type_ids: Option<Tensor>,
|
||||||
position_ids: Option<Tensor>,
|
position_ids: Option<Tensor>,
|
||||||
input_embeds: Option<Tensor>,
|
input_embeds: Option<Tensor>,
|
||||||
train: bool) -> (Tensor, Tensor, Option<Vec<Tensor>>, Option<Vec<Vec<Tensor>>>) {
|
train: bool,
|
||||||
let (sequence_output, _, all_hidden_states, all_attentions) = self.albert.forward_t(input_ids, mask, token_type_ids, position_ids, input_embeds, train).unwrap();
|
) -> (
|
||||||
|
Tensor,
|
||||||
|
Tensor,
|
||||||
|
Option<Vec<Tensor>>,
|
||||||
|
Option<Vec<Vec<Tensor>>>,
|
||||||
|
) {
|
||||||
|
let (sequence_output, _, all_hidden_states, all_attentions) = self
|
||||||
|
.albert
|
||||||
|
.forward_t(
|
||||||
|
input_ids,
|
||||||
|
mask,
|
||||||
|
token_type_ids,
|
||||||
|
position_ids,
|
||||||
|
input_embeds,
|
||||||
|
train,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
let logits = sequence_output.apply(&self.qa_outputs).split(1, -1);
|
let logits = sequence_output.apply(&self.qa_outputs).split(1, -1);
|
||||||
let (start_logits, end_logits) = (&logits[0], &logits[1]);
|
let (start_logits, end_logits) = (&logits[0], &logits[1]);
|
||||||
let start_logits = start_logits.squeeze1(-1);
|
let start_logits = start_logits.squeeze1(-1);
|
||||||
@ -721,10 +869,10 @@ impl AlbertForMultipleChoice {
|
|||||||
/// # Example
|
/// # Example
|
||||||
///
|
///
|
||||||
/// ```no_run
|
/// ```no_run
|
||||||
/// use tch::{nn, Device};
|
/// use rust_bert::albert::{AlbertConfig, AlbertForMultipleChoice};
|
||||||
/// use rust_bert::Config;
|
/// use rust_bert::Config;
|
||||||
/// use std::path::Path;
|
/// use std::path::Path;
|
||||||
/// use rust_bert::albert::{AlbertConfig, AlbertForMultipleChoice};
|
/// use tch::{nn, Device};
|
||||||
///
|
///
|
||||||
/// let config_path = Path::new("path/to/config.json");
|
/// let config_path = Path::new("path/to/config.json");
|
||||||
/// let device = Device::Cpu;
|
/// let device = Device::Cpu;
|
||||||
@ -732,14 +880,22 @@ impl AlbertForMultipleChoice {
|
|||||||
/// let config = AlbertConfig::from_file(config_path);
|
/// let config = AlbertConfig::from_file(config_path);
|
||||||
/// let albert: AlbertForMultipleChoice = AlbertForMultipleChoice::new(&p.root(), &config);
|
/// let albert: AlbertForMultipleChoice = AlbertForMultipleChoice::new(&p.root(), &config);
|
||||||
/// ```
|
/// ```
|
||||||
///
|
|
||||||
pub fn new(p: &nn::Path, config: &AlbertConfig) -> AlbertForMultipleChoice {
|
pub fn new(p: &nn::Path, config: &AlbertConfig) -> AlbertForMultipleChoice {
|
||||||
let albert = AlbertModel::new(&(p / "albert"), config);
|
let albert = AlbertModel::new(&(p / "albert"), config);
|
||||||
let dropout = Dropout::new(config.hidden_dropout_prob);
|
let dropout = Dropout::new(config.hidden_dropout_prob);
|
||||||
let num_labels = 1;
|
let num_labels = 1;
|
||||||
let classifier = nn::linear(&(p / "classifier"), config.hidden_size, num_labels, Default::default());
|
let classifier = nn::linear(
|
||||||
|
&(p / "classifier"),
|
||||||
|
config.hidden_size,
|
||||||
|
num_labels,
|
||||||
|
Default::default(),
|
||||||
|
);
|
||||||
|
|
||||||
AlbertForMultipleChoice { albert, dropout, classifier }
|
AlbertForMultipleChoice {
|
||||||
|
albert,
|
||||||
|
dropout,
|
||||||
|
classifier,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Forward pass through the model
|
/// Forward pass through the model
|
||||||
@ -762,16 +918,16 @@ impl AlbertForMultipleChoice {
|
|||||||
/// # Example
|
/// # Example
|
||||||
///
|
///
|
||||||
/// ```no_run
|
/// ```no_run
|
||||||
///# use tch::{nn, Device, Tensor, no_grad};
|
/// # use tch::{nn, Device, Tensor, no_grad};
|
||||||
///# use rust_bert::Config;
|
/// # use rust_bert::Config;
|
||||||
///# use std::path::Path;
|
/// # use std::path::Path;
|
||||||
///# use tch::kind::Kind::Int64;
|
/// # use tch::kind::Kind::Int64;
|
||||||
/// use rust_bert::albert::{AlbertConfig, AlbertForMultipleChoice};
|
/// use rust_bert::albert::{AlbertConfig, AlbertForMultipleChoice};
|
||||||
///# let config_path = Path::new("path/to/config.json");
|
/// # let config_path = Path::new("path/to/config.json");
|
||||||
///# let device = Device::Cpu;
|
/// # let device = Device::Cpu;
|
||||||
///# let vs = nn::VarStore::new(device);
|
/// # let vs = nn::VarStore::new(device);
|
||||||
///# let config = AlbertConfig::from_file(config_path);
|
/// # let config = AlbertConfig::from_file(config_path);
|
||||||
///# let albert_model: AlbertForMultipleChoice = AlbertForMultipleChoice::new(&vs.root(), &config);
|
/// # let albert_model: AlbertForMultipleChoice = AlbertForMultipleChoice::new(&vs.root(), &config);
|
||||||
/// let (batch_size, sequence_length) = (64, 128);
|
/// let (batch_size, sequence_length) = (64, 128);
|
||||||
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
|
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
|
||||||
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
|
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
|
||||||
@ -787,44 +943,68 @@ impl AlbertForMultipleChoice {
|
|||||||
/// None,
|
/// None,
|
||||||
/// false).unwrap()
|
/// false).unwrap()
|
||||||
/// });
|
/// });
|
||||||
///
|
|
||||||
/// ```
|
/// ```
|
||||||
///
|
pub fn forward_t(
|
||||||
pub fn forward_t(&self,
|
&self,
|
||||||
input_ids: Option<Tensor>,
|
input_ids: Option<Tensor>,
|
||||||
mask: Option<Tensor>,
|
mask: Option<Tensor>,
|
||||||
token_type_ids: Option<Tensor>,
|
token_type_ids: Option<Tensor>,
|
||||||
position_ids: Option<Tensor>,
|
position_ids: Option<Tensor>,
|
||||||
input_embeds: Option<Tensor>,
|
input_embeds: Option<Tensor>,
|
||||||
train: bool) -> Result<(Tensor, Option<Vec<Tensor>>, Option<Vec<Vec<Tensor>>>), &'static str> {
|
train: bool,
|
||||||
|
) -> Result<(Tensor, Option<Vec<Tensor>>, Option<Vec<Vec<Tensor>>>), &'static str> {
|
||||||
let (input_ids, input_embeds, num_choices) = match &input_ids {
|
let (input_ids, input_embeds, num_choices) = match &input_ids {
|
||||||
Some(input_value) => match &input_embeds {
|
Some(input_value) => match &input_embeds {
|
||||||
Some(_) => { return Err("Only one of input ids or input embeddings may be set"); }
|
Some(_) => {
|
||||||
None => (Some(input_value.view((-1, *input_value.size().last().unwrap()))), None, input_value.size()[1])
|
return Err("Only one of input ids or input embeddings may be set");
|
||||||
}
|
}
|
||||||
|
None => (
|
||||||
|
Some(input_value.view((-1, *input_value.size().last().unwrap()))),
|
||||||
|
None,
|
||||||
|
input_value.size()[1],
|
||||||
|
),
|
||||||
|
},
|
||||||
None => match &input_embeds {
|
None => match &input_embeds {
|
||||||
Some(embeds) => (None, Some(embeds.view((-1, embeds.size()[1], embeds.size()[2]))), embeds.size()[1]),
|
Some(embeds) => (
|
||||||
None => { return Err("At least one of input ids or input embeddings must be set"); }
|
None,
|
||||||
}
|
Some(embeds.view((-1, embeds.size()[1], embeds.size()[2]))),
|
||||||
|
embeds.size()[1],
|
||||||
|
),
|
||||||
|
None => {
|
||||||
|
return Err("At least one of input ids or input embeddings must be set");
|
||||||
|
}
|
||||||
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
let mask = match mask {
|
let mask = match mask {
|
||||||
Some(value) => Some(value.view((-1, *value.size().last().unwrap()))),
|
Some(value) => Some(value.view((-1, *value.size().last().unwrap()))),
|
||||||
None => None
|
None => None,
|
||||||
};
|
};
|
||||||
let token_type_ids = match token_type_ids {
|
let token_type_ids = match token_type_ids {
|
||||||
Some(value) => Some(value.view((-1, *value.size().last().unwrap()))),
|
Some(value) => Some(value.view((-1, *value.size().last().unwrap()))),
|
||||||
None => None
|
None => None,
|
||||||
};
|
};
|
||||||
let position_ids = match position_ids {
|
let position_ids = match position_ids {
|
||||||
Some(value) => Some(value.view((-1, *value.size().last().unwrap()))),
|
Some(value) => Some(value.view((-1, *value.size().last().unwrap()))),
|
||||||
None => None
|
None => None,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let (_, pooled_output, all_hidden_states, all_attentions) = self
|
||||||
let (_, pooled_output, all_hidden_states, all_attentions) = self.albert.forward_t(input_ids, mask, token_type_ids, position_ids, input_embeds, train).unwrap();
|
.albert
|
||||||
let logits = pooled_output.apply_t(&self.dropout, train).apply(&self.classifier).view((-1, num_choices));
|
.forward_t(
|
||||||
|
input_ids,
|
||||||
|
mask,
|
||||||
|
token_type_ids,
|
||||||
|
position_ids,
|
||||||
|
input_embeds,
|
||||||
|
train,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
let logits = pooled_output
|
||||||
|
.apply_t(&self.dropout, train)
|
||||||
|
.apply(&self.classifier)
|
||||||
|
.view((-1, num_choices));
|
||||||
|
|
||||||
Ok((logits, all_hidden_states, all_attentions))
|
Ok((logits, all_hidden_states, all_attentions))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -11,10 +11,10 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
use crate::common::dropout::Dropout;
|
|
||||||
use tch::{nn, Tensor};
|
|
||||||
use crate::albert::AlbertConfig;
|
use crate::albert::AlbertConfig;
|
||||||
|
use crate::common::dropout::Dropout;
|
||||||
use tch::kind::Kind::Float;
|
use tch::kind::Kind::Float;
|
||||||
|
use tch::{nn, Tensor};
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct AlbertSelfAttention {
|
pub struct AlbertSelfAttention {
|
||||||
@ -32,24 +32,55 @@ pub struct AlbertSelfAttention {
|
|||||||
|
|
||||||
impl AlbertSelfAttention {
|
impl AlbertSelfAttention {
|
||||||
pub fn new(p: nn::Path, config: &AlbertConfig) -> AlbertSelfAttention {
|
pub fn new(p: nn::Path, config: &AlbertConfig) -> AlbertSelfAttention {
|
||||||
assert_eq!(config.hidden_size % config.num_attention_heads, 0, "Hidden size not a multiple of the number of attention heads");
|
assert_eq!(
|
||||||
|
config.hidden_size % config.num_attention_heads,
|
||||||
|
0,
|
||||||
|
"Hidden size not a multiple of the number of attention heads"
|
||||||
|
);
|
||||||
|
|
||||||
let query = nn::linear(&p / "query", config.hidden_size, config.hidden_size, Default::default());
|
let query = nn::linear(
|
||||||
let key = nn::linear(&p / "key", config.hidden_size, config.hidden_size, Default::default());
|
&p / "query",
|
||||||
let value = nn::linear(&p / "value", config.hidden_size, config.hidden_size, Default::default());
|
config.hidden_size,
|
||||||
let dense = nn::linear(&p / "dense", config.hidden_size, config.hidden_size, Default::default());
|
config.hidden_size,
|
||||||
|
Default::default(),
|
||||||
|
);
|
||||||
|
let key = nn::linear(
|
||||||
|
&p / "key",
|
||||||
|
config.hidden_size,
|
||||||
|
config.hidden_size,
|
||||||
|
Default::default(),
|
||||||
|
);
|
||||||
|
let value = nn::linear(
|
||||||
|
&p / "value",
|
||||||
|
config.hidden_size,
|
||||||
|
config.hidden_size,
|
||||||
|
Default::default(),
|
||||||
|
);
|
||||||
|
let dense = nn::linear(
|
||||||
|
&p / "dense",
|
||||||
|
config.hidden_size,
|
||||||
|
config.hidden_size,
|
||||||
|
Default::default(),
|
||||||
|
);
|
||||||
let dropout = Dropout::new(config.attention_probs_dropout_prob);
|
let dropout = Dropout::new(config.attention_probs_dropout_prob);
|
||||||
let attention_head_size = config.hidden_size / config.num_attention_heads;
|
let attention_head_size = config.hidden_size / config.num_attention_heads;
|
||||||
let output_attentions = match config.output_attentions {
|
let output_attentions = match config.output_attentions {
|
||||||
Some(value) => value,
|
Some(value) => value,
|
||||||
None => false
|
None => false,
|
||||||
};
|
};
|
||||||
let layer_norm_eps = match config.layer_norm_eps {
|
let layer_norm_eps = match config.layer_norm_eps {
|
||||||
Some(value) => value,
|
Some(value) => value,
|
||||||
None => 1e-12
|
None => 1e-12,
|
||||||
};
|
};
|
||||||
let layer_norm_config = nn::LayerNormConfig { eps: layer_norm_eps, ..Default::default() };
|
let layer_norm_config = nn::LayerNormConfig {
|
||||||
let layer_norm = nn::layer_norm(&p / "LayerNorm", vec![config.hidden_size], layer_norm_config);
|
eps: layer_norm_eps,
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
let layer_norm = nn::layer_norm(
|
||||||
|
&p / "LayerNorm",
|
||||||
|
vec![config.hidden_size],
|
||||||
|
layer_norm_config,
|
||||||
|
);
|
||||||
|
|
||||||
AlbertSelfAttention {
|
AlbertSelfAttention {
|
||||||
num_attention_heads: config.num_attention_heads,
|
num_attention_heads: config.num_attention_heads,
|
||||||
@ -66,19 +97,23 @@ impl AlbertSelfAttention {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn split_heads(&self, x: Tensor, bs: i64, dim_per_head: i64) -> Tensor {
|
fn split_heads(&self, x: Tensor, bs: i64, dim_per_head: i64) -> Tensor {
|
||||||
x.view((bs, -1, self.num_attention_heads, dim_per_head)).transpose(1, 2)
|
x.view((bs, -1, self.num_attention_heads, dim_per_head))
|
||||||
|
.transpose(1, 2)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn forward_t(&self,
|
pub fn forward_t(
|
||||||
input_ids: &Tensor,
|
&self,
|
||||||
mask: &Option<Tensor>,
|
input_ids: &Tensor,
|
||||||
train: bool) -> (Tensor, Option<Tensor>) {
|
mask: &Option<Tensor>,
|
||||||
|
train: bool,
|
||||||
|
) -> (Tensor, Option<Tensor>) {
|
||||||
let bs = *input_ids.size().first().unwrap();
|
let bs = *input_ids.size().first().unwrap();
|
||||||
|
|
||||||
let key_layer = self.split_heads(input_ids.apply(&self.key), bs, self.attention_head_size);
|
let key_layer = self.split_heads(input_ids.apply(&self.key), bs, self.attention_head_size);
|
||||||
let value_layer = self.split_heads(input_ids.apply(&self.value), bs, self.attention_head_size);
|
let value_layer =
|
||||||
let query_layer = self.split_heads(input_ids.apply(&self.query), bs, self.attention_head_size);
|
self.split_heads(input_ids.apply(&self.value), bs, self.attention_head_size);
|
||||||
|
let query_layer =
|
||||||
|
self.split_heads(input_ids.apply(&self.query), bs, self.attention_head_size);
|
||||||
|
|
||||||
let query_layer: Tensor = query_layer / (self.attention_head_size as f64).sqrt();
|
let query_layer: Tensor = query_layer / (self.attention_head_size as f64).sqrt();
|
||||||
|
|
||||||
@ -91,9 +126,11 @@ impl AlbertSelfAttention {
|
|||||||
let weights = scores.softmax(-1, Float).apply_t(&self.dropout, train);
|
let weights = scores.softmax(-1, Float).apply_t(&self.dropout, train);
|
||||||
let context = weights.matmul(&value_layer).transpose(1, 2).contiguous();
|
let context = weights.matmul(&value_layer).transpose(1, 2).contiguous();
|
||||||
|
|
||||||
let w = self.dense.ws
|
let w = self.dense.ws.transpose(0, 1).view((
|
||||||
.transpose(0, 1)
|
self.num_attention_heads,
|
||||||
.view((self.num_attention_heads, self.attention_head_size, self.hidden_size));
|
self.attention_head_size,
|
||||||
|
self.hidden_size,
|
||||||
|
));
|
||||||
|
|
||||||
let context: Tensor = Tensor::einsum("bfnd,ndh->bfh", &[context, w]) + &self.dense.bs;
|
let context: Tensor = Tensor::einsum("bfnd,ndh->bfh", &[context, w]) + &self.dense.bs;
|
||||||
let context = (input_ids + context.apply_t(&self.dropout, train)).apply(&self.layer_norm);
|
let context = (input_ids + context.apply_t(&self.dropout, train)).apply(&self.layer_norm);
|
||||||
@ -104,4 +141,4 @@ impl AlbertSelfAttention {
|
|||||||
(context, Some(weights))
|
(context, Some(weights))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -11,10 +11,10 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
use tch::{nn, Tensor, Kind};
|
|
||||||
use crate::common::dropout::Dropout;
|
|
||||||
use crate::albert::AlbertConfig;
|
use crate::albert::AlbertConfig;
|
||||||
use tch::nn::{EmbeddingConfig, embedding};
|
use crate::common::dropout::Dropout;
|
||||||
|
use tch::nn::{embedding, EmbeddingConfig};
|
||||||
|
use tch::{nn, Kind, Tensor};
|
||||||
|
|
||||||
/// # Embeddings implementation for Albert model
|
/// # Embeddings implementation for Albert model
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
@ -34,49 +34,77 @@ impl AlbertEmbeddings {
|
|||||||
..Default::default()
|
..Default::default()
|
||||||
};
|
};
|
||||||
|
|
||||||
let word_embeddings: nn::Embedding = embedding(p / "word_embeddings",
|
let word_embeddings: nn::Embedding = embedding(
|
||||||
config.vocab_size,
|
p / "word_embeddings",
|
||||||
config.embedding_size,
|
config.vocab_size,
|
||||||
embedding_config);
|
config.embedding_size,
|
||||||
|
embedding_config,
|
||||||
|
);
|
||||||
|
|
||||||
let position_embeddings: nn::Embedding = embedding(p / "position_embeddings",
|
let position_embeddings: nn::Embedding = embedding(
|
||||||
config.max_position_embeddings,
|
p / "position_embeddings",
|
||||||
config.embedding_size,
|
config.max_position_embeddings,
|
||||||
Default::default());
|
config.embedding_size,
|
||||||
|
Default::default(),
|
||||||
|
);
|
||||||
|
|
||||||
let token_type_embeddings: nn::Embedding = embedding(p / "token_type_embeddings",
|
let token_type_embeddings: nn::Embedding = embedding(
|
||||||
config.type_vocab_size,
|
p / "token_type_embeddings",
|
||||||
config.embedding_size,
|
config.type_vocab_size,
|
||||||
Default::default());
|
config.embedding_size,
|
||||||
|
Default::default(),
|
||||||
|
);
|
||||||
|
|
||||||
let layer_norm_eps = match config.layer_norm_eps {
|
let layer_norm_eps = match config.layer_norm_eps {
|
||||||
Some(value) => value,
|
Some(value) => value,
|
||||||
None => 1e-12
|
None => 1e-12,
|
||||||
};
|
};
|
||||||
let layer_norm_config = nn::LayerNormConfig { eps: layer_norm_eps, ..Default::default() };
|
let layer_norm_config = nn::LayerNormConfig {
|
||||||
let layer_norm: nn::LayerNorm = nn::layer_norm(p / "LayerNorm", vec![config.embedding_size], layer_norm_config);
|
eps: layer_norm_eps,
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
let layer_norm: nn::LayerNorm = nn::layer_norm(
|
||||||
|
p / "LayerNorm",
|
||||||
|
vec![config.embedding_size],
|
||||||
|
layer_norm_config,
|
||||||
|
);
|
||||||
let dropout: Dropout = Dropout::new(config.hidden_dropout_prob);
|
let dropout: Dropout = Dropout::new(config.hidden_dropout_prob);
|
||||||
AlbertEmbeddings { word_embeddings, position_embeddings, token_type_embeddings, layer_norm, dropout}
|
AlbertEmbeddings {
|
||||||
|
word_embeddings,
|
||||||
|
position_embeddings,
|
||||||
|
token_type_embeddings,
|
||||||
|
layer_norm,
|
||||||
|
dropout,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn forward_t(&self,
|
pub fn forward_t(
|
||||||
input_ids: Option<Tensor>,
|
&self,
|
||||||
token_type_ids: Option<Tensor>,
|
input_ids: Option<Tensor>,
|
||||||
position_ids: Option<Tensor>,
|
token_type_ids: Option<Tensor>,
|
||||||
input_embeds: Option<Tensor>,
|
position_ids: Option<Tensor>,
|
||||||
train: bool) -> Result<Tensor, &'static str> {
|
input_embeds: Option<Tensor>,
|
||||||
|
train: bool,
|
||||||
|
) -> Result<Tensor, &'static str> {
|
||||||
let (input_embeddings, input_shape) = match input_ids {
|
let (input_embeddings, input_shape) = match input_ids {
|
||||||
Some(input_value) => match input_embeds {
|
Some(input_value) => match input_embeds {
|
||||||
Some(_) => { return Err("Only one of input ids or input embeddings may be set"); }
|
Some(_) => {
|
||||||
None => (input_value.apply_t(&self.word_embeddings, train), input_value.size())
|
return Err("Only one of input ids or input embeddings may be set");
|
||||||
}
|
}
|
||||||
|
None => (
|
||||||
|
input_value.apply_t(&self.word_embeddings, train),
|
||||||
|
input_value.size(),
|
||||||
|
),
|
||||||
|
},
|
||||||
None => match input_embeds {
|
None => match input_embeds {
|
||||||
Some(embeds) => {
|
Some(embeds) => {
|
||||||
let size = vec!(embeds.size()[0], embeds.size()[1]);
|
let size = vec![embeds.size()[0], embeds.size()[1]];
|
||||||
(embeds, size)
|
(embeds, size)
|
||||||
},
|
}
|
||||||
None => { return Err("Only one of input ids or input embeddings may be set"); }
|
None => {
|
||||||
}
|
return Err("Only one of input ids or input embeddings may be set");
|
||||||
|
}
|
||||||
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
let seq_length = input_embeddings.as_ref().size()[1].to_owned();
|
let seq_length = input_embeddings.as_ref().size()[1].to_owned();
|
||||||
@ -84,19 +112,22 @@ impl AlbertEmbeddings {
|
|||||||
let position_ids = match position_ids {
|
let position_ids = match position_ids {
|
||||||
Some(value) => value,
|
Some(value) => value,
|
||||||
None => Tensor::arange(seq_length, (Kind::Int64, input_embeddings.device()))
|
None => Tensor::arange(seq_length, (Kind::Int64, input_embeddings.device()))
|
||||||
.unsqueeze(0).
|
.unsqueeze(0)
|
||||||
expand(&input_shape, true)
|
.expand(&input_shape, true),
|
||||||
};
|
};
|
||||||
|
|
||||||
let token_type_ids = match token_type_ids {
|
let token_type_ids = match token_type_ids {
|
||||||
Some(value) => value,
|
Some(value) => value,
|
||||||
None => Tensor::zeros(&input_shape, (Kind::Int64, input_embeddings.device()))
|
None => Tensor::zeros(&input_shape, (Kind::Int64, input_embeddings.device())),
|
||||||
};
|
};
|
||||||
|
|
||||||
let position_embeddings = position_ids.apply(&self.position_embeddings);
|
let position_embeddings = position_ids.apply(&self.position_embeddings);
|
||||||
let token_type_embeddings = token_type_ids.apply(&self.token_type_embeddings);
|
let token_type_embeddings = token_type_ids.apply(&self.token_type_embeddings);
|
||||||
|
|
||||||
let input_embeddings: Tensor = input_embeddings + position_embeddings + token_type_embeddings;
|
let input_embeddings: Tensor =
|
||||||
Ok(input_embeddings.apply(&self.layer_norm).apply_t(&self.dropout, train))
|
input_embeddings + position_embeddings + token_type_embeddings;
|
||||||
|
Ok(input_embeddings
|
||||||
|
.apply(&self.layer_norm)
|
||||||
|
.apply_t(&self.dropout, train))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -11,12 +11,12 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
use crate::albert::attention::AlbertSelfAttention;
|
|
||||||
use tch::{nn, Tensor};
|
|
||||||
use crate::albert::AlbertConfig;
|
|
||||||
use crate::albert::albert::Activation;
|
use crate::albert::albert::Activation;
|
||||||
use crate::common::activations::{_gelu_new, _gelu, _relu, _mish};
|
use crate::albert::attention::AlbertSelfAttention;
|
||||||
|
use crate::albert::AlbertConfig;
|
||||||
|
use crate::common::activations::{_gelu, _gelu_new, _mish, _relu};
|
||||||
use std::borrow::BorrowMut;
|
use std::borrow::BorrowMut;
|
||||||
|
use tch::{nn, Tensor};
|
||||||
|
|
||||||
pub struct AlbertLayer {
|
pub struct AlbertLayer {
|
||||||
attention: AlbertSelfAttention,
|
attention: AlbertSelfAttention,
|
||||||
@ -32,29 +32,55 @@ impl AlbertLayer {
|
|||||||
|
|
||||||
let layer_norm_eps = match config.layer_norm_eps {
|
let layer_norm_eps = match config.layer_norm_eps {
|
||||||
Some(value) => value,
|
Some(value) => value,
|
||||||
None => 1e-12
|
None => 1e-12,
|
||||||
};
|
};
|
||||||
let layer_norm_config = nn::LayerNormConfig { eps: layer_norm_eps, ..Default::default() };
|
let layer_norm_config = nn::LayerNormConfig {
|
||||||
let full_layer_layer_norm = nn::layer_norm(&(p / "full_layer_layer_norm"), vec![config.hidden_size], layer_norm_config);
|
eps: layer_norm_eps,
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
let full_layer_layer_norm = nn::layer_norm(
|
||||||
|
&(p / "full_layer_layer_norm"),
|
||||||
|
vec![config.hidden_size],
|
||||||
|
layer_norm_config,
|
||||||
|
);
|
||||||
|
|
||||||
let ffn = nn::linear(&(p / "ffn"), config.hidden_size, config.intermediate_size, Default::default());
|
let ffn = nn::linear(
|
||||||
let ffn_output = nn::linear(&(p / "ffn_output"), config.intermediate_size, config.hidden_size, Default::default());
|
&(p / "ffn"),
|
||||||
|
config.hidden_size,
|
||||||
|
config.intermediate_size,
|
||||||
|
Default::default(),
|
||||||
|
);
|
||||||
|
let ffn_output = nn::linear(
|
||||||
|
&(p / "ffn_output"),
|
||||||
|
config.intermediate_size,
|
||||||
|
config.hidden_size,
|
||||||
|
Default::default(),
|
||||||
|
);
|
||||||
|
|
||||||
let activation = Box::new(match &config.hidden_act {
|
let activation = Box::new(match &config.hidden_act {
|
||||||
Activation::gelu_new => _gelu_new,
|
Activation::gelu_new => _gelu_new,
|
||||||
Activation::gelu => _gelu,
|
Activation::gelu => _gelu,
|
||||||
Activation::relu => _relu,
|
Activation::relu => _relu,
|
||||||
Activation::mish => _mish
|
Activation::mish => _mish,
|
||||||
});
|
});
|
||||||
|
|
||||||
AlbertLayer { attention, full_layer_layer_norm, ffn, ffn_output, activation }
|
AlbertLayer {
|
||||||
|
attention,
|
||||||
|
full_layer_layer_norm,
|
||||||
|
ffn,
|
||||||
|
ffn_output,
|
||||||
|
activation,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn forward_t(&self,
|
pub fn forward_t(
|
||||||
hidden_states: &Tensor,
|
&self,
|
||||||
mask: &Option<Tensor>,
|
hidden_states: &Tensor,
|
||||||
train: bool) -> (Tensor, Option<Tensor>) {
|
mask: &Option<Tensor>,
|
||||||
let (attention_output, attention_weights) = self.attention.forward_t(hidden_states, mask, train);
|
train: bool,
|
||||||
|
) -> (Tensor, Option<Tensor>) {
|
||||||
|
let (attention_output, attention_weights) =
|
||||||
|
self.attention.forward_t(hidden_states, mask, train);
|
||||||
let ffn_output = attention_output.apply(&self.ffn);
|
let ffn_output = attention_output.apply(&self.ffn);
|
||||||
let ffn_output: Tensor = (self.activation)(&ffn_output);
|
let ffn_output: Tensor = (self.activation)(&ffn_output);
|
||||||
let ffn_output = ffn_output.apply(&self.ffn_output);
|
let ffn_output = ffn_output.apply(&self.ffn_output);
|
||||||
@ -76,29 +102,42 @@ impl AlbertLayerGroup {
|
|||||||
|
|
||||||
let output_attentions = match config.output_attentions {
|
let output_attentions = match config.output_attentions {
|
||||||
Some(value) => value,
|
Some(value) => value,
|
||||||
None => false
|
None => false,
|
||||||
};
|
};
|
||||||
|
|
||||||
let output_hidden_states = match config.output_hidden_states {
|
let output_hidden_states = match config.output_hidden_states {
|
||||||
Some(value) => value,
|
Some(value) => value,
|
||||||
None => false
|
None => false,
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut layers: Vec<AlbertLayer> = vec!();
|
let mut layers: Vec<AlbertLayer> = vec![];
|
||||||
for layer_index in 0..config.inner_group_num {
|
for layer_index in 0..config.inner_group_num {
|
||||||
layers.push(AlbertLayer::new(&(p / layer_index), config));
|
layers.push(AlbertLayer::new(&(p / layer_index), config));
|
||||||
};
|
}
|
||||||
|
|
||||||
AlbertLayerGroup { output_hidden_states, output_attentions, layers }
|
AlbertLayerGroup {
|
||||||
|
output_hidden_states,
|
||||||
|
output_attentions,
|
||||||
|
layers,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn forward_t(&self,
|
pub fn forward_t(
|
||||||
hidden_states: &Tensor,
|
&self,
|
||||||
mask: &Option<Tensor>,
|
hidden_states: &Tensor,
|
||||||
train: bool)
|
mask: &Option<Tensor>,
|
||||||
-> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
|
train: bool,
|
||||||
let mut all_hidden_states: Option<Vec<Tensor>> = if self.output_hidden_states { Some(vec!()) } else { None };
|
) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
|
||||||
let mut all_attentions: Option<Vec<Tensor>> = if self.output_attentions { Some(vec!()) } else { None };
|
let mut all_hidden_states: Option<Vec<Tensor>> = if self.output_hidden_states {
|
||||||
|
Some(vec![])
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
let mut all_attentions: Option<Vec<Tensor>> = if self.output_attentions {
|
||||||
|
Some(vec![])
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
let mut hidden_state = hidden_states.copy();
|
let mut hidden_state = hidden_states.copy();
|
||||||
let mut attention_weights: Option<Tensor>;
|
let mut attention_weights: Option<Tensor>;
|
||||||
@ -117,9 +156,9 @@ impl AlbertLayerGroup {
|
|||||||
attentions.push(attention_weights.as_ref().unwrap().copy());
|
attentions.push(attention_weights.as_ref().unwrap().copy());
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
None => break
|
None => break,
|
||||||
};
|
};
|
||||||
};
|
}
|
||||||
|
|
||||||
(hidden_state, all_hidden_states, all_attentions)
|
(hidden_state, all_hidden_states, all_attentions)
|
||||||
}
|
}
|
||||||
@ -140,20 +179,25 @@ impl AlbertTransformer {
|
|||||||
|
|
||||||
let output_attentions = match config.output_attentions {
|
let output_attentions = match config.output_attentions {
|
||||||
Some(value) => value,
|
Some(value) => value,
|
||||||
None => false
|
None => false,
|
||||||
};
|
};
|
||||||
|
|
||||||
let output_hidden_states = match config.output_hidden_states {
|
let output_hidden_states = match config.output_hidden_states {
|
||||||
Some(value) => value,
|
Some(value) => value,
|
||||||
None => false
|
None => false,
|
||||||
};
|
};
|
||||||
|
|
||||||
let embedding_hidden_mapping_in = nn::linear(&(p / "embedding_hidden_mapping_in"), config.embedding_size, config.hidden_size, Default::default());
|
let embedding_hidden_mapping_in = nn::linear(
|
||||||
|
&(p / "embedding_hidden_mapping_in"),
|
||||||
|
config.embedding_size,
|
||||||
|
config.hidden_size,
|
||||||
|
Default::default(),
|
||||||
|
);
|
||||||
|
|
||||||
let mut layers: Vec<AlbertLayerGroup> = vec!();
|
let mut layers: Vec<AlbertLayerGroup> = vec![];
|
||||||
for layer_index in 0..config.inner_group_num {
|
for layer_index in 0..config.inner_group_num {
|
||||||
layers.push(AlbertLayerGroup::new(&(p_layers / layer_index), config));
|
layers.push(AlbertLayerGroup::new(&(p_layers / layer_index), config));
|
||||||
};
|
}
|
||||||
|
|
||||||
AlbertTransformer {
|
AlbertTransformer {
|
||||||
output_hidden_states,
|
output_hidden_states,
|
||||||
@ -165,16 +209,24 @@ impl AlbertTransformer {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn forward_t(&self,
|
pub fn forward_t(
|
||||||
hidden_states: &Tensor,
|
&self,
|
||||||
mask: Option<Tensor>,
|
hidden_states: &Tensor,
|
||||||
train: bool)
|
mask: Option<Tensor>,
|
||||||
-> (Tensor, Option<Vec<Tensor>>, Option<Vec<Vec<Tensor>>>) {
|
train: bool,
|
||||||
|
) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Vec<Tensor>>>) {
|
||||||
let mut hidden_state = hidden_states.apply(&self.embedding_hidden_mapping_in);
|
let mut hidden_state = hidden_states.apply(&self.embedding_hidden_mapping_in);
|
||||||
|
|
||||||
let mut all_hidden_states: Option<Vec<Tensor>> = if self.output_hidden_states { Some(vec!()) } else { None };
|
let mut all_hidden_states: Option<Vec<Tensor>> = if self.output_hidden_states {
|
||||||
let mut all_attentions: Option<Vec<Vec<Tensor>>> = if self.output_attentions { Some(vec!()) } else { None };
|
Some(vec![])
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
let mut all_attentions: Option<Vec<Vec<Tensor>>> = if self.output_attentions {
|
||||||
|
Some(vec![])
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
for i in 0..self.num_hidden_layers {
|
for i in 0..self.num_hidden_layers {
|
||||||
let group_idx = i / (self.num_hidden_layers / self.num_hidden_groups);
|
let group_idx = i / (self.num_hidden_layers / self.num_hidden_groups);
|
||||||
@ -190,9 +242,8 @@ impl AlbertTransformer {
|
|||||||
if let Some(attentions) = all_attentions.borrow_mut() {
|
if let Some(attentions) = all_attentions.borrow_mut() {
|
||||||
attentions.push(attention_weights.unwrap());
|
attentions.push(attention_weights.unwrap());
|
||||||
};
|
};
|
||||||
};
|
}
|
||||||
|
|
||||||
(hidden_state, all_hidden_states, all_attentions)
|
(hidden_state, all_hidden_states, all_attentions)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -20,37 +20,46 @@
|
|||||||
//! Pretrained models are available and can be downloaded using RemoteResources.
|
//! Pretrained models are available and can be downloaded using RemoteResources.
|
||||||
//!
|
//!
|
||||||
//! ```no_run
|
//! ```no_run
|
||||||
//!# fn main() -> failure::Fallible<()> {
|
//! # fn main() -> failure::Fallible<()> {
|
||||||
//!#
|
//! #
|
||||||
//! use rust_tokenizers::AlbertTokenizer;
|
//! use rust_tokenizers::AlbertTokenizer;
|
||||||
//! use tch::{nn, Device};
|
//! use tch::{nn, Device};
|
||||||
//!# use std::path::PathBuf;
|
//! # use std::path::PathBuf;
|
||||||
//! use rust_bert::albert::{AlbertForMaskedLM, AlbertConfig};
|
//! use rust_bert::albert::{AlbertConfig, AlbertForMaskedLM};
|
||||||
|
//! use rust_bert::resources::{download_resource, LocalResource, Resource};
|
||||||
//! use rust_bert::Config;
|
//! use rust_bert::Config;
|
||||||
//! use rust_bert::resources::{Resource, download_resource, LocalResource};
|
|
||||||
//!
|
//!
|
||||||
//! let config_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/config.json")});
|
//! let config_resource = Resource::Local(LocalResource {
|
||||||
//! let vocab_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/vocab.txt")});
|
//! local_path: PathBuf::from("path/to/config.json"),
|
||||||
//! let weights_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/model.ot")});
|
//! });
|
||||||
|
//! let vocab_resource = Resource::Local(LocalResource {
|
||||||
|
//! local_path: PathBuf::from("path/to/vocab.txt"),
|
||||||
|
//! });
|
||||||
|
//! let weights_resource = Resource::Local(LocalResource {
|
||||||
|
//! local_path: PathBuf::from("path/to/model.ot"),
|
||||||
|
//! });
|
||||||
//! let config_path = download_resource(&config_resource)?;
|
//! let config_path = download_resource(&config_resource)?;
|
||||||
//! let vocab_path = download_resource(&vocab_resource)?;
|
//! let vocab_path = download_resource(&vocab_resource)?;
|
||||||
//! let weights_path = download_resource(&weights_resource)?;
|
//! let weights_path = download_resource(&weights_resource)?;
|
||||||
//! let device = Device::cuda_if_available();
|
//! let device = Device::cuda_if_available();
|
||||||
//! let mut vs = nn::VarStore::new(device);
|
//! let mut vs = nn::VarStore::new(device);
|
||||||
//! let tokenizer: AlbertTokenizer = AlbertTokenizer::from_file(vocab_path.to_str().unwrap(), true, true);
|
//! let tokenizer: AlbertTokenizer =
|
||||||
|
//! AlbertTokenizer::from_file(vocab_path.to_str().unwrap(), true, true);
|
||||||
//! let config = AlbertConfig::from_file(config_path);
|
//! let config = AlbertConfig::from_file(config_path);
|
||||||
//! let bert_model = AlbertForMaskedLM::new(&vs.root(), &config);
|
//! let bert_model = AlbertForMaskedLM::new(&vs.root(), &config);
|
||||||
//! vs.load(weights_path)?;
|
//! vs.load(weights_path)?;
|
||||||
//!
|
//!
|
||||||
//!# Ok(())
|
//! # Ok(())
|
||||||
//!# }
|
//! # }
|
||||||
//! ```
|
//! ```
|
||||||
|
|
||||||
|
mod albert;
|
||||||
|
|
||||||
mod encoder;
|
|
||||||
mod attention;
|
mod attention;
|
||||||
mod embeddings;
|
mod embeddings;
|
||||||
mod albert;
|
mod encoder;
|
||||||
|
|
||||||
pub use albert::{AlbertConfig, AlbertModelResources, AlbertConfigResources, AlbertVocabResources, AlbertModel, AlbertForMaskedLM, AlbertForSequenceClassification, AlbertForTokenClassification, AlbertForQuestionAnswering, AlbertForMultipleChoice};
|
pub use albert::{
|
||||||
|
AlbertConfig, AlbertConfigResources, AlbertForMaskedLM, AlbertForMultipleChoice,
|
||||||
|
AlbertForQuestionAnswering, AlbertForSequenceClassification, AlbertForTokenClassification,
|
||||||
|
AlbertModel, AlbertModelResources, AlbertVocabResources,
|
||||||
|
};
|
||||||
|
@ -12,8 +12,8 @@
|
|||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
use crate::common::dropout::Dropout;
|
use crate::common::dropout::Dropout;
|
||||||
use tch::{nn, Tensor};
|
|
||||||
use tch::kind::Kind::Float;
|
use tch::kind::Kind::Float;
|
||||||
|
use tch::{nn, Tensor};
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
/// # Cache for BART attention layers
|
/// # Cache for BART attention layers
|
||||||
@ -31,7 +31,7 @@ impl Clone for LayerState {
|
|||||||
fn clone(&self) -> Self {
|
fn clone(&self) -> Self {
|
||||||
let prev_key_padding_mask = match &self.prev_key_padding_mask {
|
let prev_key_padding_mask = match &self.prev_key_padding_mask {
|
||||||
Some(key_padding_mask) => Some(key_padding_mask.copy()),
|
Some(key_padding_mask) => Some(key_padding_mask.copy()),
|
||||||
None => None
|
None => None,
|
||||||
};
|
};
|
||||||
LayerState {
|
LayerState {
|
||||||
prev_key: self.prev_key.copy(),
|
prev_key: self.prev_key.copy(),
|
||||||
@ -46,12 +46,16 @@ impl LayerState {
|
|||||||
self.prev_key = self.prev_key.index_select(0, new_indices);
|
self.prev_key = self.prev_key.index_select(0, new_indices);
|
||||||
self.prev_value = self.prev_value.index_select(0, new_indices);
|
self.prev_value = self.prev_value.index_select(0, new_indices);
|
||||||
if self.prev_key_padding_mask.is_some() {
|
if self.prev_key_padding_mask.is_some() {
|
||||||
self.prev_key_padding_mask = Some(self.prev_key_padding_mask.as_ref().unwrap().index_select(0, new_indices));
|
self.prev_key_padding_mask = Some(
|
||||||
|
self.prev_key_padding_mask
|
||||||
|
.as_ref()
|
||||||
|
.unwrap()
|
||||||
|
.index_select(0, new_indices),
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct SelfAttention {
|
pub struct SelfAttention {
|
||||||
num_heads: i64,
|
num_heads: i64,
|
||||||
@ -68,8 +72,15 @@ pub struct SelfAttention {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl SelfAttention {
|
impl SelfAttention {
|
||||||
pub fn new(p: nn::Path, embed_dim: i64, num_heads: i64, dropout: f64,
|
pub fn new(
|
||||||
encoder_decoder_attention: bool, store_cache: bool, output_attentions: bool) -> SelfAttention {
|
p: nn::Path,
|
||||||
|
embed_dim: i64,
|
||||||
|
num_heads: i64,
|
||||||
|
dropout: f64,
|
||||||
|
encoder_decoder_attention: bool,
|
||||||
|
store_cache: bool,
|
||||||
|
output_attentions: bool,
|
||||||
|
) -> SelfAttention {
|
||||||
let k_proj = nn::linear(&p / "k_proj", embed_dim, embed_dim, Default::default());
|
let k_proj = nn::linear(&p / "k_proj", embed_dim, embed_dim, Default::default());
|
||||||
let v_proj = nn::linear(&p / "v_proj", embed_dim, embed_dim, Default::default());
|
let v_proj = nn::linear(&p / "v_proj", embed_dim, embed_dim, Default::default());
|
||||||
let q_proj = nn::linear(&p / "q_proj", embed_dim, embed_dim, Default::default());
|
let q_proj = nn::linear(&p / "q_proj", embed_dim, embed_dim, Default::default());
|
||||||
@ -95,58 +106,90 @@ impl SelfAttention {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn flatten(&self, x: Tensor, dim_0: i64, bs: i64) -> Tensor {
|
fn flatten(&self, x: Tensor, dim_0: i64, bs: i64) -> Tensor {
|
||||||
x.contiguous().view((dim_0, bs * self.num_heads, self.head_dim)).transpose(0, 1)
|
x.contiguous()
|
||||||
|
.view((dim_0, bs * self.num_heads, self.head_dim))
|
||||||
|
.transpose(0, 1)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn forward_t(&self, query: &Tensor,
|
pub fn forward_t(
|
||||||
key: Option<&Tensor>,
|
&self,
|
||||||
key_padding_mask: Option<&Tensor>,
|
query: &Tensor,
|
||||||
attention_mask: Option<&Tensor>,
|
key: Option<&Tensor>,
|
||||||
mut layer_state: Option<LayerState>,
|
key_padding_mask: Option<&Tensor>,
|
||||||
train: bool) -> (Tensor, Option<Tensor>, Option<LayerState>) {
|
attention_mask: Option<&Tensor>,
|
||||||
|
mut layer_state: Option<LayerState>,
|
||||||
|
train: bool,
|
||||||
|
) -> (Tensor, Option<Tensor>, Option<LayerState>) {
|
||||||
let query_size = query.size();
|
let query_size = query.size();
|
||||||
let (target_sequence_length, bs) = (query_size[0], query_size[1]);
|
let (target_sequence_length, bs) = (query_size[0], query_size[1]);
|
||||||
let q: Tensor = self.flatten(query.as_ref().apply(&self.q_proj) * self.scaling, target_sequence_length, bs);
|
let q: Tensor = self.flatten(
|
||||||
|
query.as_ref().apply(&self.q_proj) * self.scaling,
|
||||||
|
target_sequence_length,
|
||||||
|
bs,
|
||||||
|
);
|
||||||
let key = match &layer_state {
|
let key = match &layer_state {
|
||||||
Some(_) => { if self.encoder_decoder_attention { None } else { key } }
|
Some(_) => {
|
||||||
None => key
|
if self.encoder_decoder_attention {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
key
|
||||||
|
}
|
||||||
|
}
|
||||||
|
None => key,
|
||||||
};
|
};
|
||||||
|
|
||||||
let (k, v) = if self.encoder_decoder_attention {
|
let (k, v) = if self.encoder_decoder_attention {
|
||||||
match key {
|
match key {
|
||||||
Some(key) => {
|
Some(key) => (
|
||||||
(Some(self.flatten(key.apply(&self.k_proj), -1, bs)),
|
Some(self.flatten(key.apply(&self.k_proj), -1, bs)),
|
||||||
Some(self.flatten(key.apply(&self.v_proj), -1, bs))
|
Some(self.flatten(key.apply(&self.v_proj), -1, bs)),
|
||||||
)
|
),
|
||||||
}
|
None => (None, None),
|
||||||
None => (None, None)
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
(Some(self.flatten(query.apply(&self.k_proj), -1, bs)),
|
(
|
||||||
Some(self.flatten(query.apply(&self.v_proj), -1, bs))
|
Some(self.flatten(query.apply(&self.k_proj), -1, bs)),
|
||||||
|
Some(self.flatten(query.apply(&self.v_proj), -1, bs)),
|
||||||
)
|
)
|
||||||
};
|
};
|
||||||
|
|
||||||
let (k, v, key_padding_mask) = self.use_saved_state(&layer_state, k, v, key_padding_mask, bs);
|
let (k, v, key_padding_mask) =
|
||||||
|
self.use_saved_state(&layer_state, k, v, key_padding_mask, bs);
|
||||||
|
|
||||||
let source_sequence_length = k.size()[1];
|
let source_sequence_length = k.size()[1];
|
||||||
let attention_weights = q.bmm(&k.transpose(1, 2));
|
let attention_weights = q.bmm(&k.transpose(1, 2));
|
||||||
let attention_weights = match attention_mask {
|
let attention_weights = match attention_mask {
|
||||||
Some(mask) => {
|
Some(mask) => {
|
||||||
let attention_weights = attention_weights.view((bs, self.num_heads, target_sequence_length, source_sequence_length)) + mask;
|
let attention_weights = attention_weights.view((
|
||||||
attention_weights.view((bs * self.num_heads, target_sequence_length, source_sequence_length))
|
bs,
|
||||||
|
self.num_heads,
|
||||||
|
target_sequence_length,
|
||||||
|
source_sequence_length,
|
||||||
|
)) + mask;
|
||||||
|
attention_weights.view((
|
||||||
|
bs * self.num_heads,
|
||||||
|
target_sequence_length,
|
||||||
|
source_sequence_length,
|
||||||
|
))
|
||||||
}
|
}
|
||||||
None => attention_weights
|
None => attention_weights,
|
||||||
};
|
};
|
||||||
|
|
||||||
let attention_weights = match key_padding_mask.as_ref() {
|
let attention_weights = match key_padding_mask.as_ref() {
|
||||||
Some(mask) => {
|
Some(mask) => attention_weights
|
||||||
attention_weights
|
.view((
|
||||||
.view((bs, self.num_heads, target_sequence_length, source_sequence_length))
|
bs,
|
||||||
.masked_fill(&mask.unsqueeze(1).unsqueeze(2), std::f64::NEG_INFINITY)
|
self.num_heads,
|
||||||
.view((bs * self.num_heads, target_sequence_length, source_sequence_length))
|
target_sequence_length,
|
||||||
}
|
source_sequence_length,
|
||||||
None => attention_weights
|
))
|
||||||
|
.masked_fill(&mask.unsqueeze(1).unsqueeze(2), std::f64::NEG_INFINITY)
|
||||||
|
.view((
|
||||||
|
bs * self.num_heads,
|
||||||
|
target_sequence_length,
|
||||||
|
source_sequence_length,
|
||||||
|
)),
|
||||||
|
None => attention_weights,
|
||||||
};
|
};
|
||||||
|
|
||||||
let attention_weights = attention_weights.softmax(-1, Float);
|
let attention_weights = attention_weights.softmax(-1, Float);
|
||||||
@ -159,16 +202,25 @@ impl SelfAttention {
|
|||||||
.apply(&self.out_proj);
|
.apply(&self.out_proj);
|
||||||
|
|
||||||
let attention_weights = if self.output_attentions {
|
let attention_weights = if self.output_attentions {
|
||||||
Some(attention_weights.view((bs, self.num_heads, target_sequence_length, source_sequence_length)))
|
Some(attention_weights.view((
|
||||||
} else { None };
|
bs,
|
||||||
|
self.num_heads,
|
||||||
|
target_sequence_length,
|
||||||
|
source_sequence_length,
|
||||||
|
)))
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
if self.store_cache {
|
if self.store_cache {
|
||||||
if layer_state.is_some() {
|
if layer_state.is_some() {
|
||||||
layer_state.as_mut().unwrap().prev_key = k.view((bs, self.num_heads, -1, self.head_dim));
|
layer_state.as_mut().unwrap().prev_key =
|
||||||
layer_state.as_mut().unwrap().prev_value = v.view((bs, self.num_heads, -1, self.head_dim));
|
k.view((bs, self.num_heads, -1, self.head_dim));
|
||||||
|
layer_state.as_mut().unwrap().prev_value =
|
||||||
|
v.view((bs, self.num_heads, -1, self.head_dim));
|
||||||
layer_state.as_mut().unwrap().prev_key_padding_mask = match key_padding_mask {
|
layer_state.as_mut().unwrap().prev_key_padding_mask = match key_padding_mask {
|
||||||
Some(tensor) => Some(tensor),
|
Some(tensor) => Some(tensor),
|
||||||
None => None
|
None => None,
|
||||||
};
|
};
|
||||||
} else {
|
} else {
|
||||||
layer_state = Some(LayerState {
|
layer_state = Some(LayerState {
|
||||||
@ -176,7 +228,7 @@ impl SelfAttention {
|
|||||||
prev_value: v.view((bs, self.num_heads, -1, self.head_dim)),
|
prev_value: v.view((bs, self.num_heads, -1, self.head_dim)),
|
||||||
prev_key_padding_mask: match key_padding_mask {
|
prev_key_padding_mask: match key_padding_mask {
|
||||||
Some(tensor) => Some(tensor),
|
Some(tensor) => Some(tensor),
|
||||||
None => None
|
None => None,
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
};
|
};
|
||||||
@ -185,17 +237,23 @@ impl SelfAttention {
|
|||||||
(output, attention_weights, layer_state)
|
(output, attention_weights, layer_state)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn use_saved_state(&self,
|
fn use_saved_state(
|
||||||
layer_state: &Option<LayerState>,
|
&self,
|
||||||
k: Option<Tensor>,
|
layer_state: &Option<LayerState>,
|
||||||
v: Option<Tensor>,
|
k: Option<Tensor>,
|
||||||
key_padding_mask: Option<&Tensor>,
|
v: Option<Tensor>,
|
||||||
bs: i64)
|
key_padding_mask: Option<&Tensor>,
|
||||||
-> (Tensor, Tensor, Option<Tensor>) {
|
bs: i64,
|
||||||
|
) -> (Tensor, Tensor, Option<Tensor>) {
|
||||||
match &layer_state {
|
match &layer_state {
|
||||||
Some(prev_state) => {
|
Some(prev_state) => {
|
||||||
let prev_key = prev_state.prev_key.view((bs * self.num_heads, -1, self.head_dim));
|
let prev_key = prev_state
|
||||||
let prev_value = prev_state.prev_value.view((bs * self.num_heads, -1, self.head_dim));
|
.prev_key
|
||||||
|
.view((bs * self.num_heads, -1, self.head_dim));
|
||||||
|
let prev_value =
|
||||||
|
prev_state
|
||||||
|
.prev_value
|
||||||
|
.view((bs * self.num_heads, -1, self.head_dim));
|
||||||
let k = if self.encoder_decoder_attention {
|
let k = if self.encoder_decoder_attention {
|
||||||
prev_key
|
prev_key
|
||||||
} else {
|
} else {
|
||||||
@ -207,39 +265,54 @@ impl SelfAttention {
|
|||||||
Tensor::cat(&[prev_value, v.unwrap()], 1)
|
Tensor::cat(&[prev_value, v.unwrap()], 1)
|
||||||
};
|
};
|
||||||
|
|
||||||
let key_padding_mask = self.use_saved_key_padding_mask(key_padding_mask,
|
let key_padding_mask = self.use_saved_key_padding_mask(
|
||||||
&prev_state.prev_key_padding_mask,
|
key_padding_mask,
|
||||||
bs,
|
&prev_state.prev_key_padding_mask,
|
||||||
k.size()[1]);
|
bs,
|
||||||
|
k.size()[1],
|
||||||
|
);
|
||||||
(k, v, key_padding_mask)
|
(k, v, key_padding_mask)
|
||||||
}
|
}
|
||||||
None => {
|
None => {
|
||||||
let key_padding_mask = match key_padding_mask {
|
let key_padding_mask = match key_padding_mask {
|
||||||
Some(value) => Some(value.copy()),
|
Some(value) => Some(value.copy()),
|
||||||
None => None
|
None => None,
|
||||||
};
|
};
|
||||||
(k.unwrap(), v.unwrap(), key_padding_mask)
|
(k.unwrap(), v.unwrap(), key_padding_mask)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn use_saved_key_padding_mask(&self, key_padding_mask: Option<&Tensor>, prev_key_padding_mask: &Option<Tensor>,
|
fn use_saved_key_padding_mask(
|
||||||
bs: i64, sequence_length: i64) -> Option<Tensor> {
|
&self,
|
||||||
|
key_padding_mask: Option<&Tensor>,
|
||||||
|
prev_key_padding_mask: &Option<Tensor>,
|
||||||
|
bs: i64,
|
||||||
|
sequence_length: i64,
|
||||||
|
) -> Option<Tensor> {
|
||||||
if prev_key_padding_mask.is_some() {
|
if prev_key_padding_mask.is_some() {
|
||||||
if self.encoder_decoder_attention {
|
if self.encoder_decoder_attention {
|
||||||
Some(prev_key_padding_mask.as_ref().unwrap().copy())
|
Some(prev_key_padding_mask.as_ref().unwrap().copy())
|
||||||
} else {
|
} else {
|
||||||
Some(Tensor::cat(&[prev_key_padding_mask.as_ref().unwrap(), key_padding_mask.as_ref().unwrap()], 1))
|
Some(Tensor::cat(
|
||||||
|
&[
|
||||||
|
prev_key_padding_mask.as_ref().unwrap(),
|
||||||
|
key_padding_mask.as_ref().unwrap(),
|
||||||
|
],
|
||||||
|
1,
|
||||||
|
))
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
match key_padding_mask {
|
match key_padding_mask {
|
||||||
Some(key_padding_mask) => {
|
Some(key_padding_mask) => {
|
||||||
let filler = Tensor::zeros(&[bs, sequence_length - key_padding_mask.size()[1]],
|
let filler = Tensor::zeros(
|
||||||
(key_padding_mask.kind(), key_padding_mask.device()));
|
&[bs, sequence_length - key_padding_mask.size()[1]],
|
||||||
|
(key_padding_mask.kind(), key_padding_mask.device()),
|
||||||
|
);
|
||||||
Some(Tensor::cat(&[filler, key_padding_mask.copy()], 1))
|
Some(Tensor::cat(&[filler, key_padding_mask.copy()], 1))
|
||||||
}
|
}
|
||||||
None => None
|
None => None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
621
src/bart/bart.rs
621
src/bart/bart.rs
@ -11,18 +11,18 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
use serde::{Deserialize, Serialize};
|
|
||||||
use std::collections::HashMap;
|
|
||||||
use crate::Config;
|
|
||||||
use tch::{Tensor, nn};
|
|
||||||
use tch::kind::Kind::{Int64, Float};
|
|
||||||
use crate::bart::encoder::BartEncoder;
|
|
||||||
use crate::bart::decoder::BartDecoder;
|
|
||||||
use tch::nn::{embedding, EmbeddingConfig};
|
|
||||||
use crate::bart::attention::LayerState;
|
use crate::bart::attention::LayerState;
|
||||||
use std::borrow::BorrowMut;
|
use crate::bart::decoder::BartDecoder;
|
||||||
|
use crate::bart::encoder::BartEncoder;
|
||||||
use crate::common::dropout::Dropout;
|
use crate::common::dropout::Dropout;
|
||||||
use crate::pipelines::generation::{Cache, LMHeadModel};
|
use crate::pipelines::generation::{Cache, LMHeadModel};
|
||||||
|
use crate::Config;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::borrow::BorrowMut;
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use tch::kind::Kind::{Float, Int64};
|
||||||
|
use tch::nn::{embedding, EmbeddingConfig};
|
||||||
|
use tch::{nn, Tensor};
|
||||||
|
|
||||||
/// # BART Pretrained model weight files
|
/// # BART Pretrained model weight files
|
||||||
pub struct BartModelResources;
|
pub struct BartModelResources;
|
||||||
@ -38,38 +38,74 @@ pub struct BartMergesResources;
|
|||||||
|
|
||||||
impl BartModelResources {
|
impl BartModelResources {
|
||||||
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
|
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
|
||||||
pub const BART: (&'static str, &'static str) = ("bart/model.ot", "https://cdn.huggingface.co/facebook/bart-large/rust_model.ot");
|
pub const BART: (&'static str, &'static str) = (
|
||||||
|
"bart/model.ot",
|
||||||
|
"https://cdn.huggingface.co/facebook/bart-large/rust_model.ot",
|
||||||
|
);
|
||||||
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
|
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
|
||||||
pub const BART_CNN: (&'static str, &'static str) = ("bart-cnn/model.ot", "https://cdn.huggingface.co/facebook/bart-large-cnn/rust_model.ot");
|
pub const BART_CNN: (&'static str, &'static str) = (
|
||||||
|
"bart-cnn/model.ot",
|
||||||
|
"https://cdn.huggingface.co/facebook/bart-large-cnn/rust_model.ot",
|
||||||
|
);
|
||||||
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
|
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
|
||||||
pub const BART_XSUM: (&'static str, &'static str) = ("bart-xsum/model.ot", "https://cdn.huggingface.co/facebook/bart-large-xsum/rust_model.ot");
|
pub const BART_XSUM: (&'static str, &'static str) = (
|
||||||
|
"bart-xsum/model.ot",
|
||||||
|
"https://cdn.huggingface.co/facebook/bart-large-xsum/rust_model.ot",
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
impl BartConfigResources {
|
impl BartConfigResources {
|
||||||
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
|
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
|
||||||
pub const BART: (&'static str, &'static str) = ("bart/config.json", "https://cdn.huggingface.co/facebook/bart-large/config.json");
|
pub const BART: (&'static str, &'static str) = (
|
||||||
|
"bart/config.json",
|
||||||
|
"https://cdn.huggingface.co/facebook/bart-large/config.json",
|
||||||
|
);
|
||||||
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
|
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
|
||||||
pub const BART_CNN: (&'static str, &'static str) = ("bart-cnn/config.json", "https://cdn.huggingface.co/facebook/bart-large-cnn/config.json");
|
pub const BART_CNN: (&'static str, &'static str) = (
|
||||||
|
"bart-cnn/config.json",
|
||||||
|
"https://cdn.huggingface.co/facebook/bart-large-cnn/config.json",
|
||||||
|
);
|
||||||
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
|
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
|
||||||
pub const BART_XSUM: (&'static str, &'static str) = ("bart-xsum/config.json", "https://cdn.huggingface.co/facebook/bart-large-xsum/config.json");
|
pub const BART_XSUM: (&'static str, &'static str) = (
|
||||||
|
"bart-xsum/config.json",
|
||||||
|
"https://cdn.huggingface.co/facebook/bart-large-xsum/config.json",
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
impl BartVocabResources {
|
impl BartVocabResources {
|
||||||
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
|
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
|
||||||
pub const BART: (&'static str, &'static str) = ("bart/vocab.txt", "https://cdn.huggingface.co/roberta-large-vocab.json");
|
pub const BART: (&'static str, &'static str) = (
|
||||||
|
"bart/vocab.txt",
|
||||||
|
"https://cdn.huggingface.co/roberta-large-vocab.json",
|
||||||
|
);
|
||||||
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
|
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
|
||||||
pub const BART_CNN: (&'static str, &'static str) = ("bart-cnn/vocab.txt", "https://cdn.huggingface.co/roberta-large-vocab.json");
|
pub const BART_CNN: (&'static str, &'static str) = (
|
||||||
|
"bart-cnn/vocab.txt",
|
||||||
|
"https://cdn.huggingface.co/roberta-large-vocab.json",
|
||||||
|
);
|
||||||
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
|
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
|
||||||
pub const BART_XSUM: (&'static str, &'static str) = ("bart-xsum/vocab.txt", "https://cdn.huggingface.co/roberta-large-vocab.json");
|
pub const BART_XSUM: (&'static str, &'static str) = (
|
||||||
|
"bart-xsum/vocab.txt",
|
||||||
|
"https://cdn.huggingface.co/roberta-large-vocab.json",
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
impl BartMergesResources {
|
impl BartMergesResources {
|
||||||
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
|
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
|
||||||
pub const BART: (&'static str, &'static str) = ("bart/merges.txt", "https://cdn.huggingface.co/roberta-large-merges.txt");
|
pub const BART: (&'static str, &'static str) = (
|
||||||
|
"bart/merges.txt",
|
||||||
|
"https://cdn.huggingface.co/roberta-large-merges.txt",
|
||||||
|
);
|
||||||
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
|
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
|
||||||
pub const BART_CNN: (&'static str, &'static str) = ("bart-cnn/merges.txt", "https://cdn.huggingface.co/roberta-large-merges.txt");
|
pub const BART_CNN: (&'static str, &'static str) = (
|
||||||
|
"bart-cnn/merges.txt",
|
||||||
|
"https://cdn.huggingface.co/roberta-large-merges.txt",
|
||||||
|
);
|
||||||
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
|
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
|
||||||
pub const BART_XSUM: (&'static str, &'static str) = ("bart-xsum/merges.txt", "https://cdn.huggingface.co/roberta-large-merges.txt");
|
pub const BART_XSUM: (&'static str, &'static str) = (
|
||||||
|
"bart-xsum/merges.txt",
|
||||||
|
"https://cdn.huggingface.co/roberta-large-merges.txt",
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(non_camel_case_types)]
|
#[allow(non_camel_case_types)]
|
||||||
@ -130,14 +166,15 @@ pub struct BartConfig {
|
|||||||
|
|
||||||
impl Config<BartConfig> for BartConfig {}
|
impl Config<BartConfig> for BartConfig {}
|
||||||
|
|
||||||
fn _prepare_bart_decoder_inputs(pad_token_id: i64,
|
fn _prepare_bart_decoder_inputs(
|
||||||
input_ids: &Tensor,
|
pad_token_id: i64,
|
||||||
decoder_input_ids: Option<&Tensor>,
|
input_ids: &Tensor,
|
||||||
decoder_padding_mask: Option<&Tensor>)
|
decoder_input_ids: Option<&Tensor>,
|
||||||
-> (Tensor, Option<Tensor>, Option<Tensor>) {
|
decoder_padding_mask: Option<&Tensor>,
|
||||||
|
) -> (Tensor, Option<Tensor>, Option<Tensor>) {
|
||||||
let decoder_input_ids = match decoder_input_ids {
|
let decoder_input_ids = match decoder_input_ids {
|
||||||
Some(value) => value.copy(),
|
Some(value) => value.copy(),
|
||||||
None => _shift_tokens_right(input_ids, pad_token_id)
|
None => _shift_tokens_right(input_ids, pad_token_id),
|
||||||
};
|
};
|
||||||
|
|
||||||
let decoder_padding_mask = match decoder_padding_mask {
|
let decoder_padding_mask = match decoder_padding_mask {
|
||||||
@ -159,7 +196,6 @@ fn _prepare_bart_decoder_inputs(pad_token_id: i64,
|
|||||||
(decoder_input_ids, decoder_padding_mask, Some(causal_mask))
|
(decoder_input_ids, decoder_padding_mask, Some(causal_mask))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
fn _shift_tokens_right(input_ids: &Tensor, pad_token_id: i64) -> Tensor {
|
fn _shift_tokens_right(input_ids: &Tensor, pad_token_id: i64) -> Tensor {
|
||||||
let index_eos: Tensor = input_ids.ne(pad_token_id).sum1(&[-1], true, Int64) - 1;
|
let index_eos: Tensor = input_ids.ne(pad_token_id).sum1(&[-1], true, Int64) - 1;
|
||||||
let output = input_ids.empty_like().to_kind(Int64);
|
let output = input_ids.empty_like().to_kind(Int64);
|
||||||
@ -200,10 +236,10 @@ impl BartModel {
|
|||||||
/// # Example
|
/// # Example
|
||||||
///
|
///
|
||||||
/// ```no_run
|
/// ```no_run
|
||||||
/// use tch::{nn, Device};
|
/// use rust_bert::bart::{BartConfig, BartModel};
|
||||||
/// use rust_bert::Config;
|
/// use rust_bert::Config;
|
||||||
/// use std::path::Path;
|
/// use std::path::Path;
|
||||||
/// use rust_bert::bart::{BartConfig, BartModel};
|
/// use tch::{nn, Device};
|
||||||
///
|
///
|
||||||
/// let config_path = Path::new("path/to/config.json");
|
/// let config_path = Path::new("path/to/config.json");
|
||||||
/// let device = Device::Cpu;
|
/// let device = Device::Cpu;
|
||||||
@ -212,22 +248,32 @@ impl BartModel {
|
|||||||
/// let generation_mode = true;
|
/// let generation_mode = true;
|
||||||
/// let bart: BartModel = BartModel::new(&(&p.root() / "bart"), &config, generation_mode);
|
/// let bart: BartModel = BartModel::new(&(&p.root() / "bart"), &config, generation_mode);
|
||||||
/// ```
|
/// ```
|
||||||
///
|
|
||||||
pub fn new(p: &nn::Path, config: &BartConfig, generation_mode: bool) -> BartModel {
|
pub fn new(p: &nn::Path, config: &BartConfig, generation_mode: bool) -> BartModel {
|
||||||
let pad_token_id = match config.pad_token_id {
|
let pad_token_id = match config.pad_token_id {
|
||||||
Some(value) => value,
|
Some(value) => value,
|
||||||
None => 1
|
None => 1,
|
||||||
};
|
};
|
||||||
let embedding_config = EmbeddingConfig { padding_idx: pad_token_id, ..Default::default() };
|
let embedding_config = EmbeddingConfig {
|
||||||
let embeddings: nn::Embedding = embedding(p / "shared",
|
padding_idx: pad_token_id,
|
||||||
config.vocab_size,
|
..Default::default()
|
||||||
config.d_model,
|
};
|
||||||
embedding_config);
|
let embeddings: nn::Embedding = embedding(
|
||||||
|
p / "shared",
|
||||||
|
config.vocab_size,
|
||||||
|
config.d_model,
|
||||||
|
embedding_config,
|
||||||
|
);
|
||||||
|
|
||||||
let encoder = BartEncoder::new(p / "encoder", config);
|
let encoder = BartEncoder::new(p / "encoder", config);
|
||||||
let decoder = BartDecoder::new(p / "decoder", config, generation_mode);
|
let decoder = BartDecoder::new(p / "decoder", config, generation_mode);
|
||||||
|
|
||||||
BartModel { encoder, decoder, generation_mode, pad_token_id, embeddings }
|
BartModel {
|
||||||
|
encoder,
|
||||||
|
decoder,
|
||||||
|
generation_mode,
|
||||||
|
pad_token_id,
|
||||||
|
embeddings,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Forward pass through the model
|
/// Forward pass through the model
|
||||||
@ -257,82 +303,116 @@ impl BartModel {
|
|||||||
/// # Example
|
/// # Example
|
||||||
///
|
///
|
||||||
/// ```no_run
|
/// ```no_run
|
||||||
///# use tch::{nn, Device, Tensor, no_grad};
|
/// # use tch::{nn, Device, Tensor, no_grad};
|
||||||
///# use rust_bert::Config;
|
/// # use rust_bert::Config;
|
||||||
///# use std::path::Path;
|
/// # use std::path::Path;
|
||||||
///# use tch::kind::Kind::{Int64, Double};
|
/// # use tch::kind::Kind::{Int64, Double};
|
||||||
/// use rust_bert::bart::{BartConfig, BartModel};
|
/// use rust_bert::bart::{BartConfig, BartModel};
|
||||||
///# let config_path = Path::new("path/to/config.json");
|
/// # let config_path = Path::new("path/to/config.json");
|
||||||
///# let vocab_path = Path::new("path/to/vocab.txt");
|
/// # let vocab_path = Path::new("path/to/vocab.txt");
|
||||||
///# let device = Device::Cpu;
|
/// # let device = Device::Cpu;
|
||||||
///# let vs = nn::VarStore::new(device);
|
/// # let vs = nn::VarStore::new(device);
|
||||||
///# let config = BartConfig::from_file(config_path);
|
/// # let config = BartConfig::from_file(config_path);
|
||||||
///# let bart_model: BartModel = BartModel::new(&vs.root(), &config, false);
|
/// # let bart_model: BartModel = BartModel::new(&vs.root(), &config, false);
|
||||||
/// let (batch_size, source_sequence_length, target_sequence_length) = (64, 128, 56);
|
/// let (batch_size, source_sequence_length, target_sequence_length) = (64, 128, 56);
|
||||||
/// let input_tensor = Tensor::rand(&[batch_size, source_sequence_length], (Int64, device));
|
/// let input_tensor = Tensor::rand(&[batch_size, source_sequence_length], (Int64, device));
|
||||||
/// let target_tensor = Tensor::rand(&[batch_size, target_sequence_length], (Int64, device));
|
/// let target_tensor = Tensor::rand(&[batch_size, target_sequence_length], (Int64, device));
|
||||||
/// let encoder_attention_mask = Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
|
/// let encoder_attention_mask =
|
||||||
/// let decoder_attention_mask = Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
|
/// Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
|
||||||
///
|
/// let decoder_attention_mask =
|
||||||
/// let (decoder_output, encoder_hidden_states, decoder_cache,
|
/// Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
|
||||||
/// all_encoder_hidden_states, all_encoder_attentions,
|
|
||||||
/// all_decoder_hidden_states, all_decoder_attentions) = no_grad(|| {
|
|
||||||
/// bart_model
|
|
||||||
/// .forward_t(Some(&input_tensor),
|
|
||||||
/// Some(&encoder_attention_mask),
|
|
||||||
/// Some(&target_tensor),
|
|
||||||
/// None,
|
|
||||||
/// Some(&decoder_attention_mask),
|
|
||||||
/// None,
|
|
||||||
/// false)
|
|
||||||
/// });
|
|
||||||
///
|
///
|
||||||
|
/// let (
|
||||||
|
/// decoder_output,
|
||||||
|
/// encoder_hidden_states,
|
||||||
|
/// decoder_cache,
|
||||||
|
/// all_encoder_hidden_states,
|
||||||
|
/// all_encoder_attentions,
|
||||||
|
/// all_decoder_hidden_states,
|
||||||
|
/// all_decoder_attentions,
|
||||||
|
/// ) = no_grad(|| {
|
||||||
|
/// bart_model.forward_t(
|
||||||
|
/// Some(&input_tensor),
|
||||||
|
/// Some(&encoder_attention_mask),
|
||||||
|
/// Some(&target_tensor),
|
||||||
|
/// None,
|
||||||
|
/// Some(&decoder_attention_mask),
|
||||||
|
/// None,
|
||||||
|
/// false,
|
||||||
|
/// )
|
||||||
|
/// });
|
||||||
/// ```
|
/// ```
|
||||||
///
|
pub fn forward_t(
|
||||||
pub fn forward_t(&self,
|
&self,
|
||||||
input_ids: Option<&Tensor>,
|
input_ids: Option<&Tensor>,
|
||||||
attention_mask: Option<&Tensor>,
|
attention_mask: Option<&Tensor>,
|
||||||
decoder_input_ids: Option<&Tensor>,
|
decoder_input_ids: Option<&Tensor>,
|
||||||
encoder_outputs: Option<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>)>,
|
encoder_outputs: Option<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>)>,
|
||||||
decoder_attention_mask: Option<&Tensor>,
|
decoder_attention_mask: Option<&Tensor>,
|
||||||
layer_states: Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
|
layer_states: Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
|
||||||
train: bool) ->
|
train: bool,
|
||||||
(Tensor, Tensor, Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
|
) -> (
|
||||||
Option<Vec<Tensor>>, Option<Vec<Tensor>>,
|
Tensor,
|
||||||
Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
|
Tensor,
|
||||||
|
Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
|
||||||
|
Option<Vec<Tensor>>,
|
||||||
|
Option<Vec<Tensor>>,
|
||||||
|
Option<Vec<Tensor>>,
|
||||||
|
Option<Vec<Tensor>>,
|
||||||
|
) {
|
||||||
let (decoder_input_ids, decoder_padding_mask, causal_mask) = if self.generation_mode {
|
let (decoder_input_ids, decoder_padding_mask, causal_mask) = if self.generation_mode {
|
||||||
(decoder_input_ids.unwrap().copy(), None, None)
|
(decoder_input_ids.unwrap().copy(), None, None)
|
||||||
} else {
|
} else {
|
||||||
assert!(input_ids.is_some(), "input_ids must be provided when not in generation mode");
|
assert!(
|
||||||
_prepare_bart_decoder_inputs(self.pad_token_id, input_ids.unwrap(), decoder_input_ids, decoder_attention_mask)
|
input_ids.is_some(),
|
||||||
|
"input_ids must be provided when not in generation mode"
|
||||||
|
);
|
||||||
|
_prepare_bart_decoder_inputs(
|
||||||
|
self.pad_token_id,
|
||||||
|
input_ids.unwrap(),
|
||||||
|
decoder_input_ids,
|
||||||
|
decoder_attention_mask,
|
||||||
|
)
|
||||||
};
|
};
|
||||||
|
|
||||||
let (encoder_hidden_states,
|
let (encoder_hidden_states, all_encoder_hidden_states, all_encoder_attentions) =
|
||||||
all_encoder_hidden_states,
|
match encoder_outputs {
|
||||||
all_encoder_attentions) = match encoder_outputs {
|
Some(value) => value,
|
||||||
Some(value) => value,
|
None => {
|
||||||
None => {
|
assert!(
|
||||||
assert!(input_ids.is_some(), "input_ids must be provided when encoder output is not pre-computed");
|
input_ids.is_some(),
|
||||||
self.encoder.forward_t(input_ids.unwrap(), attention_mask, &self.embeddings, train)
|
"input_ids must be provided when encoder output is not pre-computed"
|
||||||
}
|
);
|
||||||
};
|
self.encoder.forward_t(
|
||||||
|
input_ids.unwrap(),
|
||||||
|
attention_mask,
|
||||||
|
&self.embeddings,
|
||||||
|
train,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
let (decoder_outputs,
|
let (decoder_outputs, decoder_cache, all_decoder_hidden_states, all_decoder_attentions) =
|
||||||
decoder_cache,
|
self.decoder.forward_t(
|
||||||
|
&decoder_input_ids,
|
||||||
|
&encoder_hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
decoder_padding_mask.as_ref(),
|
||||||
|
causal_mask.as_ref(),
|
||||||
|
&self.embeddings,
|
||||||
|
layer_states,
|
||||||
|
train,
|
||||||
|
);
|
||||||
|
(
|
||||||
|
decoder_outputs,
|
||||||
|
encoder_hidden_states,
|
||||||
|
decoder_cache.1,
|
||||||
all_decoder_hidden_states,
|
all_decoder_hidden_states,
|
||||||
all_decoder_attentions) = self.decoder.forward_t(&decoder_input_ids,
|
all_decoder_attentions,
|
||||||
&encoder_hidden_states,
|
all_encoder_hidden_states,
|
||||||
attention_mask,
|
all_encoder_attentions,
|
||||||
decoder_padding_mask.as_ref(),
|
)
|
||||||
causal_mask.as_ref(),
|
|
||||||
&self.embeddings,
|
|
||||||
layer_states,
|
|
||||||
train);
|
|
||||||
(decoder_outputs, encoder_hidden_states, decoder_cache.1,
|
|
||||||
all_decoder_hidden_states, all_decoder_attentions,
|
|
||||||
all_encoder_hidden_states, all_encoder_attentions)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// # BART Model for conditional generation
|
/// # BART Model for conditional generation
|
||||||
@ -356,20 +436,24 @@ impl BartForConditionalGeneration {
|
|||||||
/// # Example
|
/// # Example
|
||||||
///
|
///
|
||||||
/// ```no_run
|
/// ```no_run
|
||||||
/// use tch::{nn, Device};
|
/// use rust_bert::bart::{BartConfig, BartForConditionalGeneration};
|
||||||
/// use rust_bert::Config;
|
/// use rust_bert::Config;
|
||||||
/// use std::path::Path;
|
/// use std::path::Path;
|
||||||
/// use rust_bert::bart::{BartConfig, BartForConditionalGeneration};
|
/// use tch::{nn, Device};
|
||||||
///
|
///
|
||||||
/// let config_path = Path::new("path/to/config.json");
|
/// let config_path = Path::new("path/to/config.json");
|
||||||
/// let device = Device::Cpu;
|
/// let device = Device::Cpu;
|
||||||
/// let p = nn::VarStore::new(device);
|
/// let p = nn::VarStore::new(device);
|
||||||
/// let config = BartConfig::from_file(config_path);
|
/// let config = BartConfig::from_file(config_path);
|
||||||
/// let generation_mode = true;
|
/// let generation_mode = true;
|
||||||
/// let bart: BartForConditionalGeneration = BartForConditionalGeneration::new(&(&p.root() / "bart"), &config, generation_mode);
|
/// let bart: BartForConditionalGeneration =
|
||||||
|
/// BartForConditionalGeneration::new(&(&p.root() / "bart"), &config, generation_mode);
|
||||||
/// ```
|
/// ```
|
||||||
///
|
pub fn new(
|
||||||
pub fn new(p: &nn::Path, config: &BartConfig, generation_mode: bool) -> BartForConditionalGeneration {
|
p: &nn::Path,
|
||||||
|
config: &BartConfig,
|
||||||
|
generation_mode: bool,
|
||||||
|
) -> BartForConditionalGeneration {
|
||||||
let base_model = BartModel::new(&(p / "model"), config, generation_mode);
|
let base_model = BartModel::new(&(p / "model"), config, generation_mode);
|
||||||
BartForConditionalGeneration { base_model }
|
BartForConditionalGeneration { base_model }
|
||||||
}
|
}
|
||||||
@ -398,17 +482,17 @@ impl BartForConditionalGeneration {
|
|||||||
/// # Example
|
/// # Example
|
||||||
///
|
///
|
||||||
/// ```no_run
|
/// ```no_run
|
||||||
///# use tch::{nn, Device, Tensor, no_grad};
|
/// # use tch::{nn, Device, Tensor, no_grad};
|
||||||
///# use rust_bert::Config;
|
/// # use rust_bert::Config;
|
||||||
///# use std::path::Path;
|
/// # use std::path::Path;
|
||||||
///# use tch::kind::Kind::{Int64, Double};
|
/// # use tch::kind::Kind::{Int64, Double};
|
||||||
/// use rust_bert::bart::{BartConfig, BartForConditionalGeneration};
|
/// use rust_bert::bart::{BartConfig, BartForConditionalGeneration};
|
||||||
///# let config_path = Path::new("path/to/config.json");
|
/// # let config_path = Path::new("path/to/config.json");
|
||||||
///# let vocab_path = Path::new("path/to/vocab.txt");
|
/// # let vocab_path = Path::new("path/to/vocab.txt");
|
||||||
///# let device = Device::Cpu;
|
/// # let device = Device::Cpu;
|
||||||
///# let vs = nn::VarStore::new(device);
|
/// # let vs = nn::VarStore::new(device);
|
||||||
///# let config = BartConfig::from_file(config_path);
|
/// # let config = BartConfig::from_file(config_path);
|
||||||
///# let bart_model: BartForConditionalGeneration = BartForConditionalGeneration::new(&vs.root(), &config, false);
|
/// # let bart_model: BartForConditionalGeneration = BartForConditionalGeneration::new(&vs.root(), &config, false);
|
||||||
/// let (batch_size, source_sequence_length, target_sequence_length) = (64, 128, 56);
|
/// let (batch_size, source_sequence_length, target_sequence_length) = (64, 128, 56);
|
||||||
/// let input_tensor = Tensor::rand(&[batch_size, source_sequence_length], (Int64, device));
|
/// let input_tensor = Tensor::rand(&[batch_size, source_sequence_length], (Int64, device));
|
||||||
/// let target_tensor = Tensor::rand(&[batch_size, target_sequence_length], (Int64, device));
|
/// let target_tensor = Tensor::rand(&[batch_size, target_sequence_length], (Int64, device));
|
||||||
@ -427,37 +511,64 @@ impl BartForConditionalGeneration {
|
|||||||
/// None,
|
/// None,
|
||||||
/// false)
|
/// false)
|
||||||
/// });
|
/// });
|
||||||
///
|
|
||||||
/// ```
|
/// ```
|
||||||
///
|
pub fn forward_t(
|
||||||
pub fn forward_t(&self,
|
&self,
|
||||||
input_ids: Option<&Tensor>,
|
input_ids: Option<&Tensor>,
|
||||||
attention_mask: Option<&Tensor>,
|
attention_mask: Option<&Tensor>,
|
||||||
encoder_outputs: Option<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>)>,
|
encoder_outputs: Option<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>)>,
|
||||||
decoder_input_ids: Option<&Tensor>,
|
decoder_input_ids: Option<&Tensor>,
|
||||||
decoder_attention_mask: Option<&Tensor>,
|
decoder_attention_mask: Option<&Tensor>,
|
||||||
old_layer_states: Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
|
old_layer_states: Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
|
||||||
train: bool)
|
train: bool,
|
||||||
-> (Tensor, Tensor, Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
|
) -> (
|
||||||
Option<Vec<Tensor>>, Option<Vec<Tensor>>,
|
Tensor,
|
||||||
Option<Vec<Tensor>>, Option<Vec<Tensor>>)
|
Tensor,
|
||||||
{
|
Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
|
||||||
let (decoder_outputs, encoder_hidden_states, decoder_cache,
|
Option<Vec<Tensor>>,
|
||||||
all_decoder_hidden_states, all_decoder_attentions,
|
Option<Vec<Tensor>>,
|
||||||
all_encoder_hidden_states, all_encoder_attentions) =
|
Option<Vec<Tensor>>,
|
||||||
self.base_model.forward_t(input_ids, attention_mask, decoder_input_ids, encoder_outputs, decoder_attention_mask, old_layer_states, train);
|
Option<Vec<Tensor>>,
|
||||||
|
) {
|
||||||
|
let (
|
||||||
|
decoder_outputs,
|
||||||
|
encoder_hidden_states,
|
||||||
|
decoder_cache,
|
||||||
|
all_decoder_hidden_states,
|
||||||
|
all_decoder_attentions,
|
||||||
|
all_encoder_hidden_states,
|
||||||
|
all_encoder_attentions,
|
||||||
|
) = self.base_model.forward_t(
|
||||||
|
input_ids,
|
||||||
|
attention_mask,
|
||||||
|
decoder_input_ids,
|
||||||
|
encoder_outputs,
|
||||||
|
decoder_attention_mask,
|
||||||
|
old_layer_states,
|
||||||
|
train,
|
||||||
|
);
|
||||||
|
|
||||||
let lm_logits = decoder_outputs.linear::<Tensor>(&self.base_model.embeddings.ws, None);
|
let lm_logits = decoder_outputs.linear::<Tensor>(&self.base_model.embeddings.ws, None);
|
||||||
(lm_logits, encoder_hidden_states, decoder_cache,
|
(
|
||||||
all_decoder_hidden_states, all_decoder_attentions,
|
lm_logits,
|
||||||
all_encoder_hidden_states, all_encoder_attentions)
|
encoder_hidden_states,
|
||||||
|
decoder_cache,
|
||||||
|
all_decoder_hidden_states,
|
||||||
|
all_decoder_attentions,
|
||||||
|
all_encoder_hidden_states,
|
||||||
|
all_encoder_attentions,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn encode(&self, input_ids: &Tensor, attention_mask: Option<&Tensor>) -> Tensor {
|
pub fn encode(&self, input_ids: &Tensor, attention_mask: Option<&Tensor>) -> Tensor {
|
||||||
let (encoder_hidden_states, _, _) = self.base_model.encoder.forward_t(input_ids, attention_mask, &self.base_model.embeddings, false);
|
let (encoder_hidden_states, _, _) = self.base_model.encoder.forward_t(
|
||||||
|
input_ids,
|
||||||
|
attention_mask,
|
||||||
|
&self.base_model.embeddings,
|
||||||
|
false,
|
||||||
|
);
|
||||||
encoder_hidden_states
|
encoder_hidden_states
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct BartClassificationHead {
|
pub struct BartClassificationHead {
|
||||||
@ -468,16 +579,29 @@ pub struct BartClassificationHead {
|
|||||||
|
|
||||||
impl BartClassificationHead {
|
impl BartClassificationHead {
|
||||||
pub fn new(p: &nn::Path, config: &BartConfig) -> BartClassificationHead {
|
pub fn new(p: &nn::Path, config: &BartConfig) -> BartClassificationHead {
|
||||||
let dense = nn::linear(&(p / "dense"), config.d_model, config.d_model, Default::default());
|
let dense = nn::linear(
|
||||||
|
&(p / "dense"),
|
||||||
|
config.d_model,
|
||||||
|
config.d_model,
|
||||||
|
Default::default(),
|
||||||
|
);
|
||||||
let dropout = Dropout::new(config.classif_dropout);
|
let dropout = Dropout::new(config.classif_dropout);
|
||||||
let out_proj = nn::linear(&(p / "out_proj"), config.d_model, config.num_labels.unwrap(), Default::default());
|
let out_proj = nn::linear(
|
||||||
|
&(p / "out_proj"),
|
||||||
|
config.d_model,
|
||||||
|
config.num_labels.unwrap(),
|
||||||
|
Default::default(),
|
||||||
|
);
|
||||||
|
|
||||||
BartClassificationHead { dense, dropout, out_proj }
|
BartClassificationHead {
|
||||||
|
dense,
|
||||||
|
dropout,
|
||||||
|
out_proj,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn forward_t(&self, x: &Tensor, train: bool) -> Tensor {
|
pub fn forward_t(&self, x: &Tensor, train: bool) -> Tensor {
|
||||||
x
|
x.apply_t(&self.dropout, train)
|
||||||
.apply_t(&self.dropout, train)
|
|
||||||
.apply(&self.dense)
|
.apply(&self.dense)
|
||||||
.tanh()
|
.tanh()
|
||||||
.apply_t(&self.dropout, train)
|
.apply_t(&self.dropout, train)
|
||||||
@ -497,7 +621,6 @@ pub struct BartForSequenceClassification {
|
|||||||
eos_token_id: i64,
|
eos_token_id: i64,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
impl BartForSequenceClassification {
|
impl BartForSequenceClassification {
|
||||||
/// Build a new `BartForSequenceClassification`
|
/// Build a new `BartForSequenceClassification`
|
||||||
///
|
///
|
||||||
@ -509,27 +632,31 @@ impl BartForSequenceClassification {
|
|||||||
/// # Example
|
/// # Example
|
||||||
///
|
///
|
||||||
/// ```no_run
|
/// ```no_run
|
||||||
/// use tch::{nn, Device};
|
/// use rust_bert::bart::{BartConfig, BartForSequenceClassification};
|
||||||
/// use rust_bert::Config;
|
/// use rust_bert::Config;
|
||||||
/// use std::path::Path;
|
/// use std::path::Path;
|
||||||
/// use rust_bert::bart::{BartConfig, BartForSequenceClassification};
|
/// use tch::{nn, Device};
|
||||||
///
|
///
|
||||||
/// let config_path = Path::new("path/to/config.json");
|
/// let config_path = Path::new("path/to/config.json");
|
||||||
/// let device = Device::Cpu;
|
/// let device = Device::Cpu;
|
||||||
/// let p = nn::VarStore::new(device);
|
/// let p = nn::VarStore::new(device);
|
||||||
/// let config = BartConfig::from_file(config_path);
|
/// let config = BartConfig::from_file(config_path);
|
||||||
/// let generation_mode = true;
|
/// let generation_mode = true;
|
||||||
/// let bart: BartForSequenceClassification = BartForSequenceClassification::new(&(&p.root() / "bart"), &config);
|
/// let bart: BartForSequenceClassification =
|
||||||
|
/// BartForSequenceClassification::new(&(&p.root() / "bart"), &config);
|
||||||
/// ```
|
/// ```
|
||||||
///
|
|
||||||
pub fn new(p: &nn::Path, config: &BartConfig) -> BartForSequenceClassification {
|
pub fn new(p: &nn::Path, config: &BartConfig) -> BartForSequenceClassification {
|
||||||
let base_model = BartModel::new(&(p / "model"), config, false);
|
let base_model = BartModel::new(&(p / "model"), config, false);
|
||||||
let classification_head = BartClassificationHead::new(&(p / "classification_head"), config);
|
let classification_head = BartClassificationHead::new(&(p / "classification_head"), config);
|
||||||
let eos_token_id = match config.eos_token_id {
|
let eos_token_id = match config.eos_token_id {
|
||||||
Some(value) => value,
|
Some(value) => value,
|
||||||
None => 3
|
None => 3,
|
||||||
};
|
};
|
||||||
BartForSequenceClassification { base_model, classification_head, eos_token_id }
|
BartForSequenceClassification {
|
||||||
|
base_model,
|
||||||
|
classification_head,
|
||||||
|
eos_token_id,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Forward pass through the model
|
/// Forward pass through the model
|
||||||
@ -556,17 +683,17 @@ impl BartForSequenceClassification {
|
|||||||
/// # Example
|
/// # Example
|
||||||
///
|
///
|
||||||
/// ```no_run
|
/// ```no_run
|
||||||
///# use tch::{nn, Device, Tensor, no_grad};
|
/// # use tch::{nn, Device, Tensor, no_grad};
|
||||||
///# use rust_bert::Config;
|
/// # use rust_bert::Config;
|
||||||
///# use std::path::Path;
|
/// # use std::path::Path;
|
||||||
///# use tch::kind::Kind::{Int64, Double};
|
/// # use tch::kind::Kind::{Int64, Double};
|
||||||
/// use rust_bert::bart::{BartConfig, BartForConditionalGeneration};
|
/// use rust_bert::bart::{BartConfig, BartForConditionalGeneration};
|
||||||
///# let config_path = Path::new("path/to/config.json");
|
/// # let config_path = Path::new("path/to/config.json");
|
||||||
///# let vocab_path = Path::new("path/to/vocab.txt");
|
/// # let vocab_path = Path::new("path/to/vocab.txt");
|
||||||
///# let device = Device::Cpu;
|
/// # let device = Device::Cpu;
|
||||||
///# let vs = nn::VarStore::new(device);
|
/// # let vs = nn::VarStore::new(device);
|
||||||
///# let config = BartConfig::from_file(config_path);
|
/// # let config = BartConfig::from_file(config_path);
|
||||||
///# let bart_model: BartForConditionalGeneration = BartForConditionalGeneration::new(&vs.root(), &config, false);
|
/// # let bart_model: BartForConditionalGeneration = BartForConditionalGeneration::new(&vs.root(), &config, false);
|
||||||
/// let (batch_size, source_sequence_length, target_sequence_length) = (64, 128, 56);
|
/// let (batch_size, source_sequence_length, target_sequence_length) = (64, 128, 56);
|
||||||
/// let input_tensor = Tensor::rand(&[batch_size, source_sequence_length], (Int64, device));
|
/// let input_tensor = Tensor::rand(&[batch_size, source_sequence_length], (Int64, device));
|
||||||
/// let target_tensor = Tensor::rand(&[batch_size, target_sequence_length], (Int64, device));
|
/// let target_tensor = Tensor::rand(&[batch_size, target_sequence_length], (Int64, device));
|
||||||
@ -585,36 +712,63 @@ impl BartForSequenceClassification {
|
|||||||
/// None,
|
/// None,
|
||||||
/// false)
|
/// false)
|
||||||
/// });
|
/// });
|
||||||
///
|
|
||||||
/// ```
|
/// ```
|
||||||
///
|
pub fn forward_t(
|
||||||
pub fn forward_t(&mut self,
|
&mut self,
|
||||||
input_ids: &Tensor,
|
input_ids: &Tensor,
|
||||||
attention_mask: Option<&Tensor>,
|
attention_mask: Option<&Tensor>,
|
||||||
encoder_outputs: Option<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>)>,
|
encoder_outputs: Option<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>)>,
|
||||||
decoder_input_ids: Option<&Tensor>,
|
decoder_input_ids: Option<&Tensor>,
|
||||||
decoder_attention_mask: Option<&Tensor>,
|
decoder_attention_mask: Option<&Tensor>,
|
||||||
train: bool)
|
train: bool,
|
||||||
-> (Tensor, Tensor,
|
) -> (
|
||||||
Option<Vec<Tensor>>, Option<Vec<Tensor>>,
|
Tensor,
|
||||||
Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
|
Tensor,
|
||||||
let (decoder_outputs, encoder_hidden_states, _,
|
Option<Vec<Tensor>>,
|
||||||
all_decoder_hidden_states, all_decoder_attentions,
|
Option<Vec<Tensor>>,
|
||||||
all_encoder_hidden_states, all_encoder_attentions) =
|
Option<Vec<Tensor>>,
|
||||||
self.borrow_mut().base_model.forward_t(Some(input_ids), attention_mask, decoder_input_ids, encoder_outputs, decoder_attention_mask, None, train);
|
Option<Vec<Tensor>>,
|
||||||
|
) {
|
||||||
|
let (
|
||||||
|
decoder_outputs,
|
||||||
|
encoder_hidden_states,
|
||||||
|
_,
|
||||||
|
all_decoder_hidden_states,
|
||||||
|
all_decoder_attentions,
|
||||||
|
all_encoder_hidden_states,
|
||||||
|
all_encoder_attentions,
|
||||||
|
) = self.borrow_mut().base_model.forward_t(
|
||||||
|
Some(input_ids),
|
||||||
|
attention_mask,
|
||||||
|
decoder_input_ids,
|
||||||
|
encoder_outputs,
|
||||||
|
decoder_attention_mask,
|
||||||
|
None,
|
||||||
|
train,
|
||||||
|
);
|
||||||
|
|
||||||
let eos_mask = input_ids.eq(self.eos_token_id);
|
let eos_mask = input_ids.eq(self.eos_token_id);
|
||||||
let sentence_representation = decoder_outputs
|
let sentence_representation = decoder_outputs
|
||||||
.index_select(0, &eos_mask)
|
.index_select(0, &eos_mask)
|
||||||
.view((decoder_outputs.size()[0], -1, *decoder_outputs.size().last().unwrap()))
|
.view((
|
||||||
|
decoder_outputs.size()[0],
|
||||||
|
-1,
|
||||||
|
*decoder_outputs.size().last().unwrap(),
|
||||||
|
))
|
||||||
.select(1, -1);
|
.select(1, -1);
|
||||||
|
|
||||||
let logits = self.classification_head.forward_t(&sentence_representation, train);
|
let logits = self
|
||||||
(logits, encoder_hidden_states,
|
.classification_head
|
||||||
all_decoder_hidden_states, all_decoder_attentions,
|
.forward_t(&sentence_representation, train);
|
||||||
all_encoder_hidden_states, all_encoder_attentions)
|
(
|
||||||
|
logits,
|
||||||
|
encoder_hidden_states,
|
||||||
|
all_decoder_hidden_states,
|
||||||
|
all_decoder_attentions,
|
||||||
|
all_encoder_hidden_states,
|
||||||
|
all_encoder_attentions,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl LMHeadModel for BartForConditionalGeneration {
|
impl LMHeadModel for BartForConditionalGeneration {
|
||||||
@ -645,18 +799,18 @@ impl LMHeadModel for BartForConditionalGeneration {
|
|||||||
/// # Example
|
/// # Example
|
||||||
///
|
///
|
||||||
/// ```no_run
|
/// ```no_run
|
||||||
///# use tch::{nn, Device, Tensor, no_grad};
|
/// # use tch::{nn, Device, Tensor, no_grad};
|
||||||
///# use rust_bert::Config;
|
/// # use rust_bert::Config;
|
||||||
///# use std::path::Path;
|
/// # use std::path::Path;
|
||||||
///# use tch::kind::Kind::{Int64, Double};
|
/// # use tch::kind::Kind::{Int64, Double};
|
||||||
/// use rust_bert::pipelines::generation::LMHeadModel;
|
/// use rust_bert::pipelines::generation::LMHeadModel;
|
||||||
/// use rust_bert::bart::{BartForConditionalGeneration, BartConfig};
|
/// use rust_bert::bart::{BartForConditionalGeneration, BartConfig};
|
||||||
///# let config_path = Path::new("path/to/config.json");
|
/// # let config_path = Path::new("path/to/config.json");
|
||||||
///# let vocab_path = Path::new("path/to/vocab.txt");
|
/// # let vocab_path = Path::new("path/to/vocab.txt");
|
||||||
///# let device = Device::Cpu;
|
/// # let device = Device::Cpu;
|
||||||
///# let vs = nn::VarStore::new(device);
|
/// # let vs = nn::VarStore::new(device);
|
||||||
///# let config = BartConfig::from_file(config_path);
|
/// # let config = BartConfig::from_file(config_path);
|
||||||
///# let bart_model: BartForConditionalGeneration = BartForConditionalGeneration::new(&vs.root(), &config, false);
|
/// # let bart_model: BartForConditionalGeneration = BartForConditionalGeneration::new(&vs.root(), &config, false);
|
||||||
/// let (batch_size, source_sequence_length, target_sequence_length) = (64, 128, 56);
|
/// let (batch_size, source_sequence_length, target_sequence_length) = (64, 128, 56);
|
||||||
/// let input_tensor = Tensor::rand(&[batch_size, source_sequence_length], (Int64, device));
|
/// let input_tensor = Tensor::rand(&[batch_size, source_sequence_length], (Int64, device));
|
||||||
/// let target_tensor = Tensor::rand(&[batch_size, target_sequence_length], (Int64, device));
|
/// let target_tensor = Tensor::rand(&[batch_size, target_sequence_length], (Int64, device));
|
||||||
@ -675,39 +829,58 @@ impl LMHeadModel for BartForConditionalGeneration {
|
|||||||
/// None,
|
/// None,
|
||||||
/// false)
|
/// false)
|
||||||
/// });
|
/// });
|
||||||
///
|
|
||||||
/// ```
|
/// ```
|
||||||
///
|
fn forward_t(
|
||||||
fn forward_t(&self,
|
&self,
|
||||||
input_ids: &Option<Tensor>,
|
input_ids: &Option<Tensor>,
|
||||||
cache: Cache,
|
cache: Cache,
|
||||||
attention_mask: &Option<Tensor>,
|
attention_mask: &Option<Tensor>,
|
||||||
_token_type_ids: &Option<Tensor>,
|
_token_type_ids: &Option<Tensor>,
|
||||||
_position_ids: &Option<Tensor>,
|
_position_ids: &Option<Tensor>,
|
||||||
_input_embeds: &Option<Tensor>,
|
_input_embeds: &Option<Tensor>,
|
||||||
encoder_outputs: Option<&Tensor>,
|
encoder_outputs: Option<&Tensor>,
|
||||||
decoder_input_ids: &Option<Tensor>,
|
decoder_input_ids: &Option<Tensor>,
|
||||||
train: bool) -> Result<(Tensor, Option<Tensor>, Cache, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
|
train: bool,
|
||||||
|
) -> Result<
|
||||||
|
(
|
||||||
|
Tensor,
|
||||||
|
Option<Tensor>,
|
||||||
|
Cache,
|
||||||
|
Option<Vec<Tensor>>,
|
||||||
|
Option<Vec<Tensor>>,
|
||||||
|
),
|
||||||
|
&'static str,
|
||||||
|
> {
|
||||||
let (decoder_output, encoder_hidden_states, new_cache, _, _, _, _) = match cache {
|
let (decoder_output, encoder_hidden_states, new_cache, _, _, _, _) = match cache {
|
||||||
Cache::BARTCache(cached_layer_states) => self.base_model.forward_t(input_ids.as_ref(),
|
Cache::BARTCache(cached_layer_states) => self.base_model.forward_t(
|
||||||
attention_mask.as_ref(),
|
input_ids.as_ref(),
|
||||||
decoder_input_ids.as_ref(),
|
attention_mask.as_ref(),
|
||||||
Some((encoder_outputs.as_ref().unwrap().copy(), None, None)),
|
decoder_input_ids.as_ref(),
|
||||||
None,
|
Some((encoder_outputs.as_ref().unwrap().copy(), None, None)),
|
||||||
cached_layer_states,
|
None,
|
||||||
train),
|
cached_layer_states,
|
||||||
|
train,
|
||||||
|
),
|
||||||
|
|
||||||
Cache::None => self.base_model.forward_t(input_ids.as_ref(),
|
Cache::None => self.base_model.forward_t(
|
||||||
attention_mask.as_ref(),
|
input_ids.as_ref(),
|
||||||
decoder_input_ids.as_ref(),
|
attention_mask.as_ref(),
|
||||||
Some((encoder_outputs.as_ref().unwrap().copy(), None, None)),
|
decoder_input_ids.as_ref(),
|
||||||
None,
|
Some((encoder_outputs.as_ref().unwrap().copy(), None, None)),
|
||||||
None,
|
None,
|
||||||
train),
|
None,
|
||||||
_ => Err("Cache not compatible with BART Model")?
|
train,
|
||||||
|
),
|
||||||
|
_ => Err("Cache not compatible with BART Model")?,
|
||||||
};
|
};
|
||||||
|
|
||||||
let lm_logits = decoder_output.linear::<Tensor>(&self.base_model.embeddings.ws, None);
|
let lm_logits = decoder_output.linear::<Tensor>(&self.base_model.embeddings.ws, None);
|
||||||
Ok((lm_logits, Some(encoder_hidden_states), Cache::BARTCache(new_cache), None, None))
|
Ok((
|
||||||
|
lm_logits,
|
||||||
|
Some(encoder_hidden_states),
|
||||||
|
Cache::BARTCache(new_cache),
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -11,15 +11,17 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
use crate::bart::attention::{SelfAttention, LayerState};
|
use crate::bart::attention::{LayerState, SelfAttention};
|
||||||
use tch::{nn, Tensor};
|
|
||||||
use crate::common::dropout::Dropout;
|
|
||||||
use crate::bart::BartConfig;
|
|
||||||
use crate::bart::bart::Activation;
|
use crate::bart::bart::Activation;
|
||||||
use crate::common::activations::{_gelu, _relu, _swish, _gelu_new, _tanh};
|
use crate::bart::embeddings::{
|
||||||
use tch::kind::Kind::Int64;
|
EmbeddingOption, LearnedPositionalEmbedding, SinusoidalPositionalEmbedding,
|
||||||
|
};
|
||||||
|
use crate::bart::BartConfig;
|
||||||
|
use crate::common::activations::{_gelu, _gelu_new, _relu, _swish, _tanh};
|
||||||
|
use crate::common::dropout::Dropout;
|
||||||
use std::borrow::BorrowMut;
|
use std::borrow::BorrowMut;
|
||||||
use crate::bart::embeddings::{EmbeddingOption, LearnedPositionalEmbedding, SinusoidalPositionalEmbedding};
|
use tch::kind::Kind::Int64;
|
||||||
|
use tch::{nn, Tensor};
|
||||||
|
|
||||||
pub struct DecoderLayer {
|
pub struct DecoderLayer {
|
||||||
self_attention: SelfAttention,
|
self_attention: SelfAttention,
|
||||||
@ -36,51 +38,74 @@ pub struct DecoderLayer {
|
|||||||
|
|
||||||
impl DecoderLayer {
|
impl DecoderLayer {
|
||||||
pub fn new(p: nn::Path, config: &BartConfig) -> DecoderLayer {
|
pub fn new(p: nn::Path, config: &BartConfig) -> DecoderLayer {
|
||||||
let layer_norm_config = nn::LayerNormConfig { eps: 1e-5, ..Default::default() };
|
let layer_norm_config = nn::LayerNormConfig {
|
||||||
|
eps: 1e-5,
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
let output_attention = match config.output_attentions {
|
let output_attention = match config.output_attentions {
|
||||||
Some(value) => value,
|
Some(value) => value,
|
||||||
None => false
|
None => false,
|
||||||
};
|
};
|
||||||
let self_attention = SelfAttention::new(&p / "self_attn",
|
let self_attention = SelfAttention::new(
|
||||||
config.d_model,
|
&p / "self_attn",
|
||||||
config.decoder_attention_heads,
|
config.d_model,
|
||||||
config.attention_dropout,
|
config.decoder_attention_heads,
|
||||||
false,
|
config.attention_dropout,
|
||||||
true,
|
false,
|
||||||
output_attention);
|
true,
|
||||||
let encoder_attention = SelfAttention::new(&p / "encoder_attn",
|
output_attention,
|
||||||
config.d_model,
|
);
|
||||||
config.decoder_attention_heads,
|
let encoder_attention = SelfAttention::new(
|
||||||
config.attention_dropout,
|
&p / "encoder_attn",
|
||||||
true,
|
config.d_model,
|
||||||
true,
|
config.decoder_attention_heads,
|
||||||
output_attention);
|
config.attention_dropout,
|
||||||
let self_attention_layer_norm = nn::layer_norm(&p / "self_attn_layer_norm",
|
true,
|
||||||
vec![config.d_model],
|
true,
|
||||||
layer_norm_config);
|
output_attention,
|
||||||
let encoder_attention_layer_norm = nn::layer_norm(&p / "encoder_attn_layer_norm",
|
);
|
||||||
vec![config.d_model],
|
let self_attention_layer_norm = nn::layer_norm(
|
||||||
layer_norm_config);
|
&p / "self_attn_layer_norm",
|
||||||
|
vec![config.d_model],
|
||||||
|
layer_norm_config,
|
||||||
|
);
|
||||||
|
let encoder_attention_layer_norm = nn::layer_norm(
|
||||||
|
&p / "encoder_attn_layer_norm",
|
||||||
|
vec![config.d_model],
|
||||||
|
layer_norm_config,
|
||||||
|
);
|
||||||
|
|
||||||
let dropout = Dropout::new(config.dropout);
|
let dropout = Dropout::new(config.dropout);
|
||||||
let activation_dropout = Dropout::new(config.activation_dropout);
|
let activation_dropout = Dropout::new(config.activation_dropout);
|
||||||
let activation_function = match &config.activation_function {
|
let activation_function = match &config.activation_function {
|
||||||
Some(act_function) => act_function,
|
Some(act_function) => act_function,
|
||||||
None => &Activation::gelu
|
None => &Activation::gelu,
|
||||||
};
|
};
|
||||||
let activation = Box::new(match activation_function {
|
let activation = Box::new(match activation_function {
|
||||||
Activation::gelu => _gelu,
|
Activation::gelu => _gelu,
|
||||||
Activation::relu => _relu,
|
Activation::relu => _relu,
|
||||||
Activation::swish => _swish,
|
Activation::swish => _swish,
|
||||||
Activation::gelu_new => _gelu_new,
|
Activation::gelu_new => _gelu_new,
|
||||||
Activation::tanh => _tanh
|
Activation::tanh => _tanh,
|
||||||
});
|
});
|
||||||
let fc1 = nn::linear(&p / "fc1", config.d_model, config.decoder_ffn_dim, Default::default());
|
let fc1 = nn::linear(
|
||||||
let fc2 = nn::linear(&p / "fc2", config.decoder_ffn_dim, config.d_model, Default::default());
|
&p / "fc1",
|
||||||
|
config.d_model,
|
||||||
|
config.decoder_ffn_dim,
|
||||||
|
Default::default(),
|
||||||
|
);
|
||||||
|
let fc2 = nn::linear(
|
||||||
|
&p / "fc2",
|
||||||
|
config.decoder_ffn_dim,
|
||||||
|
config.d_model,
|
||||||
|
Default::default(),
|
||||||
|
);
|
||||||
|
|
||||||
let final_layer_norm = nn::layer_norm(&p / "final_layer_norm",
|
let final_layer_norm = nn::layer_norm(
|
||||||
vec![config.d_model],
|
&p / "final_layer_norm",
|
||||||
layer_norm_config);
|
vec![config.d_model],
|
||||||
|
layer_norm_config,
|
||||||
|
);
|
||||||
|
|
||||||
DecoderLayer {
|
DecoderLayer {
|
||||||
self_attention,
|
self_attention,
|
||||||
@ -96,18 +121,38 @@ impl DecoderLayer {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn forward_t(&self,
|
pub fn forward_t(
|
||||||
x: &Tensor,
|
&self,
|
||||||
encoder_hidden_states: &Tensor,
|
x: &Tensor,
|
||||||
encoder_attn_mask: Option<&Tensor>,
|
encoder_hidden_states: &Tensor,
|
||||||
causal_mask: Option<&Tensor>,
|
encoder_attn_mask: Option<&Tensor>,
|
||||||
decoder_padding_mask: Option<&Tensor>,
|
causal_mask: Option<&Tensor>,
|
||||||
layer_states: (Option<LayerState>, Option<LayerState>),
|
decoder_padding_mask: Option<&Tensor>,
|
||||||
train: bool) -> (Tensor, Option<Tensor>, (Option<LayerState>, Option<LayerState>)) {
|
layer_states: (Option<LayerState>, Option<LayerState>),
|
||||||
let (output, attention_weights, new_self_layer_states) = self.self_attention.forward_t(x, Some(x), decoder_padding_mask, causal_mask, layer_states.0, train);
|
train: bool,
|
||||||
|
) -> (
|
||||||
|
Tensor,
|
||||||
|
Option<Tensor>,
|
||||||
|
(Option<LayerState>, Option<LayerState>),
|
||||||
|
) {
|
||||||
|
let (output, attention_weights, new_self_layer_states) = self.self_attention.forward_t(
|
||||||
|
x,
|
||||||
|
Some(x),
|
||||||
|
decoder_padding_mask,
|
||||||
|
causal_mask,
|
||||||
|
layer_states.0,
|
||||||
|
train,
|
||||||
|
);
|
||||||
let output: Tensor = output.apply_t(&self.dropout, train) + x;
|
let output: Tensor = output.apply_t(&self.dropout, train) + x;
|
||||||
let output = output.apply(&self.self_attention_layer_norm);
|
let output = output.apply(&self.self_attention_layer_norm);
|
||||||
let (output1, _, new_encoder_layer_states) = self.encoder_attention.forward_t(&output, Some(encoder_hidden_states), encoder_attn_mask, None, layer_states.1, train);
|
let (output1, _, new_encoder_layer_states) = self.encoder_attention.forward_t(
|
||||||
|
&output,
|
||||||
|
Some(encoder_hidden_states),
|
||||||
|
encoder_attn_mask,
|
||||||
|
None,
|
||||||
|
layer_states.1,
|
||||||
|
train,
|
||||||
|
);
|
||||||
let output1: Tensor = output1.apply_t(&self.dropout, train) + output;
|
let output1: Tensor = output1.apply_t(&self.dropout, train) + output;
|
||||||
let output1 = output1.apply(&self.encoder_attention_layer_norm);
|
let output1 = output1.apply(&self.encoder_attention_layer_norm);
|
||||||
let output2 = (self.activation)(&output1.apply(&self.fc1));
|
let output2 = (self.activation)(&output1.apply(&self.fc1));
|
||||||
@ -116,7 +161,11 @@ impl DecoderLayer {
|
|||||||
.apply(&self.fc2)
|
.apply(&self.fc2)
|
||||||
.apply_t(&self.dropout, train);
|
.apply_t(&self.dropout, train);
|
||||||
let output2: Tensor = output2 + output1;
|
let output2: Tensor = output2 + output1;
|
||||||
(output2.apply(&self.final_layer_norm), attention_weights, (new_self_layer_states, new_encoder_layer_states))
|
(
|
||||||
|
output2.apply(&self.final_layer_norm),
|
||||||
|
attention_weights,
|
||||||
|
(new_self_layer_states, new_encoder_layer_states),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -136,61 +185,76 @@ impl BartDecoder {
|
|||||||
pub fn new(p: nn::Path, config: &BartConfig, generation_mode: bool) -> BartDecoder {
|
pub fn new(p: nn::Path, config: &BartConfig, generation_mode: bool) -> BartDecoder {
|
||||||
let output_past = match config.output_past {
|
let output_past = match config.output_past {
|
||||||
Some(value) => value,
|
Some(value) => value,
|
||||||
None => true
|
None => true,
|
||||||
};
|
};
|
||||||
let output_attentions = match config.output_attentions {
|
let output_attentions = match config.output_attentions {
|
||||||
Some(value) => value,
|
Some(value) => value,
|
||||||
None => false
|
None => false,
|
||||||
};
|
};
|
||||||
let output_hidden_states = match config.output_hidden_states {
|
let output_hidden_states = match config.output_hidden_states {
|
||||||
Some(value) => value,
|
Some(value) => value,
|
||||||
None => false
|
None => false,
|
||||||
};
|
};
|
||||||
let normalize_embedding = match config.normalize_embedding {
|
let normalize_embedding = match config.normalize_embedding {
|
||||||
Some(value) => value,
|
Some(value) => value,
|
||||||
None => true
|
None => true,
|
||||||
};
|
};
|
||||||
let static_position_embeddings = match config.static_position_embeddings {
|
let static_position_embeddings = match config.static_position_embeddings {
|
||||||
Some(value) => value,
|
Some(value) => value,
|
||||||
None => false
|
None => false,
|
||||||
};
|
};
|
||||||
let scale_embedding = match config.scale_embedding {
|
let scale_embedding = match config.scale_embedding {
|
||||||
Some(value) => if value { (config.d_model as f64).sqrt() } else { 1.0 },
|
Some(value) => {
|
||||||
None => 1.0
|
if value {
|
||||||
|
(config.d_model as f64).sqrt()
|
||||||
|
} else {
|
||||||
|
1.0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
None => 1.0,
|
||||||
};
|
};
|
||||||
|
|
||||||
let dropout = Dropout::new(config.dropout);
|
let dropout = Dropout::new(config.dropout);
|
||||||
|
|
||||||
let layer_norm_embedding = if normalize_embedding {
|
let layer_norm_embedding = if normalize_embedding {
|
||||||
let layer_norm_config = nn::LayerNormConfig { eps: 1e-5, ..Default::default() };
|
let layer_norm_config = nn::LayerNormConfig {
|
||||||
Some(nn::layer_norm(&p / "layernorm_embedding",
|
eps: 1e-5,
|
||||||
vec![config.d_model],
|
..Default::default()
|
||||||
layer_norm_config))
|
};
|
||||||
|
Some(nn::layer_norm(
|
||||||
|
&p / "layernorm_embedding",
|
||||||
|
vec![config.d_model],
|
||||||
|
layer_norm_config,
|
||||||
|
))
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
};
|
};
|
||||||
|
|
||||||
let pad_token_id = match config.pad_token_id {
|
let pad_token_id = match config.pad_token_id {
|
||||||
Some(value) => value,
|
Some(value) => value,
|
||||||
None => 1
|
None => 1,
|
||||||
};
|
};
|
||||||
|
|
||||||
let embed_positions = if static_position_embeddings {
|
let embed_positions = if static_position_embeddings {
|
||||||
EmbeddingOption::SinusoidalPositionalEmbedding(SinusoidalPositionalEmbedding::new(&p / "embed_positions",
|
EmbeddingOption::SinusoidalPositionalEmbedding(SinusoidalPositionalEmbedding::new(
|
||||||
config.max_position_embeddings,
|
&p / "embed_positions",
|
||||||
config.d_model))
|
config.max_position_embeddings,
|
||||||
|
config.d_model,
|
||||||
|
))
|
||||||
} else {
|
} else {
|
||||||
EmbeddingOption::LearnedPositionalEmbedding(LearnedPositionalEmbedding::new(&p / "embed_positions",
|
EmbeddingOption::LearnedPositionalEmbedding(LearnedPositionalEmbedding::new(
|
||||||
config.max_position_embeddings,
|
&p / "embed_positions",
|
||||||
config.d_model,
|
config.max_position_embeddings,
|
||||||
pad_token_id))
|
config.d_model,
|
||||||
|
pad_token_id,
|
||||||
|
))
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut layers: Vec<DecoderLayer> = vec!();
|
let mut layers: Vec<DecoderLayer> = vec![];
|
||||||
let p_layers = &p / "layers";
|
let p_layers = &p / "layers";
|
||||||
for layer_index in 0..config.decoder_layers {
|
for layer_index in 0..config.decoder_layers {
|
||||||
layers.push(DecoderLayer::new(&p_layers / layer_index, config));
|
layers.push(DecoderLayer::new(&p_layers / layer_index, config));
|
||||||
};
|
}
|
||||||
|
|
||||||
BartDecoder {
|
BartDecoder {
|
||||||
dropout,
|
dropout,
|
||||||
@ -205,44 +269,68 @@ impl BartDecoder {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn forward_t(&self,
|
pub fn forward_t(
|
||||||
input_ids: &Tensor,
|
&self,
|
||||||
encoder_hidden_states: &Tensor,
|
input_ids: &Tensor,
|
||||||
encoder_padding_mask: Option<&Tensor>,
|
encoder_hidden_states: &Tensor,
|
||||||
decoder_padding_mask: Option<&Tensor>,
|
encoder_padding_mask: Option<&Tensor>,
|
||||||
decoder_causal_mask: Option<&Tensor>,
|
decoder_padding_mask: Option<&Tensor>,
|
||||||
embeddings: &nn::Embedding,
|
decoder_causal_mask: Option<&Tensor>,
|
||||||
old_layer_states: Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
|
embeddings: &nn::Embedding,
|
||||||
train: bool)
|
old_layer_states: Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
|
||||||
-> (Tensor,
|
train: bool,
|
||||||
(Option<Tensor>, Option<Vec<(Option<LayerState>, Option<LayerState>)>>),
|
) -> (
|
||||||
Option<Vec<Tensor>>,
|
Tensor,
|
||||||
Option<Vec<Tensor>>) {
|
(
|
||||||
|
Option<Tensor>,
|
||||||
|
Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
|
||||||
|
),
|
||||||
|
Option<Vec<Tensor>>,
|
||||||
|
Option<Vec<Tensor>>,
|
||||||
|
) {
|
||||||
let encoder_padding_mask = match encoder_padding_mask {
|
let encoder_padding_mask = match encoder_padding_mask {
|
||||||
Some(mask) => Some(mask.eq(0).to_kind(Int64)),
|
Some(mask) => Some(mask.eq(0).to_kind(Int64)),
|
||||||
None => None
|
None => None,
|
||||||
};
|
};
|
||||||
|
|
||||||
let positions = self.embed_positions.forward(input_ids, self.generation_mode);
|
let positions = self
|
||||||
|
.embed_positions
|
||||||
|
.forward(input_ids, self.generation_mode);
|
||||||
let x: Tensor = if self.generation_mode {
|
let x: Tensor = if self.generation_mode {
|
||||||
let end_inputs = input_ids.size()[1];
|
let end_inputs = input_ids.size()[1];
|
||||||
let end_positions = positions.size()[1];
|
let end_positions = positions.size()[1];
|
||||||
input_ids.narrow(1, end_inputs - 1, 1).apply(embeddings) * self.scale_embedding + positions.narrow(1, end_positions - 1, 1)
|
input_ids.narrow(1, end_inputs - 1, 1).apply(embeddings) * self.scale_embedding
|
||||||
|
+ positions.narrow(1, end_positions - 1, 1)
|
||||||
} else {
|
} else {
|
||||||
input_ids.apply(embeddings) * self.scale_embedding + positions
|
input_ids.apply(embeddings) * self.scale_embedding + positions
|
||||||
};
|
};
|
||||||
|
|
||||||
let x = if let Some(layer_norm_embedding) = &self.layer_norm_embedding { x.apply(layer_norm_embedding) } else { x };
|
let x = if let Some(layer_norm_embedding) = &self.layer_norm_embedding {
|
||||||
let mut hidden_state = x
|
x.apply(layer_norm_embedding)
|
||||||
.apply_t(&self.dropout, train)
|
} else {
|
||||||
.transpose(0, 1);
|
x
|
||||||
let mut all_hidden_states: Option<Vec<Tensor>> = if self.output_hidden_states { Some(Vec::with_capacity(self.layers.len())) } else { None };
|
};
|
||||||
let mut all_attentions: Option<Vec<Tensor>> = if self.output_attentions { Some(Vec::with_capacity(self.layers.len())) } else { None };
|
let mut hidden_state = x.apply_t(&self.dropout, train).transpose(0, 1);
|
||||||
let mut next_decoder_cache: Option<Vec<(Option<LayerState>, Option<LayerState>)>> = if self.output_past {
|
let mut all_hidden_states: Option<Vec<Tensor>> = if self.output_hidden_states {
|
||||||
if old_layer_states.is_some() { old_layer_states } else { Some(vec!((None, None); self.layers.len())) }
|
Some(Vec::with_capacity(self.layers.len()))
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
};
|
};
|
||||||
|
let mut all_attentions: Option<Vec<Tensor>> = if self.output_attentions {
|
||||||
|
Some(Vec::with_capacity(self.layers.len()))
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
let mut next_decoder_cache: Option<Vec<(Option<LayerState>, Option<LayerState>)>> =
|
||||||
|
if self.output_past {
|
||||||
|
if old_layer_states.is_some() {
|
||||||
|
old_layer_states
|
||||||
|
} else {
|
||||||
|
Some(vec![(None, None); self.layers.len()])
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
let encoder_hidden_states = encoder_hidden_states.transpose(0, 1);
|
let encoder_hidden_states = encoder_hidden_states.transpose(0, 1);
|
||||||
let mut attention_weights: Option<Tensor>;
|
let mut attention_weights: Option<Tensor>;
|
||||||
let mut layers = self.layers.iter().enumerate();
|
let mut layers = self.layers.iter().enumerate();
|
||||||
@ -252,15 +340,17 @@ impl BartDecoder {
|
|||||||
Some((layer_idx, layer)) => {
|
Some((layer_idx, layer)) => {
|
||||||
let layer_state = match &next_decoder_cache {
|
let layer_state = match &next_decoder_cache {
|
||||||
Some(values) => values[layer_idx].to_owned(),
|
Some(values) => values[layer_idx].to_owned(),
|
||||||
None => (None, None)
|
None => (None, None),
|
||||||
};
|
};
|
||||||
let temp = layer.forward_t(&hidden_state,
|
let temp = layer.forward_t(
|
||||||
&encoder_hidden_states,
|
&hidden_state,
|
||||||
encoder_padding_mask.as_ref(),
|
&encoder_hidden_states,
|
||||||
decoder_causal_mask,
|
encoder_padding_mask.as_ref(),
|
||||||
decoder_padding_mask,
|
decoder_causal_mask,
|
||||||
layer_state,
|
decoder_padding_mask,
|
||||||
train);
|
layer_state,
|
||||||
|
train,
|
||||||
|
);
|
||||||
hidden_state = temp.0;
|
hidden_state = temp.0;
|
||||||
attention_weights = temp.1;
|
attention_weights = temp.1;
|
||||||
if let Some(hidden_states) = all_hidden_states.borrow_mut() {
|
if let Some(hidden_states) = all_hidden_states.borrow_mut() {
|
||||||
@ -269,15 +359,19 @@ impl BartDecoder {
|
|||||||
if let Some(attentions) = all_attentions.borrow_mut() {
|
if let Some(attentions) = all_attentions.borrow_mut() {
|
||||||
attentions.push(attention_weights.as_ref().unwrap().copy());
|
attentions.push(attention_weights.as_ref().unwrap().copy());
|
||||||
};
|
};
|
||||||
if let Some(value) = &mut next_decoder_cache { value[layer_idx] = temp.2 };
|
if let Some(value) = &mut next_decoder_cache {
|
||||||
|
value[layer_idx] = temp.2
|
||||||
|
};
|
||||||
}
|
}
|
||||||
None => break
|
None => break,
|
||||||
};
|
};
|
||||||
};
|
}
|
||||||
|
|
||||||
(hidden_state.transpose(0, 1),
|
(
|
||||||
(encoder_padding_mask, next_decoder_cache),
|
hidden_state.transpose(0, 1),
|
||||||
all_hidden_states,
|
(encoder_padding_mask, next_decoder_cache),
|
||||||
all_attentions)
|
all_hidden_states,
|
||||||
|
all_attentions,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -11,10 +11,9 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
use tch::{nn, Tensor};
|
|
||||||
use tch::nn::{EmbeddingConfig, embedding};
|
|
||||||
use tch::kind::Kind::Int64;
|
use tch::kind::Kind::Int64;
|
||||||
|
use tch::nn::{embedding, EmbeddingConfig};
|
||||||
|
use tch::{nn, Tensor};
|
||||||
|
|
||||||
/// # Abstraction that holds a embeddings configuration
|
/// # Abstraction that holds a embeddings configuration
|
||||||
pub enum EmbeddingOption {
|
pub enum EmbeddingOption {
|
||||||
@ -27,8 +26,12 @@ impl EmbeddingOption {
|
|||||||
/// Interface method to forward_t() of the particular models.
|
/// Interface method to forward_t() of the particular models.
|
||||||
pub fn forward(&self, input: &Tensor, generation_mode: bool) -> Tensor {
|
pub fn forward(&self, input: &Tensor, generation_mode: bool) -> Tensor {
|
||||||
match *self {
|
match *self {
|
||||||
Self::LearnedPositionalEmbedding(ref embeddings) => embeddings.forward(input, generation_mode),
|
Self::LearnedPositionalEmbedding(ref embeddings) => {
|
||||||
Self::SinusoidalPositionalEmbedding(ref embeddings) => embeddings.forward(input, generation_mode)
|
embeddings.forward(input, generation_mode)
|
||||||
|
}
|
||||||
|
Self::SinusoidalPositionalEmbedding(ref embeddings) => {
|
||||||
|
embeddings.forward(input, generation_mode)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -40,15 +43,24 @@ pub struct LearnedPositionalEmbedding {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl LearnedPositionalEmbedding {
|
impl LearnedPositionalEmbedding {
|
||||||
pub fn new(p: nn::Path, num_embeddings: i64, embedding_dim: i64, padding_index: i64) -> LearnedPositionalEmbedding {
|
pub fn new(
|
||||||
let embedding_config = EmbeddingConfig { padding_idx: padding_index, ..Default::default() };
|
p: nn::Path,
|
||||||
|
num_embeddings: i64,
|
||||||
|
embedding_dim: i64,
|
||||||
|
padding_index: i64,
|
||||||
|
) -> LearnedPositionalEmbedding {
|
||||||
|
let embedding_config = EmbeddingConfig {
|
||||||
|
padding_idx: padding_index,
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
let num_embeddings = num_embeddings + padding_index + 1;
|
let num_embeddings = num_embeddings + padding_index + 1;
|
||||||
|
|
||||||
let embedding: nn::Embedding = embedding(p,
|
let embedding: nn::Embedding =
|
||||||
num_embeddings,
|
embedding(p, num_embeddings, embedding_dim, embedding_config);
|
||||||
embedding_dim,
|
LearnedPositionalEmbedding {
|
||||||
embedding_config);
|
embedding,
|
||||||
LearnedPositionalEmbedding { embedding, padding_index }
|
padding_index,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn forward(&self, input: &Tensor, generation_mode: bool) -> Tensor {
|
pub fn forward(&self, input: &Tensor, generation_mode: bool) -> Tensor {
|
||||||
@ -74,11 +86,13 @@ pub struct SinusoidalPositionalEmbedding {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl SinusoidalPositionalEmbedding {
|
impl SinusoidalPositionalEmbedding {
|
||||||
pub fn new(p: nn::Path, num_embeddings: i64, embedding_dim: i64) -> SinusoidalPositionalEmbedding {
|
pub fn new(
|
||||||
let embedding: nn::Embedding = embedding(p,
|
p: nn::Path,
|
||||||
num_embeddings,
|
num_embeddings: i64,
|
||||||
embedding_dim,
|
embedding_dim: i64,
|
||||||
Default::default());
|
) -> SinusoidalPositionalEmbedding {
|
||||||
|
let embedding: nn::Embedding =
|
||||||
|
embedding(p, num_embeddings, embedding_dim, Default::default());
|
||||||
SinusoidalPositionalEmbedding { embedding }
|
SinusoidalPositionalEmbedding { embedding }
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -86,8 +100,8 @@ impl SinusoidalPositionalEmbedding {
|
|||||||
let positions = if generation_mode {
|
let positions = if generation_mode {
|
||||||
Tensor::full(&[1, 1], input.size()[1] - 1, (Int64, input.device()))
|
Tensor::full(&[1, 1], input.size()[1] - 1, (Int64, input.device()))
|
||||||
} else {
|
} else {
|
||||||
Tensor::arange(input.size()[1],(Int64, input.device()))
|
Tensor::arange(input.size()[1], (Int64, input.device()))
|
||||||
};
|
};
|
||||||
positions.apply(&self.embedding)
|
positions.apply(&self.embedding)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -12,14 +12,16 @@
|
|||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
use crate::bart::attention::SelfAttention;
|
use crate::bart::attention::SelfAttention;
|
||||||
use tch::{nn, Tensor};
|
|
||||||
use crate::common::dropout::Dropout;
|
|
||||||
use crate::bart::BartConfig;
|
|
||||||
use crate::bart::bart::Activation;
|
use crate::bart::bart::Activation;
|
||||||
use crate::common::activations::{_gelu, _relu, _swish, _gelu_new, _tanh};
|
use crate::bart::embeddings::{
|
||||||
use crate::bart::embeddings::{EmbeddingOption, LearnedPositionalEmbedding, SinusoidalPositionalEmbedding};
|
EmbeddingOption, LearnedPositionalEmbedding, SinusoidalPositionalEmbedding,
|
||||||
use tch::kind::Kind::Bool;
|
};
|
||||||
|
use crate::bart::BartConfig;
|
||||||
|
use crate::common::activations::{_gelu, _gelu_new, _relu, _swish, _tanh};
|
||||||
|
use crate::common::dropout::Dropout;
|
||||||
use std::borrow::BorrowMut;
|
use std::borrow::BorrowMut;
|
||||||
|
use tch::kind::Kind::Bool;
|
||||||
|
use tch::{nn, Tensor};
|
||||||
|
|
||||||
pub struct EncoderLayer {
|
pub struct EncoderLayer {
|
||||||
self_attention: SelfAttention,
|
self_attention: SelfAttention,
|
||||||
@ -34,46 +36,81 @@ pub struct EncoderLayer {
|
|||||||
|
|
||||||
impl EncoderLayer {
|
impl EncoderLayer {
|
||||||
pub fn new(p: nn::Path, config: &BartConfig) -> EncoderLayer {
|
pub fn new(p: nn::Path, config: &BartConfig) -> EncoderLayer {
|
||||||
let layer_norm_config = nn::LayerNormConfig { eps: 1e-5, ..Default::default() };
|
let layer_norm_config = nn::LayerNormConfig {
|
||||||
|
eps: 1e-5,
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
let output_attention = match config.output_attentions {
|
let output_attention = match config.output_attentions {
|
||||||
Some(value) => value,
|
Some(value) => value,
|
||||||
None => false
|
None => false,
|
||||||
};
|
};
|
||||||
let self_attention = SelfAttention::new(&p / "self_attn",
|
let self_attention = SelfAttention::new(
|
||||||
config.d_model,
|
&p / "self_attn",
|
||||||
config.encoder_attention_heads,
|
config.d_model,
|
||||||
config.attention_dropout,
|
config.encoder_attention_heads,
|
||||||
false,
|
config.attention_dropout,
|
||||||
false,
|
false,
|
||||||
output_attention);
|
false,
|
||||||
let self_attention_layer_norm = nn::layer_norm(&p / "self_attn_layer_norm",
|
output_attention,
|
||||||
vec![config.d_model],
|
);
|
||||||
layer_norm_config);
|
let self_attention_layer_norm = nn::layer_norm(
|
||||||
|
&p / "self_attn_layer_norm",
|
||||||
|
vec![config.d_model],
|
||||||
|
layer_norm_config,
|
||||||
|
);
|
||||||
let dropout = Dropout::new(config.dropout);
|
let dropout = Dropout::new(config.dropout);
|
||||||
let activation_dropout = Dropout::new(config.activation_dropout);
|
let activation_dropout = Dropout::new(config.activation_dropout);
|
||||||
let activation_function = match &config.activation_function {
|
let activation_function = match &config.activation_function {
|
||||||
Some(act_function) => act_function,
|
Some(act_function) => act_function,
|
||||||
None => &Activation::gelu
|
None => &Activation::gelu,
|
||||||
};
|
};
|
||||||
let activation = Box::new(match activation_function {
|
let activation = Box::new(match activation_function {
|
||||||
Activation::gelu => _gelu,
|
Activation::gelu => _gelu,
|
||||||
Activation::relu => _relu,
|
Activation::relu => _relu,
|
||||||
Activation::swish => _swish,
|
Activation::swish => _swish,
|
||||||
Activation::gelu_new => _gelu_new,
|
Activation::gelu_new => _gelu_new,
|
||||||
Activation::tanh => _tanh
|
Activation::tanh => _tanh,
|
||||||
});
|
});
|
||||||
let fc1 = nn::linear(&p / "fc1", config.d_model, config.encoder_ffn_dim, Default::default());
|
let fc1 = nn::linear(
|
||||||
let fc2 = nn::linear(&p / "fc2", config.encoder_ffn_dim, config.d_model, Default::default());
|
&p / "fc1",
|
||||||
|
config.d_model,
|
||||||
|
config.encoder_ffn_dim,
|
||||||
|
Default::default(),
|
||||||
|
);
|
||||||
|
let fc2 = nn::linear(
|
||||||
|
&p / "fc2",
|
||||||
|
config.encoder_ffn_dim,
|
||||||
|
config.d_model,
|
||||||
|
Default::default(),
|
||||||
|
);
|
||||||
|
|
||||||
let final_layer_norm = nn::layer_norm(&p / "final_layer_norm",
|
let final_layer_norm = nn::layer_norm(
|
||||||
vec![config.d_model],
|
&p / "final_layer_norm",
|
||||||
layer_norm_config);
|
vec![config.d_model],
|
||||||
|
layer_norm_config,
|
||||||
|
);
|
||||||
|
|
||||||
EncoderLayer { self_attention, self_attention_layer_norm, dropout, activation_dropout, activation, fc1, fc2, final_layer_norm }
|
EncoderLayer {
|
||||||
|
self_attention,
|
||||||
|
self_attention_layer_norm,
|
||||||
|
dropout,
|
||||||
|
activation_dropout,
|
||||||
|
activation,
|
||||||
|
fc1,
|
||||||
|
fc2,
|
||||||
|
final_layer_norm,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn forward_t(&self, x: &Tensor, encoder_padding_mask: Option<&Tensor>, train: bool) -> (Tensor, Option<Tensor>) {
|
pub fn forward_t(
|
||||||
let (output, attention_weights, _) = self.self_attention.forward_t(x, None, encoder_padding_mask, None, None, train);
|
&self,
|
||||||
|
x: &Tensor,
|
||||||
|
encoder_padding_mask: Option<&Tensor>,
|
||||||
|
train: bool,
|
||||||
|
) -> (Tensor, Option<Tensor>) {
|
||||||
|
let (output, attention_weights, _) =
|
||||||
|
self.self_attention
|
||||||
|
.forward_t(x, None, encoder_padding_mask, None, None, train);
|
||||||
let output: Tensor = output.apply_t(&self.dropout, train) + x;
|
let output: Tensor = output.apply_t(&self.dropout, train) + x;
|
||||||
let output = output.apply(&self.self_attention_layer_norm);
|
let output = output.apply(&self.self_attention_layer_norm);
|
||||||
|
|
||||||
@ -102,57 +139,72 @@ impl BartEncoder {
|
|||||||
pub fn new(p: nn::Path, config: &BartConfig) -> BartEncoder {
|
pub fn new(p: nn::Path, config: &BartConfig) -> BartEncoder {
|
||||||
let output_attentions = match config.output_attentions {
|
let output_attentions = match config.output_attentions {
|
||||||
Some(value) => value,
|
Some(value) => value,
|
||||||
None => false
|
None => false,
|
||||||
};
|
};
|
||||||
let output_hidden_states = match config.output_hidden_states {
|
let output_hidden_states = match config.output_hidden_states {
|
||||||
Some(value) => value,
|
Some(value) => value,
|
||||||
None => false
|
None => false,
|
||||||
};
|
};
|
||||||
let normalize_embedding = match config.normalize_embedding {
|
let normalize_embedding = match config.normalize_embedding {
|
||||||
Some(value) => value,
|
Some(value) => value,
|
||||||
None => true
|
None => true,
|
||||||
};
|
};
|
||||||
let static_position_embeddings = match config.static_position_embeddings {
|
let static_position_embeddings = match config.static_position_embeddings {
|
||||||
Some(value) => value,
|
Some(value) => value,
|
||||||
None => false
|
None => false,
|
||||||
};
|
};
|
||||||
let scale_embedding = match config.scale_embedding {
|
let scale_embedding = match config.scale_embedding {
|
||||||
Some(value) => if value { (config.d_model as f64).sqrt() } else { 1.0 },
|
Some(value) => {
|
||||||
None => 1.0
|
if value {
|
||||||
|
(config.d_model as f64).sqrt()
|
||||||
|
} else {
|
||||||
|
1.0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
None => 1.0,
|
||||||
};
|
};
|
||||||
|
|
||||||
let dropout = Dropout::new(config.dropout);
|
let dropout = Dropout::new(config.dropout);
|
||||||
|
|
||||||
let layer_norm_embedding = if normalize_embedding {
|
let layer_norm_embedding = if normalize_embedding {
|
||||||
let layer_norm_config = nn::LayerNormConfig { eps: 1e-5, ..Default::default() };
|
let layer_norm_config = nn::LayerNormConfig {
|
||||||
Some(nn::layer_norm(&p / "layernorm_embedding",
|
eps: 1e-5,
|
||||||
vec![config.d_model],
|
..Default::default()
|
||||||
layer_norm_config))
|
};
|
||||||
|
Some(nn::layer_norm(
|
||||||
|
&p / "layernorm_embedding",
|
||||||
|
vec![config.d_model],
|
||||||
|
layer_norm_config,
|
||||||
|
))
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
};
|
};
|
||||||
|
|
||||||
let pad_token_id = match config.pad_token_id {
|
let pad_token_id = match config.pad_token_id {
|
||||||
Some(value) => value,
|
Some(value) => value,
|
||||||
None => 1
|
None => 1,
|
||||||
};
|
};
|
||||||
|
|
||||||
let embed_positions = if static_position_embeddings {
|
let embed_positions = if static_position_embeddings {
|
||||||
EmbeddingOption::SinusoidalPositionalEmbedding(SinusoidalPositionalEmbedding::new(&p / "embed_positions",
|
EmbeddingOption::SinusoidalPositionalEmbedding(SinusoidalPositionalEmbedding::new(
|
||||||
config.max_position_embeddings,
|
&p / "embed_positions",
|
||||||
config.d_model))
|
config.max_position_embeddings,
|
||||||
|
config.d_model,
|
||||||
|
))
|
||||||
} else {
|
} else {
|
||||||
EmbeddingOption::LearnedPositionalEmbedding(LearnedPositionalEmbedding::new(&p / "embed_positions",
|
EmbeddingOption::LearnedPositionalEmbedding(LearnedPositionalEmbedding::new(
|
||||||
config.max_position_embeddings,
|
&p / "embed_positions",
|
||||||
config.d_model,
|
config.max_position_embeddings,
|
||||||
pad_token_id))
|
config.d_model,
|
||||||
|
pad_token_id,
|
||||||
|
))
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut layers: Vec<EncoderLayer> = vec!();
|
let mut layers: Vec<EncoderLayer> = vec![];
|
||||||
let p_layers = &p / "layers";
|
let p_layers = &p / "layers";
|
||||||
for layer_index in 0..config.encoder_layers {
|
for layer_index in 0..config.encoder_layers {
|
||||||
layers.push(EncoderLayer::new(&p_layers / layer_index, config));
|
layers.push(EncoderLayer::new(&p_layers / layer_index, config));
|
||||||
};
|
}
|
||||||
|
|
||||||
BartEncoder {
|
BartEncoder {
|
||||||
dropout,
|
dropout,
|
||||||
@ -165,26 +217,37 @@ impl BartEncoder {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn forward_t(&self,
|
pub fn forward_t(
|
||||||
input_ids: &Tensor,
|
&self,
|
||||||
attention_mask: Option<&Tensor>,
|
input_ids: &Tensor,
|
||||||
embeddings: &nn::Embedding,
|
attention_mask: Option<&Tensor>,
|
||||||
train: bool)
|
embeddings: &nn::Embedding,
|
||||||
-> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
|
train: bool,
|
||||||
|
) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
|
||||||
let attention_mask = match attention_mask {
|
let attention_mask = match attention_mask {
|
||||||
Some(mask) => Some(mask.eq(0).to_kind(Bool)),
|
Some(mask) => Some(mask.eq(0).to_kind(Bool)),
|
||||||
None => None
|
None => None,
|
||||||
};
|
};
|
||||||
|
|
||||||
let x = input_ids.apply(embeddings) * self.scale_embedding;
|
let x = input_ids.apply(embeddings) * self.scale_embedding;
|
||||||
let x: Tensor = x + &self.embed_positions.forward(input_ids, false);
|
let x: Tensor = x + &self.embed_positions.forward(input_ids, false);
|
||||||
let x = if let Some(layer_norm_embedding) = &self.layer_norm_embedding { x.apply(layer_norm_embedding) } else { x };
|
let x = if let Some(layer_norm_embedding) = &self.layer_norm_embedding {
|
||||||
let x = x
|
x.apply(layer_norm_embedding)
|
||||||
.apply_t(&self.dropout, train)
|
} else {
|
||||||
.transpose(0, 1);
|
x
|
||||||
|
};
|
||||||
|
let x = x.apply_t(&self.dropout, train).transpose(0, 1);
|
||||||
|
|
||||||
let mut all_hidden_states: Option<Vec<Tensor>> = if self.output_hidden_states { Some(vec!()) } else { None };
|
let mut all_hidden_states: Option<Vec<Tensor>> = if self.output_hidden_states {
|
||||||
let mut all_attentions: Option<Vec<Tensor>> = if self.output_attentions { Some(vec!()) } else { None };
|
Some(vec![])
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
let mut all_attentions: Option<Vec<Tensor>> = if self.output_attentions {
|
||||||
|
Some(vec![])
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
let mut hidden_state = x.copy();
|
let mut hidden_state = x.copy();
|
||||||
let mut attention_weights: Option<Tensor>;
|
let mut attention_weights: Option<Tensor>;
|
||||||
@ -204,13 +267,17 @@ impl BartEncoder {
|
|||||||
attentions.push(attention_weights.as_ref().unwrap().copy());
|
attentions.push(attention_weights.as_ref().unwrap().copy());
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
None => break
|
None => break,
|
||||||
};
|
};
|
||||||
};
|
}
|
||||||
if let Some(hidden_states) = all_hidden_states.borrow_mut() {
|
if let Some(hidden_states) = all_hidden_states.borrow_mut() {
|
||||||
hidden_states.push(hidden_state.as_ref().copy().transpose(0, 1));
|
hidden_states.push(hidden_state.as_ref().copy().transpose(0, 1));
|
||||||
};
|
};
|
||||||
|
|
||||||
(hidden_state.transpose(0, 1), all_hidden_states, all_attentions)
|
(
|
||||||
|
hidden_state.transpose(0, 1),
|
||||||
|
all_hidden_states,
|
||||||
|
all_attentions,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -15,19 +15,27 @@
|
|||||||
//! Pretrained models are available and can be downloaded using RemoteResources.
|
//! Pretrained models are available and can be downloaded using RemoteResources.
|
||||||
//!
|
//!
|
||||||
//! ```no_run
|
//! ```no_run
|
||||||
//!# fn main() -> failure::Fallible<()> {
|
//! # fn main() -> failure::Fallible<()> {
|
||||||
//!#
|
//! #
|
||||||
//! use rust_tokenizers::RobertaTokenizer;
|
//! use rust_tokenizers::RobertaTokenizer;
|
||||||
//! use tch::{nn, Device};
|
//! use tch::{nn, Device};
|
||||||
//!# use std::path::PathBuf;
|
//! # use std::path::PathBuf;
|
||||||
//! use rust_bert::Config;
|
|
||||||
//! use rust_bert::bart::{BartConfig, BartModel};
|
//! use rust_bert::bart::{BartConfig, BartModel};
|
||||||
//! use rust_bert::resources::{Resource, download_resource, LocalResource};
|
//! use rust_bert::resources::{download_resource, LocalResource, Resource};
|
||||||
|
//! use rust_bert::Config;
|
||||||
//!
|
//!
|
||||||
//! let config_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/config.json")});
|
//! let config_resource = Resource::Local(LocalResource {
|
||||||
//! let vocab_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/vocab.txt")});
|
//! local_path: PathBuf::from("path/to/config.json"),
|
||||||
//! let merges_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/vocab.txt")});
|
//! });
|
||||||
//! let weights_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/model.ot")});
|
//! let vocab_resource = Resource::Local(LocalResource {
|
||||||
|
//! local_path: PathBuf::from("path/to/vocab.txt"),
|
||||||
|
//! });
|
||||||
|
//! let merges_resource = Resource::Local(LocalResource {
|
||||||
|
//! local_path: PathBuf::from("path/to/vocab.txt"),
|
||||||
|
//! });
|
||||||
|
//! let weights_resource = Resource::Local(LocalResource {
|
||||||
|
//! local_path: PathBuf::from("path/to/model.ot"),
|
||||||
|
//! });
|
||||||
//! let config_path = download_resource(&config_resource)?;
|
//! let config_path = download_resource(&config_resource)?;
|
||||||
//! let vocab_path = download_resource(&vocab_resource)?;
|
//! let vocab_path = download_resource(&vocab_resource)?;
|
||||||
//! let merges_path = download_resource(&merges_resource)?;
|
//! let merges_path = download_resource(&merges_resource)?;
|
||||||
@ -35,21 +43,28 @@
|
|||||||
//!
|
//!
|
||||||
//! let device = Device::cuda_if_available();
|
//! let device = Device::cuda_if_available();
|
||||||
//! let mut vs = nn::VarStore::new(device);
|
//! let mut vs = nn::VarStore::new(device);
|
||||||
//! let tokenizer: RobertaTokenizer = RobertaTokenizer::from_file(vocab_path.to_str().unwrap(), merges_path.to_str().unwrap(), true);
|
//! let tokenizer: RobertaTokenizer = RobertaTokenizer::from_file(
|
||||||
|
//! vocab_path.to_str().unwrap(),
|
||||||
|
//! merges_path.to_str().unwrap(),
|
||||||
|
//! true,
|
||||||
|
//! );
|
||||||
//! let config = BartConfig::from_file(config_path);
|
//! let config = BartConfig::from_file(config_path);
|
||||||
//! let bart_model = BartModel::new(&vs.root(), &config, false);
|
//! let bart_model = BartModel::new(&vs.root(), &config, false);
|
||||||
//! vs.load(weights_path)?;
|
//! vs.load(weights_path)?;
|
||||||
//!
|
//!
|
||||||
//!# Ok(())
|
//! # Ok(())
|
||||||
//!# }
|
//! # }
|
||||||
//! ```
|
//! ```
|
||||||
|
|
||||||
mod bart;
|
|
||||||
mod attention;
|
mod attention;
|
||||||
mod encoder;
|
mod bart;
|
||||||
mod decoder;
|
mod decoder;
|
||||||
mod embeddings;
|
mod embeddings;
|
||||||
|
mod encoder;
|
||||||
|
|
||||||
pub use bart::{BartModelResources, BartConfigResources, BartVocabResources, BartMergesResources,
|
pub use attention::LayerState;
|
||||||
BartConfig, Activation, BartModel, BartForSequenceClassification, BartForConditionalGeneration};
|
pub use bart::{
|
||||||
pub use attention::LayerState;
|
Activation, BartConfig, BartConfigResources, BartForConditionalGeneration,
|
||||||
|
BartForSequenceClassification, BartMergesResources, BartModel, BartModelResources,
|
||||||
|
BartVocabResources,
|
||||||
|
};
|
||||||
|
@ -11,11 +11,11 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
use crate::common::dropout::Dropout;
|
|
||||||
use tch::{nn, Tensor};
|
|
||||||
use tch::kind::Kind::Float;
|
|
||||||
use crate::bert::bert::{Activation, BertConfig};
|
use crate::bert::bert::{Activation, BertConfig};
|
||||||
use crate::common::activations::{_gelu, _relu, _mish};
|
use crate::common::activations::{_gelu, _mish, _relu};
|
||||||
|
use crate::common::dropout::Dropout;
|
||||||
|
use tch::kind::Kind::Float;
|
||||||
|
use tch::{nn, Tensor};
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct BertSelfAttention {
|
pub struct BertSelfAttention {
|
||||||
@ -30,17 +30,36 @@ pub struct BertSelfAttention {
|
|||||||
|
|
||||||
impl BertSelfAttention {
|
impl BertSelfAttention {
|
||||||
pub fn new(p: nn::Path, config: &BertConfig) -> BertSelfAttention {
|
pub fn new(p: nn::Path, config: &BertConfig) -> BertSelfAttention {
|
||||||
assert_eq!(config.hidden_size % config.num_attention_heads, 0, "Hidden size not a multiple of the number of attention heads");
|
assert_eq!(
|
||||||
|
config.hidden_size % config.num_attention_heads,
|
||||||
|
0,
|
||||||
|
"Hidden size not a multiple of the number of attention heads"
|
||||||
|
);
|
||||||
|
|
||||||
let query = nn::linear(&p / "query", config.hidden_size, config.hidden_size, Default::default());
|
let query = nn::linear(
|
||||||
let key = nn::linear(&p / "key", config.hidden_size, config.hidden_size, Default::default());
|
&p / "query",
|
||||||
let value = nn::linear(&p / "value", config.hidden_size, config.hidden_size, Default::default());
|
config.hidden_size,
|
||||||
|
config.hidden_size,
|
||||||
|
Default::default(),
|
||||||
|
);
|
||||||
|
let key = nn::linear(
|
||||||
|
&p / "key",
|
||||||
|
config.hidden_size,
|
||||||
|
config.hidden_size,
|
||||||
|
Default::default(),
|
||||||
|
);
|
||||||
|
let value = nn::linear(
|
||||||
|
&p / "value",
|
||||||
|
config.hidden_size,
|
||||||
|
config.hidden_size,
|
||||||
|
Default::default(),
|
||||||
|
);
|
||||||
|
|
||||||
let dropout = Dropout::new(config.attention_probs_dropout_prob);
|
let dropout = Dropout::new(config.attention_probs_dropout_prob);
|
||||||
let attention_head_size = config.hidden_size / config.num_attention_heads;
|
let attention_head_size = config.hidden_size / config.num_attention_heads;
|
||||||
let output_attentions = match config.output_attentions {
|
let output_attentions = match config.output_attentions {
|
||||||
Some(value) => value,
|
Some(value) => value,
|
||||||
None => false
|
None => false,
|
||||||
};
|
};
|
||||||
|
|
||||||
BertSelfAttention {
|
BertSelfAttention {
|
||||||
@ -55,35 +74,44 @@ impl BertSelfAttention {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn split_heads(&self, x: Tensor, bs: i64, dim_per_head: i64) -> Tensor {
|
fn split_heads(&self, x: Tensor, bs: i64, dim_per_head: i64) -> Tensor {
|
||||||
x.view((bs, -1, self.num_attention_heads, dim_per_head)).transpose(1, 2)
|
x.view((bs, -1, self.num_attention_heads, dim_per_head))
|
||||||
|
.transpose(1, 2)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn flatten(&self, x: Tensor, bs: i64, dim_per_head: i64) -> Tensor {
|
fn flatten(&self, x: Tensor, bs: i64, dim_per_head: i64) -> Tensor {
|
||||||
x.transpose(1, 2).contiguous().view((bs, -1, &self.num_attention_heads * dim_per_head))
|
x.transpose(1, 2)
|
||||||
|
.contiguous()
|
||||||
|
.view((bs, -1, &self.num_attention_heads * dim_per_head))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn forward_t(&self,
|
pub fn forward_t(
|
||||||
hidden_states: &Tensor,
|
&self,
|
||||||
mask: &Option<Tensor>,
|
hidden_states: &Tensor,
|
||||||
encoder_hidden_states: &Option<Tensor>,
|
mask: &Option<Tensor>,
|
||||||
encoder_mask: &Option<Tensor>,
|
encoder_hidden_states: &Option<Tensor>,
|
||||||
train: bool) -> (Tensor, Option<Tensor>) {
|
encoder_mask: &Option<Tensor>,
|
||||||
|
train: bool,
|
||||||
|
) -> (Tensor, Option<Tensor>) {
|
||||||
let (key_layer, value_layer, mask) = match encoder_hidden_states {
|
let (key_layer, value_layer, mask) = match encoder_hidden_states {
|
||||||
Some(encoder_hidden_state_values) => {
|
Some(encoder_hidden_state_values) => (
|
||||||
(encoder_hidden_state_values.apply(&self.key),
|
encoder_hidden_state_values.apply(&self.key),
|
||||||
encoder_hidden_state_values.apply(&self.value),
|
encoder_hidden_state_values.apply(&self.value),
|
||||||
encoder_mask)
|
encoder_mask,
|
||||||
}
|
),
|
||||||
None => {
|
None => (
|
||||||
(hidden_states.apply(&self.key),
|
hidden_states.apply(&self.key),
|
||||||
hidden_states.apply(&self.value),
|
hidden_states.apply(&self.value),
|
||||||
mask)
|
mask,
|
||||||
}
|
),
|
||||||
};
|
};
|
||||||
|
|
||||||
let bs = hidden_states.size()[0];
|
let bs = hidden_states.size()[0];
|
||||||
|
|
||||||
let query_layer = self.split_heads(hidden_states.apply(&self.query), bs, self.attention_head_size);
|
let query_layer = self.split_heads(
|
||||||
|
hidden_states.apply(&self.query),
|
||||||
|
bs,
|
||||||
|
self.attention_head_size,
|
||||||
|
);
|
||||||
let key_layer = self.split_heads(key_layer, bs, self.attention_head_size);
|
let key_layer = self.split_heads(key_layer, bs, self.attention_head_size);
|
||||||
let value_layer = self.split_heads(value_layer, bs, self.attention_head_size);
|
let value_layer = self.split_heads(value_layer, bs, self.attention_head_size);
|
||||||
let query_layer: Tensor = query_layer / (self.attention_head_size as f64).sqrt();
|
let query_layer: Tensor = query_layer / (self.attention_head_size as f64).sqrt();
|
||||||
@ -114,16 +142,32 @@ pub struct BertSelfOutput {
|
|||||||
|
|
||||||
impl BertSelfOutput {
|
impl BertSelfOutput {
|
||||||
pub fn new(p: &nn::Path, config: &BertConfig) -> BertSelfOutput {
|
pub fn new(p: &nn::Path, config: &BertConfig) -> BertSelfOutput {
|
||||||
let linear = nn::linear(p / "dense", config.hidden_size, config.hidden_size, Default::default());
|
let linear = nn::linear(
|
||||||
let layer_norm_config = nn::LayerNormConfig { eps: 1e-12, ..Default::default() };
|
p / "dense",
|
||||||
let layer_norm = nn::layer_norm(p / "LayerNorm", vec![config.hidden_size], layer_norm_config);
|
config.hidden_size,
|
||||||
|
config.hidden_size,
|
||||||
|
Default::default(),
|
||||||
|
);
|
||||||
|
let layer_norm_config = nn::LayerNormConfig {
|
||||||
|
eps: 1e-12,
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
let layer_norm =
|
||||||
|
nn::layer_norm(p / "LayerNorm", vec![config.hidden_size], layer_norm_config);
|
||||||
let dropout = Dropout::new(config.hidden_dropout_prob);
|
let dropout = Dropout::new(config.hidden_dropout_prob);
|
||||||
|
|
||||||
BertSelfOutput { linear, layer_norm, dropout }
|
BertSelfOutput {
|
||||||
|
linear,
|
||||||
|
layer_norm,
|
||||||
|
dropout,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn forward_t(&self, hidden_states: &Tensor, input_tensor: &Tensor, train: bool) -> Tensor {
|
pub fn forward_t(&self, hidden_states: &Tensor, input_tensor: &Tensor, train: bool) -> Tensor {
|
||||||
let hidden_states: Tensor = input_tensor + hidden_states.apply(&self.linear).apply_t(&self.dropout, train);
|
let hidden_states: Tensor = input_tensor
|
||||||
|
+ hidden_states
|
||||||
|
.apply(&self.linear)
|
||||||
|
.apply_t(&self.dropout, train);
|
||||||
hidden_states.apply(&self.layer_norm)
|
hidden_states.apply(&self.layer_norm)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -141,14 +185,21 @@ impl BertAttention {
|
|||||||
BertAttention { _self, output }
|
BertAttention { _self, output }
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn forward_t(&self,
|
pub fn forward_t(
|
||||||
hidden_states: &Tensor,
|
&self,
|
||||||
mask: &Option<Tensor>,
|
hidden_states: &Tensor,
|
||||||
encoder_hidden_states: &Option<Tensor>,
|
mask: &Option<Tensor>,
|
||||||
encoder_mask: &Option<Tensor>,
|
encoder_hidden_states: &Option<Tensor>,
|
||||||
train: bool) -> (Tensor, Option<Tensor>) {
|
encoder_mask: &Option<Tensor>,
|
||||||
let (self_output, attention_weights) = self._self.
|
train: bool,
|
||||||
forward_t(hidden_states, mask, encoder_hidden_states, encoder_mask, train);
|
) -> (Tensor, Option<Tensor>) {
|
||||||
|
let (self_output, attention_weights) = self._self.forward_t(
|
||||||
|
hidden_states,
|
||||||
|
mask,
|
||||||
|
encoder_hidden_states,
|
||||||
|
encoder_mask,
|
||||||
|
train,
|
||||||
|
);
|
||||||
|
|
||||||
let self_output = self.output.forward_t(&self_output, hidden_states, train);
|
let self_output = self.output.forward_t(&self_output, hidden_states, train);
|
||||||
(self_output, attention_weights)
|
(self_output, attention_weights)
|
||||||
@ -162,11 +213,16 @@ pub struct BertIntermediate {
|
|||||||
|
|
||||||
impl BertIntermediate {
|
impl BertIntermediate {
|
||||||
pub fn new(p: &nn::Path, config: &BertConfig) -> BertIntermediate {
|
pub fn new(p: &nn::Path, config: &BertConfig) -> BertIntermediate {
|
||||||
let lin = nn::linear(p / "dense", config.hidden_size, config.intermediate_size, Default::default());
|
let lin = nn::linear(
|
||||||
|
p / "dense",
|
||||||
|
config.hidden_size,
|
||||||
|
config.intermediate_size,
|
||||||
|
Default::default(),
|
||||||
|
);
|
||||||
let activation = Box::new(match &config.hidden_act {
|
let activation = Box::new(match &config.hidden_act {
|
||||||
Activation::gelu => _gelu,
|
Activation::gelu => _gelu,
|
||||||
Activation::relu => _relu,
|
Activation::relu => _relu,
|
||||||
Activation::mish => _mish
|
Activation::mish => _mish,
|
||||||
});
|
});
|
||||||
BertIntermediate { lin, activation }
|
BertIntermediate { lin, activation }
|
||||||
}
|
}
|
||||||
@ -184,17 +240,30 @@ pub struct BertOutput {
|
|||||||
|
|
||||||
impl BertOutput {
|
impl BertOutput {
|
||||||
pub fn new(p: &nn::Path, config: &BertConfig) -> BertOutput {
|
pub fn new(p: &nn::Path, config: &BertConfig) -> BertOutput {
|
||||||
let lin = nn::linear(p / "dense", config.intermediate_size, config.hidden_size, Default::default());
|
let lin = nn::linear(
|
||||||
let layer_norm_config = nn::LayerNormConfig { eps: 1e-12, ..Default::default() };
|
p / "dense",
|
||||||
let layer_norm = nn::layer_norm(p / "LayerNorm", vec![config.hidden_size], layer_norm_config);
|
config.intermediate_size,
|
||||||
|
config.hidden_size,
|
||||||
|
Default::default(),
|
||||||
|
);
|
||||||
|
let layer_norm_config = nn::LayerNormConfig {
|
||||||
|
eps: 1e-12,
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
let layer_norm =
|
||||||
|
nn::layer_norm(p / "LayerNorm", vec![config.hidden_size], layer_norm_config);
|
||||||
let dropout = Dropout::new(config.hidden_dropout_prob);
|
let dropout = Dropout::new(config.hidden_dropout_prob);
|
||||||
|
|
||||||
BertOutput { lin, layer_norm, dropout }
|
BertOutput {
|
||||||
|
lin,
|
||||||
|
layer_norm,
|
||||||
|
dropout,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn forward_t(&self, hidden_states: &Tensor, input_tensor: &Tensor, train: bool) -> Tensor {
|
pub fn forward_t(&self, hidden_states: &Tensor, input_tensor: &Tensor, train: bool) -> Tensor {
|
||||||
let hidden_states: Tensor = input_tensor + hidden_states.apply(&self.lin).apply_t(&self.dropout, train);
|
let hidden_states: Tensor =
|
||||||
|
input_tensor + hidden_states.apply(&self.lin).apply_t(&self.dropout, train);
|
||||||
hidden_states.apply(&self.layer_norm)
|
hidden_states.apply(&self.layer_norm)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
799
src/bert/bert.rs
799
src/bert/bert.rs
File diff suppressed because it is too large
Load Diff
@ -11,22 +11,24 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
use tch::{nn, Tensor, Kind};
|
|
||||||
use tch::nn::{EmbeddingConfig, embedding};
|
|
||||||
use crate::common::dropout::Dropout;
|
|
||||||
use crate::bert::bert::BertConfig;
|
use crate::bert::bert::BertConfig;
|
||||||
|
use crate::common::dropout::Dropout;
|
||||||
|
use tch::nn::{embedding, EmbeddingConfig};
|
||||||
|
use tch::{nn, Kind, Tensor};
|
||||||
|
|
||||||
/// # BertEmbedding trait (for use in BertModel or RoBERTaModel)
|
/// # BertEmbedding trait (for use in BertModel or RoBERTaModel)
|
||||||
/// Defines an interface for the embedding layers in BERT-based models
|
/// Defines an interface for the embedding layers in BERT-based models
|
||||||
pub trait BertEmbedding {
|
pub trait BertEmbedding {
|
||||||
fn new(p: &nn::Path, config: &BertConfig) -> Self;
|
fn new(p: &nn::Path, config: &BertConfig) -> Self;
|
||||||
|
|
||||||
fn forward_t(&self,
|
fn forward_t(
|
||||||
input_ids: Option<Tensor>,
|
&self,
|
||||||
token_type_ids: Option<Tensor>,
|
input_ids: Option<Tensor>,
|
||||||
position_ids: Option<Tensor>,
|
token_type_ids: Option<Tensor>,
|
||||||
input_embeds: Option<Tensor>,
|
position_ids: Option<Tensor>,
|
||||||
train: bool) -> Result<Tensor, &'static str>;
|
input_embeds: Option<Tensor>,
|
||||||
|
train: bool,
|
||||||
|
) -> Result<Tensor, &'static str>;
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
@ -51,10 +53,10 @@ impl BertEmbedding for BertEmbeddings {
|
|||||||
/// # Example
|
/// # Example
|
||||||
///
|
///
|
||||||
/// ```no_run
|
/// ```no_run
|
||||||
/// use rust_bert::bert::{BertConfig, BertEmbeddings, BertEmbedding};
|
/// use rust_bert::bert::{BertConfig, BertEmbedding, BertEmbeddings};
|
||||||
/// use tch::{nn, Device};
|
|
||||||
/// use rust_bert::Config;
|
/// use rust_bert::Config;
|
||||||
/// use std::path::Path;
|
/// use std::path::Path;
|
||||||
|
/// use tch::{nn, Device};
|
||||||
///
|
///
|
||||||
/// let config_path = Path::new("path/to/config.json");
|
/// let config_path = Path::new("path/to/config.json");
|
||||||
/// let device = Device::Cpu;
|
/// let device = Device::Cpu;
|
||||||
@ -62,29 +64,47 @@ impl BertEmbedding for BertEmbeddings {
|
|||||||
/// let config = BertConfig::from_file(config_path);
|
/// let config = BertConfig::from_file(config_path);
|
||||||
/// let bert_embeddings = BertEmbeddings::new(&(&p.root() / "bert_embeddings"), &config);
|
/// let bert_embeddings = BertEmbeddings::new(&(&p.root() / "bert_embeddings"), &config);
|
||||||
/// ```
|
/// ```
|
||||||
///
|
|
||||||
fn new(p: &nn::Path, config: &BertConfig) -> BertEmbeddings {
|
fn new(p: &nn::Path, config: &BertConfig) -> BertEmbeddings {
|
||||||
let embedding_config = EmbeddingConfig { padding_idx: 0, ..Default::default() };
|
let embedding_config = EmbeddingConfig {
|
||||||
|
padding_idx: 0,
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
|
||||||
let word_embeddings: nn::Embedding = embedding(p / "word_embeddings",
|
let word_embeddings: nn::Embedding = embedding(
|
||||||
config.vocab_size,
|
p / "word_embeddings",
|
||||||
config.hidden_size,
|
config.vocab_size,
|
||||||
embedding_config);
|
config.hidden_size,
|
||||||
|
embedding_config,
|
||||||
|
);
|
||||||
|
|
||||||
let position_embeddings: nn::Embedding = embedding(p / "position_embeddings",
|
let position_embeddings: nn::Embedding = embedding(
|
||||||
config.max_position_embeddings,
|
p / "position_embeddings",
|
||||||
config.hidden_size,
|
config.max_position_embeddings,
|
||||||
Default::default());
|
config.hidden_size,
|
||||||
|
Default::default(),
|
||||||
|
);
|
||||||
|
|
||||||
let token_type_embeddings: nn::Embedding = embedding(p / "token_type_embeddings",
|
let token_type_embeddings: nn::Embedding = embedding(
|
||||||
config.type_vocab_size,
|
p / "token_type_embeddings",
|
||||||
config.hidden_size,
|
config.type_vocab_size,
|
||||||
Default::default());
|
config.hidden_size,
|
||||||
|
Default::default(),
|
||||||
|
);
|
||||||
|
|
||||||
let layer_norm_config = nn::LayerNormConfig { eps: 1e-12, ..Default::default() };
|
let layer_norm_config = nn::LayerNormConfig {
|
||||||
let layer_norm: nn::LayerNorm = nn::layer_norm(p / "LayerNorm", vec![config.hidden_size], layer_norm_config);
|
eps: 1e-12,
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
let layer_norm: nn::LayerNorm =
|
||||||
|
nn::layer_norm(p / "LayerNorm", vec![config.hidden_size], layer_norm_config);
|
||||||
let dropout: Dropout = Dropout::new(config.hidden_dropout_prob);
|
let dropout: Dropout = Dropout::new(config.hidden_dropout_prob);
|
||||||
BertEmbeddings { word_embeddings, position_embeddings, token_type_embeddings, layer_norm, dropout }
|
BertEmbeddings {
|
||||||
|
word_embeddings,
|
||||||
|
position_embeddings,
|
||||||
|
token_type_embeddings,
|
||||||
|
layer_norm,
|
||||||
|
dropout,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Forward pass through the embedding layer
|
/// Forward pass through the embedding layer
|
||||||
@ -104,50 +124,62 @@ impl BertEmbedding for BertEmbeddings {
|
|||||||
/// # Example
|
/// # Example
|
||||||
///
|
///
|
||||||
/// ```no_run
|
/// ```no_run
|
||||||
///# use rust_bert::bert::{BertConfig, BertEmbeddings, BertEmbedding};
|
/// # use rust_bert::bert::{BertConfig, BertEmbeddings, BertEmbedding};
|
||||||
///# use tch::{nn, Device, Tensor, no_grad};
|
/// # use tch::{nn, Device, Tensor, no_grad};
|
||||||
///# use rust_bert::Config;
|
/// # use rust_bert::Config;
|
||||||
///# use std::path::Path;
|
/// # use std::path::Path;
|
||||||
///# use tch::kind::Kind::Int64;
|
/// # use tch::kind::Kind::Int64;
|
||||||
///# let config_path = Path::new("path/to/config.json");
|
/// # let config_path = Path::new("path/to/config.json");
|
||||||
///# let vocab_path = Path::new("path/to/vocab.txt");
|
/// # let vocab_path = Path::new("path/to/vocab.txt");
|
||||||
///# let device = Device::Cpu;
|
/// # let device = Device::Cpu;
|
||||||
///# let vs = nn::VarStore::new(device);
|
/// # let vs = nn::VarStore::new(device);
|
||||||
///# let config = BertConfig::from_file(config_path);
|
/// # let config = BertConfig::from_file(config_path);
|
||||||
///# let bert_embeddings = BertEmbeddings::new(&vs.root(), &config);
|
/// # let bert_embeddings = BertEmbeddings::new(&vs.root(), &config);
|
||||||
/// let (batch_size, sequence_length) = (64, 128);
|
/// let (batch_size, sequence_length) = (64, 128);
|
||||||
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
|
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
|
||||||
/// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
|
/// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
|
||||||
/// let position_ids = Tensor::arange(sequence_length, (Int64, device)).expand(&[batch_size, sequence_length], true);
|
/// let position_ids = Tensor::arange(sequence_length, (Int64, device))
|
||||||
|
/// .expand(&[batch_size, sequence_length], true);
|
||||||
///
|
///
|
||||||
/// let embedded_output = no_grad(|| {
|
/// let embedded_output = no_grad(|| {
|
||||||
/// bert_embeddings
|
/// bert_embeddings
|
||||||
/// .forward_t(Some(input_tensor),
|
/// .forward_t(
|
||||||
/// Some(token_type_ids),
|
/// Some(input_tensor),
|
||||||
/// Some(position_ids),
|
/// Some(token_type_ids),
|
||||||
/// None,
|
/// Some(position_ids),
|
||||||
/// false).unwrap()
|
/// None,
|
||||||
/// });
|
/// false,
|
||||||
|
/// )
|
||||||
|
/// .unwrap()
|
||||||
|
/// });
|
||||||
/// ```
|
/// ```
|
||||||
///
|
fn forward_t(
|
||||||
fn forward_t(&self,
|
&self,
|
||||||
input_ids: Option<Tensor>,
|
input_ids: Option<Tensor>,
|
||||||
token_type_ids: Option<Tensor>,
|
token_type_ids: Option<Tensor>,
|
||||||
position_ids: Option<Tensor>,
|
position_ids: Option<Tensor>,
|
||||||
input_embeds: Option<Tensor>,
|
input_embeds: Option<Tensor>,
|
||||||
train: bool) -> Result<Tensor, &'static str> {
|
train: bool,
|
||||||
|
) -> Result<Tensor, &'static str> {
|
||||||
let (input_embeddings, input_shape) = match input_ids {
|
let (input_embeddings, input_shape) = match input_ids {
|
||||||
Some(input_value) => match input_embeds {
|
Some(input_value) => match input_embeds {
|
||||||
Some(_) => { return Err("Only one of input ids or input embeddings may be set"); }
|
Some(_) => {
|
||||||
None => (input_value.apply_t(&self.word_embeddings, train), input_value.size())
|
return Err("Only one of input ids or input embeddings may be set");
|
||||||
}
|
}
|
||||||
|
None => (
|
||||||
|
input_value.apply_t(&self.word_embeddings, train),
|
||||||
|
input_value.size(),
|
||||||
|
),
|
||||||
|
},
|
||||||
None => match input_embeds {
|
None => match input_embeds {
|
||||||
Some(embeds) => {
|
Some(embeds) => {
|
||||||
let size = vec!(embeds.size()[0], embeds.size()[1]);
|
let size = vec![embeds.size()[0], embeds.size()[1]];
|
||||||
(embeds, size)
|
(embeds, size)
|
||||||
},
|
}
|
||||||
None => { return Err("Only one of input ids or input embeddings may be set"); }
|
None => {
|
||||||
}
|
return Err("Only one of input ids or input embeddings may be set");
|
||||||
|
}
|
||||||
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
let seq_length = input_embeddings.as_ref().size()[1].to_owned();
|
let seq_length = input_embeddings.as_ref().size()[1].to_owned();
|
||||||
@ -155,19 +187,22 @@ impl BertEmbedding for BertEmbeddings {
|
|||||||
let position_ids = match position_ids {
|
let position_ids = match position_ids {
|
||||||
Some(value) => value,
|
Some(value) => value,
|
||||||
None => Tensor::arange(seq_length, (Kind::Int64, input_embeddings.device()))
|
None => Tensor::arange(seq_length, (Kind::Int64, input_embeddings.device()))
|
||||||
.unsqueeze(0).
|
.unsqueeze(0)
|
||||||
expand(&input_shape, true)
|
.expand(&input_shape, true),
|
||||||
};
|
};
|
||||||
|
|
||||||
let token_type_ids = match token_type_ids {
|
let token_type_ids = match token_type_ids {
|
||||||
Some(value) => value,
|
Some(value) => value,
|
||||||
None => Tensor::zeros(&input_shape, (Kind::Int64, input_embeddings.device()))
|
None => Tensor::zeros(&input_shape, (Kind::Int64, input_embeddings.device())),
|
||||||
};
|
};
|
||||||
|
|
||||||
let position_embeddings = position_ids.apply(&self.position_embeddings);
|
let position_embeddings = position_ids.apply(&self.position_embeddings);
|
||||||
let token_type_embeddings = token_type_ids.apply(&self.token_type_embeddings);
|
let token_type_embeddings = token_type_ids.apply(&self.token_type_embeddings);
|
||||||
|
|
||||||
let input_embeddings: Tensor = input_embeddings + position_embeddings + token_type_embeddings;
|
let input_embeddings: Tensor =
|
||||||
Ok(input_embeddings.apply(&self.layer_norm).apply_t(&self.dropout, train))
|
input_embeddings + position_embeddings + token_type_embeddings;
|
||||||
|
Ok(input_embeddings
|
||||||
|
.apply(&self.layer_norm)
|
||||||
|
.apply_t(&self.dropout, train))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -11,10 +11,10 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
use tch::{Tensor, nn};
|
|
||||||
use crate::bert::attention::{BertAttention, BertIntermediate, BertOutput};
|
use crate::bert::attention::{BertAttention, BertIntermediate, BertOutput};
|
||||||
use std::borrow::BorrowMut;
|
|
||||||
use crate::bert::bert::BertConfig;
|
use crate::bert::bert::BertConfig;
|
||||||
|
use std::borrow::BorrowMut;
|
||||||
|
use tch::{nn, Tensor};
|
||||||
|
|
||||||
pub struct BertLayer {
|
pub struct BertLayer {
|
||||||
attention: BertAttention,
|
attention: BertAttention,
|
||||||
@ -30,37 +30,57 @@ impl BertLayer {
|
|||||||
let (is_decoder, cross_attention) = match config.is_decoder {
|
let (is_decoder, cross_attention) = match config.is_decoder {
|
||||||
Some(value) => {
|
Some(value) => {
|
||||||
if value == true {
|
if value == true {
|
||||||
(value, Some(BertAttention::new(&(p / "cross_attention"), &config)))
|
(
|
||||||
|
value,
|
||||||
|
Some(BertAttention::new(&(p / "cross_attention"), &config)),
|
||||||
|
)
|
||||||
} else {
|
} else {
|
||||||
(value, None)
|
(value, None)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
None => (false, None)
|
None => (false, None),
|
||||||
};
|
};
|
||||||
|
|
||||||
let intermediate = BertIntermediate::new(&(p / "intermediate"), &config);
|
let intermediate = BertIntermediate::new(&(p / "intermediate"), &config);
|
||||||
let output = BertOutput::new(&(p / "output"), &config);
|
let output = BertOutput::new(&(p / "output"), &config);
|
||||||
|
|
||||||
BertLayer { attention, is_decoder, cross_attention, intermediate, output }
|
BertLayer {
|
||||||
|
attention,
|
||||||
|
is_decoder,
|
||||||
|
cross_attention,
|
||||||
|
intermediate,
|
||||||
|
output,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn forward_t(&self,
|
pub fn forward_t(
|
||||||
hidden_states: &Tensor,
|
&self,
|
||||||
mask: &Option<Tensor>,
|
hidden_states: &Tensor,
|
||||||
encoder_hidden_states: &Option<Tensor>,
|
mask: &Option<Tensor>,
|
||||||
encoder_mask: &Option<Tensor>,
|
encoder_hidden_states: &Option<Tensor>,
|
||||||
train: bool) -> (Tensor, Option<Tensor>, Option<Tensor>) {
|
encoder_mask: &Option<Tensor>,
|
||||||
let (attention_output, attention_weights, cross_attention_weights) = if self.is_decoder & encoder_hidden_states.is_some() {
|
train: bool,
|
||||||
let (attention_output, attention_weights) =
|
) -> (Tensor, Option<Tensor>, Option<Tensor>) {
|
||||||
self.attention.forward_t(hidden_states, mask, &None, &None, train);
|
let (attention_output, attention_weights, cross_attention_weights) =
|
||||||
let (attention_output, cross_attention_weights) =
|
if self.is_decoder & encoder_hidden_states.is_some() {
|
||||||
self.cross_attention.as_ref().unwrap().forward_t(&attention_output, mask, encoder_hidden_states, encoder_mask, train);
|
let (attention_output, attention_weights) =
|
||||||
(attention_output, attention_weights, cross_attention_weights)
|
self.attention
|
||||||
} else {
|
.forward_t(hidden_states, mask, &None, &None, train);
|
||||||
let (attention_output, attention_weights) =
|
let (attention_output, cross_attention_weights) =
|
||||||
self.attention.forward_t(hidden_states, mask, &None, &None, train);
|
self.cross_attention.as_ref().unwrap().forward_t(
|
||||||
(attention_output, attention_weights, None)
|
&attention_output,
|
||||||
};
|
mask,
|
||||||
|
encoder_hidden_states,
|
||||||
|
encoder_mask,
|
||||||
|
train,
|
||||||
|
);
|
||||||
|
(attention_output, attention_weights, cross_attention_weights)
|
||||||
|
} else {
|
||||||
|
let (attention_output, attention_weights) =
|
||||||
|
self.attention
|
||||||
|
.forward_t(hidden_states, mask, &None, &None, train);
|
||||||
|
(attention_output, attention_weights, None)
|
||||||
|
};
|
||||||
|
|
||||||
let output = self.intermediate.forward(&attention_output);
|
let output = self.intermediate.forward(&attention_output);
|
||||||
let output = self.output.forward_t(&output, &attention_output, train);
|
let output = self.output.forward_t(&output, &attention_output, train);
|
||||||
@ -78,26 +98,47 @@ pub struct BertEncoder {
|
|||||||
impl BertEncoder {
|
impl BertEncoder {
|
||||||
pub fn new(p: &nn::Path, config: &BertConfig) -> BertEncoder {
|
pub fn new(p: &nn::Path, config: &BertConfig) -> BertEncoder {
|
||||||
let p = &(p / "layer");
|
let p = &(p / "layer");
|
||||||
let output_attentions = if let Some(value) = config.output_attentions { value } else { false };
|
let output_attentions = if let Some(value) = config.output_attentions {
|
||||||
let output_hidden_states = if let Some(value) = config.output_hidden_states { value } else { false };
|
value
|
||||||
|
} else {
|
||||||
let mut layers: Vec<BertLayer> = vec!();
|
false
|
||||||
for layer_index in 0..config.num_hidden_layers {
|
};
|
||||||
layers.push(BertLayer::new(&(p / layer_index), config));
|
let output_hidden_states = if let Some(value) = config.output_hidden_states {
|
||||||
|
value
|
||||||
|
} else {
|
||||||
|
false
|
||||||
};
|
};
|
||||||
|
|
||||||
BertEncoder { output_attentions, output_hidden_states, layers }
|
let mut layers: Vec<BertLayer> = vec![];
|
||||||
|
for layer_index in 0..config.num_hidden_layers {
|
||||||
|
layers.push(BertLayer::new(&(p / layer_index), config));
|
||||||
|
}
|
||||||
|
|
||||||
|
BertEncoder {
|
||||||
|
output_attentions,
|
||||||
|
output_hidden_states,
|
||||||
|
layers,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn forward_t(&self,
|
pub fn forward_t(
|
||||||
hidden_states: &Tensor,
|
&self,
|
||||||
mask: &Option<Tensor>,
|
hidden_states: &Tensor,
|
||||||
encoder_hidden_states: &Option<Tensor>,
|
mask: &Option<Tensor>,
|
||||||
encoder_mask: &Option<Tensor>,
|
encoder_hidden_states: &Option<Tensor>,
|
||||||
train: bool)
|
encoder_mask: &Option<Tensor>,
|
||||||
-> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
|
train: bool,
|
||||||
let mut all_hidden_states: Option<Vec<Tensor>> = if self.output_hidden_states { Some(vec!()) } else { None };
|
) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
|
||||||
let mut all_attentions: Option<Vec<Tensor>> = if self.output_attentions { Some(vec!()) } else { None };
|
let mut all_hidden_states: Option<Vec<Tensor>> = if self.output_hidden_states {
|
||||||
|
Some(vec![])
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
let mut all_attentions: Option<Vec<Tensor>> = if self.output_attentions {
|
||||||
|
Some(vec![])
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
let mut hidden_state = hidden_states.copy();
|
let mut hidden_state = hidden_states.copy();
|
||||||
let mut attention_weights: Option<Tensor>;
|
let mut attention_weights: Option<Tensor>;
|
||||||
@ -109,16 +150,22 @@ impl BertEncoder {
|
|||||||
hidden_states.push(hidden_state.as_ref().copy());
|
hidden_states.push(hidden_state.as_ref().copy());
|
||||||
};
|
};
|
||||||
|
|
||||||
let temp = layer.forward_t(&hidden_state, &mask, encoder_hidden_states, encoder_mask, train);
|
let temp = layer.forward_t(
|
||||||
|
&hidden_state,
|
||||||
|
&mask,
|
||||||
|
encoder_hidden_states,
|
||||||
|
encoder_mask,
|
||||||
|
train,
|
||||||
|
);
|
||||||
hidden_state = temp.0;
|
hidden_state = temp.0;
|
||||||
attention_weights = temp.1;
|
attention_weights = temp.1;
|
||||||
if let Some(attentions) = all_attentions.borrow_mut() {
|
if let Some(attentions) = all_attentions.borrow_mut() {
|
||||||
attentions.push(attention_weights.as_ref().unwrap().copy());
|
attentions.push(attention_weights.as_ref().unwrap().copy());
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
None => break
|
None => break,
|
||||||
};
|
};
|
||||||
};
|
}
|
||||||
|
|
||||||
(hidden_state, all_hidden_states, all_attentions)
|
(hidden_state, all_hidden_states, all_attentions)
|
||||||
}
|
}
|
||||||
@ -130,14 +177,16 @@ pub struct BertPooler {
|
|||||||
|
|
||||||
impl BertPooler {
|
impl BertPooler {
|
||||||
pub fn new(p: &nn::Path, config: &BertConfig) -> BertPooler {
|
pub fn new(p: &nn::Path, config: &BertConfig) -> BertPooler {
|
||||||
let lin = nn::linear(&(p / "dense"), config.hidden_size, config.hidden_size, Default::default());
|
let lin = nn::linear(
|
||||||
|
&(p / "dense"),
|
||||||
|
config.hidden_size,
|
||||||
|
config.hidden_size,
|
||||||
|
Default::default(),
|
||||||
|
);
|
||||||
BertPooler { lin }
|
BertPooler { lin }
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn forward(&self, hidden_states: &Tensor) -> Tensor {
|
pub fn forward(&self, hidden_states: &Tensor) -> Tensor {
|
||||||
hidden_states
|
hidden_states.select(1, 0).apply(&self.lin).tanh()
|
||||||
.select(1, 0)
|
|
||||||
.apply(&self.lin)
|
|
||||||
.tanh()
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -19,18 +19,24 @@
|
|||||||
//! Pretrained models are available and can be downloaded using RemoteResources.
|
//! Pretrained models are available and can be downloaded using RemoteResources.
|
||||||
//!
|
//!
|
||||||
//! ```no_run
|
//! ```no_run
|
||||||
//!# fn main() -> failure::Fallible<()> {
|
//! # fn main() -> failure::Fallible<()> {
|
||||||
//!#
|
//! #
|
||||||
//! use rust_tokenizers::BertTokenizer;
|
//! use rust_tokenizers::BertTokenizer;
|
||||||
//! use tch::{nn, Device};
|
//! use tch::{nn, Device};
|
||||||
//!# use std::path::PathBuf;
|
//! # use std::path::PathBuf;
|
||||||
//! use rust_bert::bert::{BertForMaskedLM, BertConfig};
|
//! use rust_bert::bert::{BertConfig, BertForMaskedLM};
|
||||||
|
//! use rust_bert::resources::{download_resource, LocalResource, Resource};
|
||||||
//! use rust_bert::Config;
|
//! use rust_bert::Config;
|
||||||
//! use rust_bert::resources::{Resource, download_resource, LocalResource};
|
|
||||||
//!
|
//!
|
||||||
//! let config_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/config.json")});
|
//! let config_resource = Resource::Local(LocalResource {
|
||||||
//! let vocab_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/vocab.txt")});
|
//! local_path: PathBuf::from("path/to/config.json"),
|
||||||
//! let weights_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/model.ot")});
|
//! });
|
||||||
|
//! let vocab_resource = Resource::Local(LocalResource {
|
||||||
|
//! local_path: PathBuf::from("path/to/vocab.txt"),
|
||||||
|
//! });
|
||||||
|
//! let weights_resource = Resource::Local(LocalResource {
|
||||||
|
//! local_path: PathBuf::from("path/to/model.ot"),
|
||||||
|
//! });
|
||||||
//! let config_path = download_resource(&config_resource)?;
|
//! let config_path = download_resource(&config_resource)?;
|
||||||
//! let vocab_path = download_resource(&vocab_resource)?;
|
//! let vocab_path = download_resource(&vocab_resource)?;
|
||||||
//! let weights_path = download_resource(&weights_resource)?;
|
//! let weights_path = download_resource(&weights_resource)?;
|
||||||
@ -41,17 +47,18 @@
|
|||||||
//! let bert_model = BertForMaskedLM::new(&vs.root(), &config);
|
//! let bert_model = BertForMaskedLM::new(&vs.root(), &config);
|
||||||
//! vs.load(weights_path)?;
|
//! vs.load(weights_path)?;
|
||||||
//!
|
//!
|
||||||
//!# Ok(())
|
//! # Ok(())
|
||||||
//!# }
|
//! # }
|
||||||
//! ```
|
//! ```
|
||||||
|
|
||||||
|
mod attention;
|
||||||
mod bert;
|
mod bert;
|
||||||
mod embeddings;
|
mod embeddings;
|
||||||
mod attention;
|
|
||||||
pub(crate) mod encoder;
|
pub(crate) mod encoder;
|
||||||
|
|
||||||
pub use bert::{BertModelResources, BertConfigResources, BertVocabResources,
|
pub use bert::{
|
||||||
BertConfig, Activation, BertModel, BertForTokenClassification, BertForMultipleChoice,
|
Activation, BertConfig, BertConfigResources, BertForMaskedLM, BertForMultipleChoice,
|
||||||
BertForMaskedLM, BertForSequenceClassification, BertForQuestionAnswering};
|
BertForQuestionAnswering, BertForSequenceClassification, BertForTokenClassification, BertModel,
|
||||||
pub use embeddings::{BertEmbedding, BertEmbeddings};
|
BertModelResources, BertVocabResources,
|
||||||
|
};
|
||||||
|
pub use embeddings::{BertEmbedding, BertEmbeddings};
|
||||||
|
@ -1,14 +1,26 @@
|
|||||||
use tch::Tensor;
|
|
||||||
use std::f64::consts::PI;
|
use std::f64::consts::PI;
|
||||||
|
use tch::Tensor;
|
||||||
|
|
||||||
pub fn _gelu(x: &Tensor) -> Tensor { x * 0.5 * (1.0 + (x / ((2.0 as f64).sqrt())).erf()) }
|
pub fn _gelu(x: &Tensor) -> Tensor {
|
||||||
|
x * 0.5 * (1.0 + (x / ((2.0 as f64).sqrt())).erf())
|
||||||
|
}
|
||||||
|
|
||||||
pub fn _relu(x: &Tensor) -> Tensor { x.relu() }
|
pub fn _relu(x: &Tensor) -> Tensor {
|
||||||
|
x.relu()
|
||||||
|
}
|
||||||
|
|
||||||
pub fn _swish(x: &Tensor) -> Tensor { x * x.sigmoid() }
|
pub fn _swish(x: &Tensor) -> Tensor {
|
||||||
|
x * x.sigmoid()
|
||||||
|
}
|
||||||
|
|
||||||
pub fn _mish(x: &Tensor) -> Tensor { x * (x.softplus().tanh()) }
|
pub fn _mish(x: &Tensor) -> Tensor {
|
||||||
|
x * (x.softplus().tanh())
|
||||||
|
}
|
||||||
|
|
||||||
pub fn _gelu_new(x: &Tensor) -> Tensor { x * 0.5 * (((x.pow(3.0f64) * 0.044715 + x) * ((2f64 / PI).sqrt())).tanh() + 1) }
|
pub fn _gelu_new(x: &Tensor) -> Tensor {
|
||||||
|
x * 0.5 * (((x.pow(3.0f64) * 0.044715 + x) * ((2f64 / PI).sqrt())).tanh() + 1)
|
||||||
|
}
|
||||||
|
|
||||||
pub fn _tanh(x: &Tensor) -> Tensor { x.tanh() }
|
pub fn _tanh(x: &Tensor) -> Tensor {
|
||||||
|
x.tanh()
|
||||||
|
}
|
||||||
|
@ -9,15 +9,16 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
|
use serde::Deserialize;
|
||||||
use std::path::Path;
|
|
||||||
use std::fs::File;
|
use std::fs::File;
|
||||||
use std::io::BufReader;
|
use std::io::BufReader;
|
||||||
use serde::Deserialize;
|
use std::path::Path;
|
||||||
|
|
||||||
/// # Utility to deserialize JSON config files
|
/// # Utility to deserialize JSON config files
|
||||||
pub trait Config<T>
|
pub trait Config<T>
|
||||||
where for<'de> T: Deserialize<'de> {
|
where
|
||||||
|
for<'de> T: Deserialize<'de>,
|
||||||
|
{
|
||||||
/// Loads a `Config` object from a JSON file. The format is expected to be aligned with the [Transformers library](https://github.com/huggingface/transformers) configuration files for each model.
|
/// Loads a `Config` object from a JSON file. The format is expected to be aligned with the [Transformers library](https://github.com/huggingface/transformers) configuration files for each model.
|
||||||
/// The parsing will fail if non-optional keys expected by the model are missing.
|
/// The parsing will fail if non-optional keys expected by the model are missing.
|
||||||
///
|
///
|
||||||
@ -28,18 +29,17 @@ pub trait Config<T>
|
|||||||
/// # Example
|
/// # Example
|
||||||
///
|
///
|
||||||
/// ```no_run
|
/// ```no_run
|
||||||
|
/// use rust_bert::gpt2::Gpt2Config;
|
||||||
/// use rust_bert::Config;
|
/// use rust_bert::Config;
|
||||||
/// use std::path::Path;
|
/// use std::path::Path;
|
||||||
/// use rust_bert::gpt2::Gpt2Config;
|
|
||||||
///
|
///
|
||||||
/// let config_path = Path::new("path/to/config.json");
|
/// let config_path = Path::new("path/to/config.json");
|
||||||
/// let config = Gpt2Config::from_file(config_path);
|
/// let config = Gpt2Config::from_file(config_path);
|
||||||
/// ```
|
/// ```
|
||||||
///
|
|
||||||
fn from_file(path: &Path) -> T {
|
fn from_file(path: &Path) -> T {
|
||||||
let f = File::open(path).expect("Could not open configuration file.");
|
let f = File::open(path).expect("Could not open configuration file.");
|
||||||
let br = BufReader::new(f);
|
let br = BufReader::new(f);
|
||||||
let config: T = serde_json::from_reader(br).expect("could not parse configuration");
|
let config: T = serde_json::from_reader(br).expect("could not parse configuration");
|
||||||
config
|
config
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -27,4 +27,4 @@ impl ModuleT for Dropout {
|
|||||||
fn forward_t(&self, input: &Tensor, train: bool) -> Tensor {
|
fn forward_t(&self, input: &Tensor, train: bool) -> Tensor {
|
||||||
input.dropout(self.dropout_prob, train)
|
input.dropout(self.dropout_prob, train)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -10,9 +10,9 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
use tch::nn::{Init, Path, Module};
|
|
||||||
use tch::Tensor;
|
|
||||||
use std::borrow::Borrow;
|
use std::borrow::Borrow;
|
||||||
|
use tch::nn::{Init, Module, Path};
|
||||||
|
use tch::Tensor;
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy)]
|
#[derive(Debug, Clone, Copy)]
|
||||||
pub struct LinearNoBiasConfig {
|
pub struct LinearNoBiasConfig {
|
||||||
@ -32,7 +32,6 @@ pub struct LinearNoBias {
|
|||||||
pub ws: Tensor,
|
pub ws: Tensor,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
pub fn linear_no_bias<'a, T: Borrow<Path<'a>>>(
|
pub fn linear_no_bias<'a, T: Borrow<Path<'a>>>(
|
||||||
vs: T,
|
vs: T,
|
||||||
in_dim: i64,
|
in_dim: i64,
|
||||||
@ -49,4 +48,4 @@ impl Module for LinearNoBias {
|
|||||||
fn forward(&self, xs: &Tensor) -> Tensor {
|
fn forward(&self, xs: &Tensor) -> Tensor {
|
||||||
xs.matmul(&self.ws.tr())
|
xs.matmul(&self.ws.tr())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
pub mod config;
|
|
||||||
pub mod resources;
|
|
||||||
pub(crate) mod dropout;
|
|
||||||
pub(crate) mod activations;
|
pub(crate) mod activations;
|
||||||
|
pub mod config;
|
||||||
|
pub(crate) mod dropout;
|
||||||
pub(crate) mod linear;
|
pub(crate) mod linear;
|
||||||
|
pub mod resources;
|
||||||
|
|
||||||
pub use config::Config;
|
pub use config::Config;
|
||||||
|
@ -18,9 +18,9 @@
|
|||||||
//! pre-trained models in each model module.
|
//! pre-trained models in each model module.
|
||||||
|
|
||||||
use lazy_static::lazy_static;
|
use lazy_static::lazy_static;
|
||||||
use std::path::PathBuf;
|
|
||||||
use reqwest::Client;
|
use reqwest::Client;
|
||||||
use std::{fs, env};
|
use std::path::PathBuf;
|
||||||
|
use std::{env, fs};
|
||||||
use tokio::prelude::*;
|
use tokio::prelude::*;
|
||||||
use tokio::runtime::Runtime;
|
use tokio::runtime::Runtime;
|
||||||
use tokio::task;
|
use tokio::task;
|
||||||
@ -47,12 +47,13 @@ impl Resource {
|
|||||||
/// # Example
|
/// # Example
|
||||||
///
|
///
|
||||||
/// ```no_run
|
/// ```no_run
|
||||||
/// use rust_bert::resources::{Resource, LocalResource};
|
/// use rust_bert::resources::{LocalResource, Resource};
|
||||||
/// use std::path::PathBuf;
|
/// use std::path::PathBuf;
|
||||||
/// let config_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/config.json")});
|
/// let config_resource = Resource::Local(LocalResource {
|
||||||
|
/// local_path: PathBuf::from("path/to/config.json"),
|
||||||
|
/// });
|
||||||
/// let config_path = config_resource.get_local_path();
|
/// let config_path = config_resource.get_local_path();
|
||||||
/// ```
|
/// ```
|
||||||
///
|
|
||||||
pub fn get_local_path(&self) -> &PathBuf {
|
pub fn get_local_path(&self) -> &PathBuf {
|
||||||
match self {
|
match self {
|
||||||
Resource::Local(resource) => &resource.local_path,
|
Resource::Local(resource) => &resource.local_path,
|
||||||
@ -65,7 +66,7 @@ impl Resource {
|
|||||||
#[derive(PartialEq, Clone)]
|
#[derive(PartialEq, Clone)]
|
||||||
pub struct LocalResource {
|
pub struct LocalResource {
|
||||||
/// Local path for the resource
|
/// Local path for the resource
|
||||||
pub local_path: PathBuf
|
pub local_path: PathBuf,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// # Remote resource
|
/// # Remote resource
|
||||||
@ -93,13 +94,18 @@ impl RemoteResource {
|
|||||||
/// # Example
|
/// # Example
|
||||||
///
|
///
|
||||||
/// ```no_run
|
/// ```no_run
|
||||||
/// use rust_bert::resources::{Resource, RemoteResource};
|
/// use rust_bert::resources::{RemoteResource, Resource};
|
||||||
/// use std::path::PathBuf;
|
/// use std::path::PathBuf;
|
||||||
/// let config_resource = Resource::Remote(RemoteResource::new("http://config_json_location", PathBuf::from("path/to/config.json")));
|
/// let config_resource = Resource::Remote(RemoteResource::new(
|
||||||
|
/// "http://config_json_location",
|
||||||
|
/// PathBuf::from("path/to/config.json"),
|
||||||
|
/// ));
|
||||||
/// ```
|
/// ```
|
||||||
///
|
|
||||||
pub fn new(url: &str, target: PathBuf) -> RemoteResource {
|
pub fn new(url: &str, target: PathBuf) -> RemoteResource {
|
||||||
RemoteResource { url: url.to_string(), local_path: target }
|
RemoteResource {
|
||||||
|
url: url.to_string(),
|
||||||
|
local_path: target,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Creates a new RemoteResource from an URL and local name. Will define a local path pointing to
|
/// Creates a new RemoteResource from an URL and local name. Will define a local path pointing to
|
||||||
@ -117,14 +123,12 @@ impl RemoteResource {
|
|||||||
/// # Example
|
/// # Example
|
||||||
///
|
///
|
||||||
/// ```no_run
|
/// ```no_run
|
||||||
/// use rust_bert::resources::{Resource, RemoteResource};
|
/// use rust_bert::resources::{RemoteResource, Resource};
|
||||||
/// let model_resource = Resource::Remote(RemoteResource::from_pretrained(
|
/// let model_resource = Resource::Remote(RemoteResource::from_pretrained((
|
||||||
/// ("distilbert-sst2/model.ot",
|
/// "distilbert-sst2/model.ot",
|
||||||
/// "https://cdn.huggingface.co/distilbert-base-uncased-finetuned-sst-2-english-rust_model.ot"
|
/// "https://cdn.huggingface.co/distilbert-base-uncased-finetuned-sst-2-english-rust_model.ot",
|
||||||
/// )
|
/// )));
|
||||||
/// ));
|
|
||||||
/// ```
|
/// ```
|
||||||
///
|
|
||||||
pub fn from_pretrained(name_url_tuple: (&str, &str)) -> RemoteResource {
|
pub fn from_pretrained(name_url_tuple: (&str, &str)) -> RemoteResource {
|
||||||
let name = name_url_tuple.0;
|
let name = name_url_tuple.0;
|
||||||
let url = name_url_tuple.1.to_string();
|
let url = name_url_tuple.1.to_string();
|
||||||
@ -171,15 +175,13 @@ fn _get_cache_directory() -> PathBuf {
|
|||||||
/// # Example
|
/// # Example
|
||||||
///
|
///
|
||||||
/// ```no_run
|
/// ```no_run
|
||||||
/// use rust_bert::resources::{Resource, RemoteResource, download_resource};
|
/// use rust_bert::resources::{download_resource, RemoteResource, Resource};
|
||||||
/// let model_resource = Resource::Remote(RemoteResource::from_pretrained(
|
/// let model_resource = Resource::Remote(RemoteResource::from_pretrained((
|
||||||
/// ("distilbert-sst2/model.ot",
|
/// "distilbert-sst2/model.ot",
|
||||||
/// "https://cdn.huggingface.co/distilbert-base-uncased-finetuned-sst-2-english-rust_model.ot"
|
/// "https://cdn.huggingface.co/distilbert-base-uncased-finetuned-sst-2-english-rust_model.ot",
|
||||||
/// )
|
/// )));
|
||||||
/// ));
|
|
||||||
/// let local_path = download_resource(&model_resource);
|
/// let local_path = download_resource(&model_resource);
|
||||||
/// ```
|
/// ```
|
||||||
///
|
|
||||||
pub fn download_resource(resource: &Resource) -> failure::Fallible<&PathBuf> {
|
pub fn download_resource(resource: &Resource) -> failure::Fallible<&PathBuf> {
|
||||||
match resource {
|
match resource {
|
||||||
Resource::Remote(remote_resource) => {
|
Resource::Remote(remote_resource) => {
|
||||||
@ -202,8 +204,6 @@ pub fn download_resource(resource: &Resource) -> failure::Fallible<&PathBuf> {
|
|||||||
|
|
||||||
Ok(resource.get_local_path())
|
Ok(resource.get_local_path())
|
||||||
}
|
}
|
||||||
Resource::Local(_) => {
|
Resource::Local(_) => Ok(resource.get_local_path()),
|
||||||
Ok(resource.get_local_path())
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -16,7 +16,11 @@ extern crate tch;
|
|||||||
|
|
||||||
pub fn main() -> failure::Fallible<()> {
|
pub fn main() -> failure::Fallible<()> {
|
||||||
let args: Vec<_> = std::env::args().collect();
|
let args: Vec<_> = std::env::args().collect();
|
||||||
ensure!(args.len() == 3, "usage: {} source.npz destination.ot", args[0]);
|
ensure!(
|
||||||
|
args.len() == 3,
|
||||||
|
"usage: {} source.npz destination.ot",
|
||||||
|
args[0]
|
||||||
|
);
|
||||||
|
|
||||||
let source_file = &args[1];
|
let source_file = &args[1];
|
||||||
let destination_file = &args[2];
|
let destination_file = &args[2];
|
||||||
@ -24,4 +28,4 @@ pub fn main() -> failure::Fallible<()> {
|
|||||||
tch::Tensor::save_multi(&tensors, destination_file)?;
|
tch::Tensor::save_multi(&tensors, destination_file)?;
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -10,11 +10,10 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
use tch::{nn, Tensor};
|
use crate::common::dropout::Dropout;
|
||||||
use crate::distilbert::distilbert::DistilBertConfig;
|
use crate::distilbert::distilbert::DistilBertConfig;
|
||||||
use tch::kind::Kind::Float;
|
use tch::kind::Kind::Float;
|
||||||
use crate::common::dropout::Dropout;
|
use tch::{nn, Tensor};
|
||||||
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct MultiHeadSelfAttention {
|
pub struct MultiHeadSelfAttention {
|
||||||
@ -39,7 +38,7 @@ impl MultiHeadSelfAttention {
|
|||||||
|
|
||||||
let output_attentions = match config.output_attentions {
|
let output_attentions = match config.output_attentions {
|
||||||
Some(value) => value,
|
Some(value) => value,
|
||||||
None => false
|
None => false,
|
||||||
};
|
};
|
||||||
|
|
||||||
MultiHeadSelfAttention {
|
MultiHeadSelfAttention {
|
||||||
@ -59,10 +58,19 @@ impl MultiHeadSelfAttention {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn flatten(&self, x: Tensor, bs: i64, dim_per_head: i64) -> Tensor {
|
fn flatten(&self, x: Tensor, bs: i64, dim_per_head: i64) -> Tensor {
|
||||||
x.transpose(1, 2).contiguous().view((bs, -1, &self.n_heads * dim_per_head))
|
x.transpose(1, 2)
|
||||||
|
.contiguous()
|
||||||
|
.view((bs, -1, &self.n_heads * dim_per_head))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn forward_t(&self, query: &Tensor, key: &Tensor, value: &Tensor, mask: &Option<Tensor>, train: bool) -> (Tensor, Option<Tensor>) {
|
pub fn forward_t(
|
||||||
|
&self,
|
||||||
|
query: &Tensor,
|
||||||
|
key: &Tensor,
|
||||||
|
value: &Tensor,
|
||||||
|
mask: &Option<Tensor>,
|
||||||
|
train: bool,
|
||||||
|
) -> (Tensor, Option<Tensor>) {
|
||||||
let bs = query.size()[0];
|
let bs = query.size()[0];
|
||||||
let k_length = key.size()[1];
|
let k_length = key.size()[1];
|
||||||
|
|
||||||
@ -73,14 +81,19 @@ impl MultiHeadSelfAttention {
|
|||||||
|
|
||||||
let scores = if let Some(mask) = mask {
|
let scores = if let Some(mask) = mask {
|
||||||
let unmasked_scores = q.matmul(&k.transpose(2, 3));
|
let unmasked_scores = q.matmul(&k.transpose(2, 3));
|
||||||
let mask = mask.le1(&(mask.zeros_like() + 0.1)).view((bs, 1i64, 1i64, k_length)).expand_as(&unmasked_scores);
|
let mask = mask
|
||||||
|
.le1(&(mask.zeros_like() + 0.1))
|
||||||
|
.view((bs, 1i64, 1i64, k_length))
|
||||||
|
.expand_as(&unmasked_scores);
|
||||||
unmasked_scores.masked_fill(&mask, std::f64::NEG_INFINITY)
|
unmasked_scores.masked_fill(&mask, std::f64::NEG_INFINITY)
|
||||||
} else {
|
} else {
|
||||||
q.matmul(&k.transpose(2, 3))
|
q.matmul(&k.transpose(2, 3))
|
||||||
};
|
};
|
||||||
|
|
||||||
let weights = scores.softmax(-1, Float).apply_t(&self.dropout, train);
|
let weights = scores.softmax(-1, Float).apply_t(&self.dropout, train);
|
||||||
let context = self.flatten(weights.matmul(&v), bs, self.dim_per_head).apply(&self.out_lin);
|
let context = self
|
||||||
|
.flatten(weights.matmul(&v), bs, self.dim_per_head)
|
||||||
|
.apply(&self.out_lin);
|
||||||
|
|
||||||
if !self.output_attentions {
|
if !self.output_attentions {
|
||||||
(context, None)
|
(context, None)
|
||||||
|
@ -12,13 +12,13 @@
|
|||||||
|
|
||||||
extern crate tch;
|
extern crate tch;
|
||||||
|
|
||||||
use std::collections::HashMap;
|
|
||||||
use serde::{Deserialize, Serialize};
|
|
||||||
use crate::distilbert::embeddings::DistilBertEmbedding;
|
|
||||||
use crate::distilbert::transformer::Transformer;
|
|
||||||
use self::tch::{nn, Tensor};
|
use self::tch::{nn, Tensor};
|
||||||
use crate::common::dropout::Dropout;
|
use crate::common::dropout::Dropout;
|
||||||
|
use crate::distilbert::embeddings::DistilBertEmbedding;
|
||||||
|
use crate::distilbert::transformer::Transformer;
|
||||||
use crate::Config;
|
use crate::Config;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::collections::HashMap;
|
||||||
|
|
||||||
/// # DistilBERT Pretrained model weight files
|
/// # DistilBERT Pretrained model weight files
|
||||||
pub struct DistilBertModelResources;
|
pub struct DistilBertModelResources;
|
||||||
@ -31,29 +31,56 @@ pub struct DistilBertVocabResources;
|
|||||||
|
|
||||||
impl DistilBertModelResources {
|
impl DistilBertModelResources {
|
||||||
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
|
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
|
||||||
pub const DISTIL_BERT_SST2: (&'static str, &'static str) = ("distilbert-sst2/model.ot", "https://cdn.huggingface.co/distilbert-base-uncased-finetuned-sst-2-english-rust_model.ot");
|
pub const DISTIL_BERT_SST2: (&'static str, &'static str) = (
|
||||||
|
"distilbert-sst2/model.ot",
|
||||||
|
"https://cdn.huggingface.co/distilbert-base-uncased-finetuned-sst-2-english-rust_model.ot",
|
||||||
|
);
|
||||||
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
|
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
|
||||||
pub const DISTIL_BERT: (&'static str, &'static str) = ("distilbert/model.ot", "https://cdn.huggingface.co/distilbert-base-uncased-rust_model.ot");
|
pub const DISTIL_BERT: (&'static str, &'static str) = (
|
||||||
|
"distilbert/model.ot",
|
||||||
|
"https://cdn.huggingface.co/distilbert-base-uncased-rust_model.ot",
|
||||||
|
);
|
||||||
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
|
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
|
||||||
pub const DISTIL_BERT_SQUAD: (&'static str, &'static str) = ("distilbert-qa/model.ot", "https://cdn.huggingface.co/distilbert-base-cased-distilled-squad-rust_model.ot");
|
pub const DISTIL_BERT_SQUAD: (&'static str, &'static str) = (
|
||||||
|
"distilbert-qa/model.ot",
|
||||||
|
"https://cdn.huggingface.co/distilbert-base-cased-distilled-squad-rust_model.ot",
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
impl DistilBertConfigResources {
|
impl DistilBertConfigResources {
|
||||||
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
|
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
|
||||||
pub const DISTIL_BERT_SST2: (&'static str, &'static str) = ("distilbert-sst2/config.json", "https://cdn.huggingface.co/distilbert-base-uncased-finetuned-sst-2-english-config.json");
|
pub const DISTIL_BERT_SST2: (&'static str, &'static str) = (
|
||||||
|
"distilbert-sst2/config.json",
|
||||||
|
"https://cdn.huggingface.co/distilbert-base-uncased-finetuned-sst-2-english-config.json",
|
||||||
|
);
|
||||||
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
|
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
|
||||||
pub const DISTIL_BERT: (&'static str, &'static str) = ("distilbert/config.json", "https://cdn.huggingface.co/distilbert-base-uncased-config.json");
|
pub const DISTIL_BERT: (&'static str, &'static str) = (
|
||||||
|
"distilbert/config.json",
|
||||||
|
"https://cdn.huggingface.co/distilbert-base-uncased-config.json",
|
||||||
|
);
|
||||||
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
|
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
|
||||||
pub const DISTIL_BERT_SQUAD: (&'static str, &'static str) = ("distilbert-qa/config.json", "https://cdn.huggingface.co/distilbert-base-cased-distilled-squad-config.json");
|
pub const DISTIL_BERT_SQUAD: (&'static str, &'static str) = (
|
||||||
|
"distilbert-qa/config.json",
|
||||||
|
"https://cdn.huggingface.co/distilbert-base-cased-distilled-squad-config.json",
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
impl DistilBertVocabResources {
|
impl DistilBertVocabResources {
|
||||||
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
|
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
|
||||||
pub const DISTIL_BERT_SST2: (&'static str, &'static str) = ("distilbert-sst2/vocab.txt", "https://cdn.huggingface.co/distilbert-base-uncased-finetuned-sst-2-english-vocab.txt");
|
pub const DISTIL_BERT_SST2: (&'static str, &'static str) = (
|
||||||
|
"distilbert-sst2/vocab.txt",
|
||||||
|
"https://cdn.huggingface.co/distilbert-base-uncased-finetuned-sst-2-english-vocab.txt",
|
||||||
|
);
|
||||||
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
|
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
|
||||||
pub const DISTIL_BERT: (&'static str, &'static str) = ("distilbert/vocab.txt", "https://cdn.huggingface.co/bert-base-uncased-vocab.txt");
|
pub const DISTIL_BERT: (&'static str, &'static str) = (
|
||||||
|
"distilbert/vocab.txt",
|
||||||
|
"https://cdn.huggingface.co/bert-base-uncased-vocab.txt",
|
||||||
|
);
|
||||||
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
|
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
|
||||||
pub const DISTIL_BERT_SQUAD: (&'static str, &'static str) = ("distilbert-qa/vocab.txt", "https://cdn.huggingface.co/bert-large-cased-vocab.txt");
|
pub const DISTIL_BERT_SQUAD: (&'static str, &'static str) = (
|
||||||
|
"distilbert-qa/vocab.txt",
|
||||||
|
"https://cdn.huggingface.co/bert-large-cased-vocab.txt",
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(non_camel_case_types)]
|
#[allow(non_camel_case_types)]
|
||||||
@ -118,10 +145,10 @@ impl DistilBertModel {
|
|||||||
/// # Example
|
/// # Example
|
||||||
///
|
///
|
||||||
/// ```no_run
|
/// ```no_run
|
||||||
/// use tch::{nn, Device};
|
/// use rust_bert::distilbert::{DistilBertConfig, DistilBertModel};
|
||||||
/// use rust_bert::Config;
|
/// use rust_bert::Config;
|
||||||
/// use std::path::Path;
|
/// use std::path::Path;
|
||||||
/// use rust_bert::distilbert::{DistilBertConfig, DistilBertModel};
|
/// use tch::{nn, Device};
|
||||||
///
|
///
|
||||||
/// let config_path = Path::new("path/to/config.json");
|
/// let config_path = Path::new("path/to/config.json");
|
||||||
/// let device = Device::Cpu;
|
/// let device = Device::Cpu;
|
||||||
@ -129,12 +156,14 @@ impl DistilBertModel {
|
|||||||
/// let config = DistilBertConfig::from_file(config_path);
|
/// let config = DistilBertConfig::from_file(config_path);
|
||||||
/// let distil_bert: DistilBertModel = DistilBertModel::new(&(&p.root() / "distilbert"), &config);
|
/// let distil_bert: DistilBertModel = DistilBertModel::new(&(&p.root() / "distilbert"), &config);
|
||||||
/// ```
|
/// ```
|
||||||
///
|
|
||||||
pub fn new(p: &nn::Path, config: &DistilBertConfig) -> DistilBertModel {
|
pub fn new(p: &nn::Path, config: &DistilBertConfig) -> DistilBertModel {
|
||||||
let p = &(p / "distilbert");
|
let p = &(p / "distilbert");
|
||||||
let embeddings = DistilBertEmbedding::new(&(p / "embeddings"), config);
|
let embeddings = DistilBertEmbedding::new(&(p / "embeddings"), config);
|
||||||
let transformer = Transformer::new(&(p / "transformer"), config);
|
let transformer = Transformer::new(&(p / "transformer"), config);
|
||||||
DistilBertModel { embeddings, transformer }
|
DistilBertModel {
|
||||||
|
embeddings,
|
||||||
|
transformer,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Forward pass through the model
|
/// Forward pass through the model
|
||||||
@ -155,45 +184,49 @@ impl DistilBertModel {
|
|||||||
/// # Example
|
/// # Example
|
||||||
///
|
///
|
||||||
/// ```no_run
|
/// ```no_run
|
||||||
///# use tch::{nn, Device, Tensor, no_grad};
|
/// # use tch::{nn, Device, Tensor, no_grad};
|
||||||
///# use rust_bert::Config;
|
/// # use rust_bert::Config;
|
||||||
///# use std::path::Path;
|
/// # use std::path::Path;
|
||||||
///# use tch::kind::Kind::Int64;
|
/// # use tch::kind::Kind::Int64;
|
||||||
/// use rust_bert::distilbert::{DistilBertConfig, DistilBertModel};
|
/// use rust_bert::distilbert::{DistilBertConfig, DistilBertModel};
|
||||||
///# let config_path = Path::new("path/to/config.json");
|
/// # let config_path = Path::new("path/to/config.json");
|
||||||
///# let vocab_path = Path::new("path/to/vocab.txt");
|
/// # let vocab_path = Path::new("path/to/vocab.txt");
|
||||||
///# let device = Device::Cpu;
|
/// # let device = Device::Cpu;
|
||||||
///# let vs = nn::VarStore::new(device);
|
/// # let vs = nn::VarStore::new(device);
|
||||||
///# let config = DistilBertConfig::from_file(config_path);
|
/// # let config = DistilBertConfig::from_file(config_path);
|
||||||
///# let distilbert_model: DistilBertModel = DistilBertModel::new(&vs.root(), &config);
|
/// # let distilbert_model: DistilBertModel = DistilBertModel::new(&vs.root(), &config);
|
||||||
/// let (batch_size, sequence_length) = (64, 128);
|
/// let (batch_size, sequence_length) = (64, 128);
|
||||||
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
|
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
|
||||||
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
|
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
|
||||||
///
|
|
||||||
/// let (output, all_hidden_states, all_attentions) = no_grad(|| {
|
|
||||||
/// distilbert_model
|
|
||||||
/// .forward_t(Some(input_tensor),
|
|
||||||
/// Some(mask),
|
|
||||||
/// None,
|
|
||||||
/// false).unwrap()
|
|
||||||
/// });
|
|
||||||
///
|
///
|
||||||
|
/// let (output, all_hidden_states, all_attentions) = no_grad(|| {
|
||||||
|
/// distilbert_model
|
||||||
|
/// .forward_t(Some(input_tensor), Some(mask), None, false)
|
||||||
|
/// .unwrap()
|
||||||
|
/// });
|
||||||
/// ```
|
/// ```
|
||||||
///
|
pub fn forward_t(
|
||||||
pub fn forward_t(&self, input: Option<Tensor>, mask: Option<Tensor>, input_embeds: Option<Tensor>, train: bool)
|
&self,
|
||||||
-> Result<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
|
input: Option<Tensor>,
|
||||||
|
mask: Option<Tensor>,
|
||||||
|
input_embeds: Option<Tensor>,
|
||||||
|
train: bool,
|
||||||
|
) -> Result<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
|
||||||
let input_embeddings = match input {
|
let input_embeddings = match input {
|
||||||
Some(input_value) => match input_embeds {
|
Some(input_value) => match input_embeds {
|
||||||
Some(_) => { return Err("Only one of input ids or input embeddings may be set"); }
|
Some(_) => {
|
||||||
None => input_value.apply_t(&self.embeddings, train)
|
return Err("Only one of input ids or input embeddings may be set");
|
||||||
}
|
}
|
||||||
|
None => input_value.apply_t(&self.embeddings, train),
|
||||||
|
},
|
||||||
None => match input_embeds {
|
None => match input_embeds {
|
||||||
Some(embeds) => embeds,
|
Some(embeds) => embeds,
|
||||||
None => { return Err("At least one of input ids or input embeddings must be set"); }
|
None => {
|
||||||
}
|
return Err("At least one of input ids or input embeddings must be set");
|
||||||
|
}
|
||||||
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
let transformer_output = (&self.transformer).forward_t(&input_embeddings, mask, train);
|
let transformer_output = (&self.transformer).forward_t(&input_embeddings, mask, train);
|
||||||
Ok(transformer_output)
|
Ok(transformer_output)
|
||||||
}
|
}
|
||||||
@ -223,28 +256,47 @@ impl DistilBertModelClassifier {
|
|||||||
/// # Example
|
/// # Example
|
||||||
///
|
///
|
||||||
/// ```no_run
|
/// ```no_run
|
||||||
/// use tch::{nn, Device};
|
/// use rust_bert::distilbert::{DistilBertConfig, DistilBertModelClassifier};
|
||||||
/// use rust_bert::Config;
|
/// use rust_bert::Config;
|
||||||
/// use std::path::Path;
|
/// use std::path::Path;
|
||||||
/// use rust_bert::distilbert::{DistilBertConfig, DistilBertModelClassifier};
|
/// use tch::{nn, Device};
|
||||||
///
|
///
|
||||||
/// let config_path = Path::new("path/to/config.json");
|
/// let config_path = Path::new("path/to/config.json");
|
||||||
/// let device = Device::Cpu;
|
/// let device = Device::Cpu;
|
||||||
/// let p = nn::VarStore::new(device);
|
/// let p = nn::VarStore::new(device);
|
||||||
/// let config = DistilBertConfig::from_file(config_path);
|
/// let config = DistilBertConfig::from_file(config_path);
|
||||||
/// let distil_bert: DistilBertModelClassifier = DistilBertModelClassifier::new(&(&p.root() / "distilbert"), &config);
|
/// let distil_bert: DistilBertModelClassifier =
|
||||||
|
/// DistilBertModelClassifier::new(&(&p.root() / "distilbert"), &config);
|
||||||
/// ```
|
/// ```
|
||||||
///
|
|
||||||
pub fn new(p: &nn::Path, config: &DistilBertConfig) -> DistilBertModelClassifier {
|
pub fn new(p: &nn::Path, config: &DistilBertConfig) -> DistilBertModelClassifier {
|
||||||
let distil_bert_model = DistilBertModel::new(&p, config);
|
let distil_bert_model = DistilBertModel::new(&p, config);
|
||||||
|
|
||||||
let num_labels = config.id2label.as_ref().expect("id2label must be provided for classifiers").len() as i64;
|
let num_labels = config
|
||||||
|
.id2label
|
||||||
|
.as_ref()
|
||||||
|
.expect("id2label must be provided for classifiers")
|
||||||
|
.len() as i64;
|
||||||
|
|
||||||
let pre_classifier = nn::linear(&(p / "pre_classifier"), config.dim, config.dim, Default::default());
|
let pre_classifier = nn::linear(
|
||||||
let classifier = nn::linear(&(p / "classifier"), config.dim, num_labels, Default::default());
|
&(p / "pre_classifier"),
|
||||||
|
config.dim,
|
||||||
|
config.dim,
|
||||||
|
Default::default(),
|
||||||
|
);
|
||||||
|
let classifier = nn::linear(
|
||||||
|
&(p / "classifier"),
|
||||||
|
config.dim,
|
||||||
|
num_labels,
|
||||||
|
Default::default(),
|
||||||
|
);
|
||||||
let dropout = Dropout::new(config.seq_classif_dropout);
|
let dropout = Dropout::new(config.seq_classif_dropout);
|
||||||
|
|
||||||
DistilBertModelClassifier { distil_bert_model, pre_classifier, classifier, dropout }
|
DistilBertModelClassifier {
|
||||||
|
distil_bert_model,
|
||||||
|
pre_classifier,
|
||||||
|
classifier,
|
||||||
|
dropout,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Forward pass through the model
|
/// Forward pass through the model
|
||||||
@ -265,17 +317,17 @@ impl DistilBertModelClassifier {
|
|||||||
/// # Example
|
/// # Example
|
||||||
///
|
///
|
||||||
/// ```no_run
|
/// ```no_run
|
||||||
///# use tch::{nn, Device, Tensor, no_grad};
|
/// # use tch::{nn, Device, Tensor, no_grad};
|
||||||
///# use rust_bert::Config;
|
/// # use rust_bert::Config;
|
||||||
///# use std::path::Path;
|
/// # use std::path::Path;
|
||||||
///# use tch::kind::Kind::Int64;
|
/// # use tch::kind::Kind::Int64;
|
||||||
/// use rust_bert::distilbert::{DistilBertConfig, DistilBertModelClassifier};
|
/// use rust_bert::distilbert::{DistilBertConfig, DistilBertModelClassifier};
|
||||||
///# let config_path = Path::new("path/to/config.json");
|
/// # let config_path = Path::new("path/to/config.json");
|
||||||
///# let vocab_path = Path::new("path/to/vocab.txt");
|
/// # let vocab_path = Path::new("path/to/vocab.txt");
|
||||||
///# let device = Device::Cpu;
|
/// # let device = Device::Cpu;
|
||||||
///# let vs = nn::VarStore::new(device);
|
/// # let vs = nn::VarStore::new(device);
|
||||||
///# let config = DistilBertConfig::from_file(config_path);
|
/// # let config = DistilBertConfig::from_file(config_path);
|
||||||
///# let distilbert_model: DistilBertModelClassifier = DistilBertModelClassifier::new(&vs.root(), &config);
|
/// # let distilbert_model: DistilBertModelClassifier = DistilBertModelClassifier::new(&vs.root(), &config);
|
||||||
/// let (batch_size, sequence_length) = (64, 128);
|
/// let (batch_size, sequence_length) = (64, 128);
|
||||||
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
|
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
|
||||||
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
|
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
|
||||||
@ -287,15 +339,22 @@ impl DistilBertModelClassifier {
|
|||||||
/// None,
|
/// None,
|
||||||
/// false).unwrap()
|
/// false).unwrap()
|
||||||
/// });
|
/// });
|
||||||
///
|
|
||||||
/// ```
|
/// ```
|
||||||
///
|
pub fn forward_t(
|
||||||
pub fn forward_t(&self, input: Option<Tensor>, mask: Option<Tensor>, input_embeds: Option<Tensor>, train: bool)
|
&self,
|
||||||
-> Result<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
|
input: Option<Tensor>,
|
||||||
let (output, all_hidden_states, all_attentions) = match self.distil_bert_model.forward_t(input, mask, input_embeds, train) {
|
mask: Option<Tensor>,
|
||||||
Ok(value) => value,
|
input_embeds: Option<Tensor>,
|
||||||
Err(err) => return Err(err)
|
train: bool,
|
||||||
};
|
) -> Result<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
|
||||||
|
let (output, all_hidden_states, all_attentions) =
|
||||||
|
match self
|
||||||
|
.distil_bert_model
|
||||||
|
.forward_t(input, mask, input_embeds, train)
|
||||||
|
{
|
||||||
|
Ok(value) => value,
|
||||||
|
Err(err) => return Err(err),
|
||||||
|
};
|
||||||
|
|
||||||
let output = output
|
let output = output
|
||||||
.select(1, 0)
|
.select(1, 0)
|
||||||
@ -322,7 +381,6 @@ pub struct DistilBertModelMaskedLM {
|
|||||||
vocab_projector: nn::Linear,
|
vocab_projector: nn::Linear,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
impl DistilBertModelMaskedLM {
|
impl DistilBertModelMaskedLM {
|
||||||
/// Build a new `DistilBertModelMaskedLM` for sequence classification
|
/// Build a new `DistilBertModelMaskedLM` for sequence classification
|
||||||
///
|
///
|
||||||
@ -334,10 +392,10 @@ impl DistilBertModelMaskedLM {
|
|||||||
/// # Example
|
/// # Example
|
||||||
///
|
///
|
||||||
/// ```no_run
|
/// ```no_run
|
||||||
/// use tch::{nn, Device};
|
/// use rust_bert::distilbert::{DistilBertConfig, DistilBertModelMaskedLM};
|
||||||
/// use rust_bert::Config;
|
/// use rust_bert::Config;
|
||||||
/// use std::path::Path;
|
/// use std::path::Path;
|
||||||
/// use rust_bert::distilbert::{DistilBertConfig, DistilBertModelMaskedLM};
|
/// use tch::{nn, Device};
|
||||||
///
|
///
|
||||||
/// let config_path = Path::new("path/to/config.json");
|
/// let config_path = Path::new("path/to/config.json");
|
||||||
/// let device = Device::Cpu;
|
/// let device = Device::Cpu;
|
||||||
@ -345,15 +403,33 @@ impl DistilBertModelMaskedLM {
|
|||||||
/// let config = DistilBertConfig::from_file(config_path);
|
/// let config = DistilBertConfig::from_file(config_path);
|
||||||
/// let distil_bert = DistilBertModelMaskedLM::new(&(&p.root() / "distilbert"), &config);
|
/// let distil_bert = DistilBertModelMaskedLM::new(&(&p.root() / "distilbert"), &config);
|
||||||
/// ```
|
/// ```
|
||||||
///
|
|
||||||
pub fn new(p: &nn::Path, config: &DistilBertConfig) -> DistilBertModelMaskedLM {
|
pub fn new(p: &nn::Path, config: &DistilBertConfig) -> DistilBertModelMaskedLM {
|
||||||
let distil_bert_model = DistilBertModel::new(&p, config);
|
let distil_bert_model = DistilBertModel::new(&p, config);
|
||||||
let vocab_transform = nn::linear(&(p / "vocab_transform"), config.dim, config.dim, Default::default());
|
let vocab_transform = nn::linear(
|
||||||
let layer_norm_config = nn::LayerNormConfig { eps: 1e-12, ..Default::default() };
|
&(p / "vocab_transform"),
|
||||||
let vocab_layer_norm = nn::layer_norm(p / "vocab_layer_norm", vec![config.dim], layer_norm_config);
|
config.dim,
|
||||||
let vocab_projector = nn::linear(&(p / "vocab_projector"), config.dim, config.vocab_size, Default::default());
|
config.dim,
|
||||||
|
Default::default(),
|
||||||
|
);
|
||||||
|
let layer_norm_config = nn::LayerNormConfig {
|
||||||
|
eps: 1e-12,
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
let vocab_layer_norm =
|
||||||
|
nn::layer_norm(p / "vocab_layer_norm", vec![config.dim], layer_norm_config);
|
||||||
|
let vocab_projector = nn::linear(
|
||||||
|
&(p / "vocab_projector"),
|
||||||
|
config.dim,
|
||||||
|
config.vocab_size,
|
||||||
|
Default::default(),
|
||||||
|
);
|
||||||
|
|
||||||
DistilBertModelMaskedLM { distil_bert_model, vocab_transform, vocab_layer_norm, vocab_projector }
|
DistilBertModelMaskedLM {
|
||||||
|
distil_bert_model,
|
||||||
|
vocab_transform,
|
||||||
|
vocab_layer_norm,
|
||||||
|
vocab_projector,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Forward pass through the model
|
/// Forward pass through the model
|
||||||
@ -374,37 +450,42 @@ impl DistilBertModelMaskedLM {
|
|||||||
/// # Example
|
/// # Example
|
||||||
///
|
///
|
||||||
/// ```no_run
|
/// ```no_run
|
||||||
///# use tch::{nn, Device, Tensor, no_grad};
|
/// # use tch::{nn, Device, Tensor, no_grad};
|
||||||
///# use rust_bert::Config;
|
/// # use rust_bert::Config;
|
||||||
///# use std::path::Path;
|
/// # use std::path::Path;
|
||||||
///# use tch::kind::Kind::Int64;
|
/// # use tch::kind::Kind::Int64;
|
||||||
/// use rust_bert::distilbert::{DistilBertConfig, DistilBertModelMaskedLM};
|
/// use rust_bert::distilbert::{DistilBertConfig, DistilBertModelMaskedLM};
|
||||||
///# let config_path = Path::new("path/to/config.json");
|
/// # let config_path = Path::new("path/to/config.json");
|
||||||
///# let vocab_path = Path::new("path/to/vocab.txt");
|
/// # let vocab_path = Path::new("path/to/vocab.txt");
|
||||||
///# let device = Device::Cpu;
|
/// # let device = Device::Cpu;
|
||||||
///# let vs = nn::VarStore::new(device);
|
/// # let vs = nn::VarStore::new(device);
|
||||||
///# let config = DistilBertConfig::from_file(config_path);
|
/// # let config = DistilBertConfig::from_file(config_path);
|
||||||
///# let distilbert_model = DistilBertModelMaskedLM::new(&vs.root(), &config);
|
/// # let distilbert_model = DistilBertModelMaskedLM::new(&vs.root(), &config);
|
||||||
/// let (batch_size, sequence_length) = (64, 128);
|
/// let (batch_size, sequence_length) = (64, 128);
|
||||||
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
|
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
|
||||||
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
|
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
|
||||||
///
|
|
||||||
/// let (output, all_hidden_states, all_attentions) = no_grad(|| {
|
|
||||||
/// distilbert_model
|
|
||||||
/// .forward_t(Some(input_tensor),
|
|
||||||
/// Some(mask),
|
|
||||||
/// None,
|
|
||||||
/// false).unwrap()
|
|
||||||
/// });
|
|
||||||
///
|
///
|
||||||
|
/// let (output, all_hidden_states, all_attentions) = no_grad(|| {
|
||||||
|
/// distilbert_model
|
||||||
|
/// .forward_t(Some(input_tensor), Some(mask), None, false)
|
||||||
|
/// .unwrap()
|
||||||
|
/// });
|
||||||
/// ```
|
/// ```
|
||||||
///
|
pub fn forward_t(
|
||||||
pub fn forward_t(&self, input: Option<Tensor>, mask: Option<Tensor>, input_embeds: Option<Tensor>, train: bool)
|
&self,
|
||||||
-> Result<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
|
input: Option<Tensor>,
|
||||||
let (output, all_hidden_states, all_attentions) = match self.distil_bert_model.forward_t(input, mask, input_embeds, train) {
|
mask: Option<Tensor>,
|
||||||
Ok(value) => value,
|
input_embeds: Option<Tensor>,
|
||||||
Err(err) => return Err(err)
|
train: bool,
|
||||||
};
|
) -> Result<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
|
||||||
|
let (output, all_hidden_states, all_attentions) =
|
||||||
|
match self
|
||||||
|
.distil_bert_model
|
||||||
|
.forward_t(input, mask, input_embeds, train)
|
||||||
|
{
|
||||||
|
Ok(value) => value,
|
||||||
|
Err(err) => return Err(err),
|
||||||
|
};
|
||||||
|
|
||||||
let output = output
|
let output = output
|
||||||
.apply(&self.vocab_transform)
|
.apply(&self.vocab_transform)
|
||||||
@ -440,10 +521,10 @@ impl DistilBertForQuestionAnswering {
|
|||||||
/// # Example
|
/// # Example
|
||||||
///
|
///
|
||||||
/// ```no_run
|
/// ```no_run
|
||||||
/// use tch::{nn, Device};
|
/// use rust_bert::distilbert::{DistilBertConfig, DistilBertForQuestionAnswering};
|
||||||
/// use rust_bert::Config;
|
/// use rust_bert::Config;
|
||||||
/// use std::path::Path;
|
/// use std::path::Path;
|
||||||
/// use rust_bert::distilbert::{DistilBertConfig, DistilBertForQuestionAnswering};
|
/// use tch::{nn, Device};
|
||||||
///
|
///
|
||||||
/// let config_path = Path::new("path/to/config.json");
|
/// let config_path = Path::new("path/to/config.json");
|
||||||
/// let device = Device::Cpu;
|
/// let device = Device::Cpu;
|
||||||
@ -451,13 +532,16 @@ impl DistilBertForQuestionAnswering {
|
|||||||
/// let config = DistilBertConfig::from_file(config_path);
|
/// let config = DistilBertConfig::from_file(config_path);
|
||||||
/// let distil_bert = DistilBertForQuestionAnswering::new(&(&p.root() / "distilbert"), &config);
|
/// let distil_bert = DistilBertForQuestionAnswering::new(&(&p.root() / "distilbert"), &config);
|
||||||
/// ```
|
/// ```
|
||||||
///
|
|
||||||
pub fn new(p: &nn::Path, config: &DistilBertConfig) -> DistilBertForQuestionAnswering {
|
pub fn new(p: &nn::Path, config: &DistilBertConfig) -> DistilBertForQuestionAnswering {
|
||||||
let distil_bert_model = DistilBertModel::new(&p, config);
|
let distil_bert_model = DistilBertModel::new(&p, config);
|
||||||
let qa_outputs = nn::linear(&(p / "qa_outputs"), config.dim, 2, Default::default());
|
let qa_outputs = nn::linear(&(p / "qa_outputs"), config.dim, 2, Default::default());
|
||||||
let dropout = Dropout::new(config.qa_dropout);
|
let dropout = Dropout::new(config.qa_dropout);
|
||||||
|
|
||||||
DistilBertForQuestionAnswering { distil_bert_model, qa_outputs, dropout }
|
DistilBertForQuestionAnswering {
|
||||||
|
distil_bert_model,
|
||||||
|
qa_outputs,
|
||||||
|
dropout,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Forward pass through the model
|
/// Forward pass through the model
|
||||||
@ -479,52 +563,50 @@ impl DistilBertForQuestionAnswering {
|
|||||||
/// # Example
|
/// # Example
|
||||||
///
|
///
|
||||||
/// ```no_run
|
/// ```no_run
|
||||||
///# use tch::{nn, Device, Tensor, no_grad};
|
/// # use tch::{nn, Device, Tensor, no_grad};
|
||||||
///# use rust_bert::Config;
|
/// # use rust_bert::Config;
|
||||||
///# use std::path::Path;
|
/// # use std::path::Path;
|
||||||
///# use tch::kind::Kind::Int64;
|
/// # use tch::kind::Kind::Int64;
|
||||||
/// use rust_bert::distilbert::{DistilBertConfig, DistilBertForQuestionAnswering};
|
/// use rust_bert::distilbert::{DistilBertConfig, DistilBertForQuestionAnswering};
|
||||||
///# let config_path = Path::new("path/to/config.json");
|
/// # let config_path = Path::new("path/to/config.json");
|
||||||
///# let vocab_path = Path::new("path/to/vocab.txt");
|
/// # let vocab_path = Path::new("path/to/vocab.txt");
|
||||||
///# let device = Device::Cpu;
|
/// # let device = Device::Cpu;
|
||||||
///# let vs = nn::VarStore::new(device);
|
/// # let vs = nn::VarStore::new(device);
|
||||||
///# let config = DistilBertConfig::from_file(config_path);
|
/// # let config = DistilBertConfig::from_file(config_path);
|
||||||
///# let distilbert_model = DistilBertForQuestionAnswering::new(&vs.root(), &config);
|
/// # let distilbert_model = DistilBertForQuestionAnswering::new(&vs.root(), &config);
|
||||||
/// let (batch_size, sequence_length) = (64, 128);
|
/// let (batch_size, sequence_length) = (64, 128);
|
||||||
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
|
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
|
||||||
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
|
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
|
||||||
///
|
|
||||||
/// let (start_scores, end_score, all_hidden_states, all_attentions) = no_grad(|| {
|
|
||||||
/// distilbert_model
|
|
||||||
/// .forward_t(Some(input_tensor),
|
|
||||||
/// Some(mask),
|
|
||||||
/// None,
|
|
||||||
/// false).unwrap()
|
|
||||||
/// });
|
|
||||||
///
|
///
|
||||||
|
/// let (start_scores, end_score, all_hidden_states, all_attentions) = no_grad(|| {
|
||||||
|
/// distilbert_model
|
||||||
|
/// .forward_t(Some(input_tensor), Some(mask), None, false)
|
||||||
|
/// .unwrap()
|
||||||
|
/// });
|
||||||
/// ```
|
/// ```
|
||||||
///
|
pub fn forward_t(
|
||||||
pub fn forward_t(&self,
|
&self,
|
||||||
input: Option<Tensor>,
|
input: Option<Tensor>,
|
||||||
mask: Option<Tensor>,
|
mask: Option<Tensor>,
|
||||||
input_embeds: Option<Tensor>,
|
input_embeds: Option<Tensor>,
|
||||||
train: bool)
|
train: bool,
|
||||||
-> Result<(Tensor, Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
|
) -> Result<(Tensor, Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
|
||||||
let (output, all_hidden_states, all_attentions) = match self.distil_bert_model.forward_t(input, mask, input_embeds, train) {
|
let (output, all_hidden_states, all_attentions) =
|
||||||
Ok(value) => value,
|
match self
|
||||||
Err(err) => return Err(err)
|
.distil_bert_model
|
||||||
};
|
.forward_t(input, mask, input_embeds, train)
|
||||||
|
{
|
||||||
|
Ok(value) => value,
|
||||||
|
Err(err) => return Err(err),
|
||||||
|
};
|
||||||
|
|
||||||
let output = output
|
let output = output.apply_t(&self.dropout, train).apply(&self.qa_outputs);
|
||||||
.apply_t(&self.dropout, train)
|
|
||||||
.apply(&self.qa_outputs);
|
|
||||||
|
|
||||||
let logits = output.split(1, -1);
|
let logits = output.split(1, -1);
|
||||||
let (start_logits, end_logits) = (&logits[0], &logits[1]);
|
let (start_logits, end_logits) = (&logits[0], &logits[1]);
|
||||||
let start_logits = start_logits.squeeze1(-1);
|
let start_logits = start_logits.squeeze1(-1);
|
||||||
let end_logits = end_logits.squeeze1(-1);
|
let end_logits = end_logits.squeeze1(-1);
|
||||||
|
|
||||||
|
|
||||||
Ok((start_logits, end_logits, all_hidden_states, all_attentions))
|
Ok((start_logits, end_logits, all_hidden_states, all_attentions))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -552,10 +634,10 @@ impl DistilBertForTokenClassification {
|
|||||||
/// # Example
|
/// # Example
|
||||||
///
|
///
|
||||||
/// ```no_run
|
/// ```no_run
|
||||||
/// use tch::{nn, Device};
|
/// use rust_bert::distilbert::{DistilBertConfig, DistilBertForTokenClassification};
|
||||||
/// use rust_bert::Config;
|
/// use rust_bert::Config;
|
||||||
/// use std::path::Path;
|
/// use std::path::Path;
|
||||||
/// use rust_bert::distilbert::{DistilBertConfig, DistilBertForTokenClassification};
|
/// use tch::{nn, Device};
|
||||||
///
|
///
|
||||||
/// let config_path = Path::new("path/to/config.json");
|
/// let config_path = Path::new("path/to/config.json");
|
||||||
/// let device = Device::Cpu;
|
/// let device = Device::Cpu;
|
||||||
@ -563,16 +645,28 @@ impl DistilBertForTokenClassification {
|
|||||||
/// let config = DistilBertConfig::from_file(config_path);
|
/// let config = DistilBertConfig::from_file(config_path);
|
||||||
/// let distil_bert = DistilBertForTokenClassification::new(&(&p.root() / "distilbert"), &config);
|
/// let distil_bert = DistilBertForTokenClassification::new(&(&p.root() / "distilbert"), &config);
|
||||||
/// ```
|
/// ```
|
||||||
///
|
|
||||||
pub fn new(p: &nn::Path, config: &DistilBertConfig) -> DistilBertForTokenClassification {
|
pub fn new(p: &nn::Path, config: &DistilBertConfig) -> DistilBertForTokenClassification {
|
||||||
let distil_bert_model = DistilBertModel::new(&p, config);
|
let distil_bert_model = DistilBertModel::new(&p, config);
|
||||||
|
|
||||||
let num_labels = config.id2label.as_ref().expect("id2label must be provided for classifiers").len() as i64;
|
let num_labels = config
|
||||||
|
.id2label
|
||||||
|
.as_ref()
|
||||||
|
.expect("id2label must be provided for classifiers")
|
||||||
|
.len() as i64;
|
||||||
|
|
||||||
let classifier = nn::linear(&(p / "classifier"), config.dim, num_labels, Default::default());
|
let classifier = nn::linear(
|
||||||
|
&(p / "classifier"),
|
||||||
|
config.dim,
|
||||||
|
num_labels,
|
||||||
|
Default::default(),
|
||||||
|
);
|
||||||
let dropout = Dropout::new(config.seq_classif_dropout);
|
let dropout = Dropout::new(config.seq_classif_dropout);
|
||||||
|
|
||||||
DistilBertForTokenClassification { distil_bert_model, classifier, dropout }
|
DistilBertForTokenClassification {
|
||||||
|
distil_bert_model,
|
||||||
|
classifier,
|
||||||
|
dropout,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Forward pass through the model
|
/// Forward pass through the model
|
||||||
@ -593,41 +687,44 @@ impl DistilBertForTokenClassification {
|
|||||||
/// # Example
|
/// # Example
|
||||||
///
|
///
|
||||||
/// ```no_run
|
/// ```no_run
|
||||||
///# use tch::{nn, Device, Tensor, no_grad};
|
/// # use tch::{nn, Device, Tensor, no_grad};
|
||||||
///# use rust_bert::Config;
|
/// # use rust_bert::Config;
|
||||||
///# use std::path::Path;
|
/// # use std::path::Path;
|
||||||
///# use tch::kind::Kind::Int64;
|
/// # use tch::kind::Kind::Int64;
|
||||||
/// use rust_bert::distilbert::{DistilBertConfig, DistilBertForTokenClassification};
|
/// use rust_bert::distilbert::{DistilBertConfig, DistilBertForTokenClassification};
|
||||||
///# let config_path = Path::new("path/to/config.json");
|
/// # let config_path = Path::new("path/to/config.json");
|
||||||
///# let vocab_path = Path::new("path/to/vocab.txt");
|
/// # let vocab_path = Path::new("path/to/vocab.txt");
|
||||||
///# let device = Device::Cpu;
|
/// # let device = Device::Cpu;
|
||||||
///# let vs = nn::VarStore::new(device);
|
/// # let vs = nn::VarStore::new(device);
|
||||||
///# let config = DistilBertConfig::from_file(config_path);
|
/// # let config = DistilBertConfig::from_file(config_path);
|
||||||
///# let distilbert_model = DistilBertForTokenClassification::new(&vs.root(), &config);
|
/// # let distilbert_model = DistilBertForTokenClassification::new(&vs.root(), &config);
|
||||||
/// let (batch_size, sequence_length) = (64, 128);
|
/// let (batch_size, sequence_length) = (64, 128);
|
||||||
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
|
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
|
||||||
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
|
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
|
||||||
///
|
|
||||||
/// let (output, all_hidden_states, all_attentions) = no_grad(|| {
|
|
||||||
/// distilbert_model
|
|
||||||
/// .forward_t(Some(input_tensor),
|
|
||||||
/// Some(mask),
|
|
||||||
/// None,
|
|
||||||
/// false).unwrap()
|
|
||||||
/// });
|
|
||||||
///
|
///
|
||||||
|
/// let (output, all_hidden_states, all_attentions) = no_grad(|| {
|
||||||
|
/// distilbert_model
|
||||||
|
/// .forward_t(Some(input_tensor), Some(mask), None, false)
|
||||||
|
/// .unwrap()
|
||||||
|
/// });
|
||||||
/// ```
|
/// ```
|
||||||
///
|
pub fn forward_t(
|
||||||
pub fn forward_t(&self, input: Option<Tensor>, mask: Option<Tensor>, input_embeds: Option<Tensor>, train: bool)
|
&self,
|
||||||
-> Result<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
|
input: Option<Tensor>,
|
||||||
let (output, all_hidden_states, all_attentions) = match self.distil_bert_model.forward_t(input, mask, input_embeds, train) {
|
mask: Option<Tensor>,
|
||||||
Ok(value) => value,
|
input_embeds: Option<Tensor>,
|
||||||
Err(err) => return Err(err)
|
train: bool,
|
||||||
};
|
) -> Result<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
|
||||||
|
let (output, all_hidden_states, all_attentions) =
|
||||||
|
match self
|
||||||
|
.distil_bert_model
|
||||||
|
.forward_t(input, mask, input_embeds, train)
|
||||||
|
{
|
||||||
|
Ok(value) => value,
|
||||||
|
Err(err) => return Err(err),
|
||||||
|
};
|
||||||
|
|
||||||
let output = output
|
let output = output.apply_t(&self.dropout, train).apply(&self.classifier);
|
||||||
.apply_t(&self.dropout, train)
|
|
||||||
.apply(&self.classifier);
|
|
||||||
|
|
||||||
Ok((output, all_hidden_states, all_attentions))
|
Ok((output, all_hidden_states, all_attentions))
|
||||||
}
|
}
|
||||||
|
@ -10,22 +10,26 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
use tch::{nn, Tensor, Kind, Device};
|
|
||||||
use tch::nn::{ModuleT, embedding, EmbeddingConfig};
|
|
||||||
use crate::distilbert::distilbert::DistilBertConfig;
|
|
||||||
use crate::common::dropout::Dropout;
|
use crate::common::dropout::Dropout;
|
||||||
|
use crate::distilbert::distilbert::DistilBertConfig;
|
||||||
use tch::kind::Kind::Float;
|
use tch::kind::Kind::Float;
|
||||||
|
use tch::nn::{embedding, EmbeddingConfig, ModuleT};
|
||||||
|
use tch::{nn, Device, Kind, Tensor};
|
||||||
|
|
||||||
fn create_sinusoidal_embeddings(config: &DistilBertConfig, device: Device) -> nn::Embedding {
|
fn create_sinusoidal_embeddings(config: &DistilBertConfig, device: Device) -> nn::Embedding {
|
||||||
let mut sinusoidal_embedding: Vec<Tensor> = Vec::with_capacity(config.max_position_embeddings as usize);
|
let mut sinusoidal_embedding: Vec<Tensor> =
|
||||||
|
Vec::with_capacity(config.max_position_embeddings as usize);
|
||||||
for pos in 0..config.max_position_embeddings {
|
for pos in 0..config.max_position_embeddings {
|
||||||
let mut temp_vec: Vec<f64> = Vec::with_capacity(config.dim as usize);
|
let mut temp_vec: Vec<f64> = Vec::with_capacity(config.dim as usize);
|
||||||
for j in 0..config.dim {
|
for j in 0..config.dim {
|
||||||
if j % 2 == 0 {
|
if j % 2 == 0 {
|
||||||
temp_vec.push((pos as f64 / 10000f64.powf((2 * (j / 2)) as f64 / config.dim as f64)).sin());
|
temp_vec.push(
|
||||||
|
(pos as f64 / 10000f64.powf((2 * (j / 2)) as f64 / config.dim as f64)).sin(),
|
||||||
|
);
|
||||||
} else {
|
} else {
|
||||||
temp_vec.push((pos as f64 / 10000f64.powf((2 * (j / 2)) as f64 / config.dim as f64)).cos());
|
temp_vec.push(
|
||||||
|
(pos as f64 / 10000f64.powf((2 * (j / 2)) as f64 / config.dim as f64)).cos(),
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
let temp_vec = Tensor::of_slice(&temp_vec);
|
let temp_vec = Tensor::of_slice(&temp_vec);
|
||||||
@ -35,17 +39,21 @@ fn create_sinusoidal_embeddings(config: &DistilBertConfig, device: Device) -> nn
|
|||||||
.to_kind(Float)
|
.to_kind(Float)
|
||||||
.to_device(device);
|
.to_device(device);
|
||||||
|
|
||||||
let embedding_config = EmbeddingConfig { padding_idx: 0, ..Default::default() };
|
let embedding_config = EmbeddingConfig {
|
||||||
let mut embeddings = embedding(&nn::VarStore::new(device).root(),
|
padding_idx: 0,
|
||||||
config.max_position_embeddings,
|
..Default::default()
|
||||||
config.dim,
|
};
|
||||||
embedding_config);
|
let mut embeddings = embedding(
|
||||||
|
&nn::VarStore::new(device).root(),
|
||||||
|
config.max_position_embeddings,
|
||||||
|
config.dim,
|
||||||
|
embedding_config,
|
||||||
|
);
|
||||||
|
|
||||||
embeddings.ws = sinusoidal_embedding;
|
embeddings.ws = sinusoidal_embedding;
|
||||||
embeddings
|
embeddings
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct DistilBertEmbedding {
|
pub struct DistilBertEmbedding {
|
||||||
word_embeddings: nn::Embedding,
|
word_embeddings: nn::Embedding,
|
||||||
@ -56,24 +64,40 @@ pub struct DistilBertEmbedding {
|
|||||||
|
|
||||||
impl DistilBertEmbedding {
|
impl DistilBertEmbedding {
|
||||||
pub fn new(p: &nn::Path, config: &DistilBertConfig) -> DistilBertEmbedding {
|
pub fn new(p: &nn::Path, config: &DistilBertConfig) -> DistilBertEmbedding {
|
||||||
let embedding_config = EmbeddingConfig { padding_idx: 0, ..Default::default() };
|
let embedding_config = EmbeddingConfig {
|
||||||
|
padding_idx: 0,
|
||||||
let word_embeddings: nn::Embedding = embedding(p / "word_embeddings",
|
..Default::default()
|
||||||
config.vocab_size,
|
|
||||||
config.dim,
|
|
||||||
embedding_config);
|
|
||||||
let position_embeddings: nn::Embedding = match config.sinusoidal_pos_embds {
|
|
||||||
false => embedding(p / "position_embeddings",
|
|
||||||
config.max_position_embeddings,
|
|
||||||
config.dim,
|
|
||||||
embedding_config),
|
|
||||||
|
|
||||||
true => create_sinusoidal_embeddings(&config, p.device())
|
|
||||||
};
|
};
|
||||||
let layer_norm_config = nn::LayerNormConfig { eps: 1e-12, ..Default::default() };
|
|
||||||
let layer_norm: nn::LayerNorm = nn::layer_norm(p / "LayerNorm", vec![config.dim], layer_norm_config);
|
let word_embeddings: nn::Embedding = embedding(
|
||||||
|
p / "word_embeddings",
|
||||||
|
config.vocab_size,
|
||||||
|
config.dim,
|
||||||
|
embedding_config,
|
||||||
|
);
|
||||||
|
let position_embeddings: nn::Embedding = match config.sinusoidal_pos_embds {
|
||||||
|
false => embedding(
|
||||||
|
p / "position_embeddings",
|
||||||
|
config.max_position_embeddings,
|
||||||
|
config.dim,
|
||||||
|
embedding_config,
|
||||||
|
),
|
||||||
|
|
||||||
|
true => create_sinusoidal_embeddings(&config, p.device()),
|
||||||
|
};
|
||||||
|
let layer_norm_config = nn::LayerNormConfig {
|
||||||
|
eps: 1e-12,
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
let layer_norm: nn::LayerNorm =
|
||||||
|
nn::layer_norm(p / "LayerNorm", vec![config.dim], layer_norm_config);
|
||||||
let dropout: Dropout = Dropout::new(config.dropout);
|
let dropout: Dropout = Dropout::new(config.dropout);
|
||||||
DistilBertEmbedding { word_embeddings, position_embeddings, layer_norm, dropout }
|
DistilBertEmbedding {
|
||||||
|
word_embeddings,
|
||||||
|
position_embeddings,
|
||||||
|
layer_norm,
|
||||||
|
dropout,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn _get_word_embeddings(&self) -> &nn::Embedding {
|
pub fn _get_word_embeddings(&self) -> &nn::Embedding {
|
||||||
@ -94,10 +118,12 @@ impl ModuleT for DistilBertEmbedding {
|
|||||||
let word_embed = input.apply(&self.word_embeddings);
|
let word_embed = input.apply(&self.word_embeddings);
|
||||||
let position_embed = position_ids.apply(&self.position_embeddings);
|
let position_embed = position_ids.apply(&self.position_embeddings);
|
||||||
|
|
||||||
// position_embed.get(0).get(0).print();
|
// position_embed.get(0).get(0).print();
|
||||||
let embeddings = word_embed + position_embed;
|
let embeddings = word_embed + position_embed;
|
||||||
let embeddings = embeddings.apply(&self.layer_norm).apply_t(&self.dropout, train);
|
let embeddings = embeddings
|
||||||
|
.apply(&self.layer_norm)
|
||||||
|
.apply_t(&self.dropout, train);
|
||||||
|
|
||||||
embeddings
|
embeddings
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -18,18 +18,27 @@
|
|||||||
//! Pretrained models are available and can be downloaded using RemoteResources.
|
//! Pretrained models are available and can be downloaded using RemoteResources.
|
||||||
//!
|
//!
|
||||||
//! ```no_run
|
//! ```no_run
|
||||||
//!# fn main() -> failure::Fallible<()> {
|
//! # fn main() -> failure::Fallible<()> {
|
||||||
//!#
|
//! #
|
||||||
//! use rust_tokenizers::BertTokenizer;
|
//! use rust_tokenizers::BertTokenizer;
|
||||||
//! use tch::{nn, Device};
|
//! use tch::{nn, Device};
|
||||||
//!# use std::path::PathBuf;
|
//! # use std::path::PathBuf;
|
||||||
|
//! use rust_bert::distilbert::{
|
||||||
|
//! DistilBertConfig, DistilBertConfigResources, DistilBertModelMaskedLM,
|
||||||
|
//! DistilBertModelResources, DistilBertVocabResources,
|
||||||
|
//! };
|
||||||
|
//! use rust_bert::resources::{download_resource, LocalResource, RemoteResource, Resource};
|
||||||
//! use rust_bert::Config;
|
//! use rust_bert::Config;
|
||||||
//! use rust_bert::distilbert::{DistilBertModelMaskedLM, DistilBertConfig, DistilBertConfigResources, DistilBertVocabResources, DistilBertModelResources};
|
|
||||||
//! use rust_bert::resources::{Resource, download_resource, RemoteResource, LocalResource};
|
|
||||||
//!
|
//!
|
||||||
//! let config_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/config.json")});
|
//! let config_resource = Resource::Local(LocalResource {
|
||||||
//! let vocab_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/vocab.txt")});
|
//! local_path: PathBuf::from("path/to/config.json"),
|
||||||
//! let weights_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/model.ot")});
|
//! });
|
||||||
|
//! let vocab_resource = Resource::Local(LocalResource {
|
||||||
|
//! local_path: PathBuf::from("path/to/vocab.txt"),
|
||||||
|
//! });
|
||||||
|
//! let weights_resource = Resource::Local(LocalResource {
|
||||||
|
//! local_path: PathBuf::from("path/to/model.ot"),
|
||||||
|
//! });
|
||||||
//! let config_path = download_resource(&config_resource)?;
|
//! let config_path = download_resource(&config_resource)?;
|
||||||
//! let vocab_path = download_resource(&vocab_resource)?;
|
//! let vocab_path = download_resource(&vocab_resource)?;
|
||||||
//! let weights_path = download_resource(&weights_resource)?;
|
//! let weights_path = download_resource(&weights_resource)?;
|
||||||
@ -40,17 +49,17 @@
|
|||||||
//! let bert_model = DistilBertModelMaskedLM::new(&vs.root(), &config);
|
//! let bert_model = DistilBertModelMaskedLM::new(&vs.root(), &config);
|
||||||
//! vs.load(weights_path)?;
|
//! vs.load(weights_path)?;
|
||||||
//!
|
//!
|
||||||
//!# Ok(())
|
//! # Ok(())
|
||||||
//!# }
|
//! # }
|
||||||
//! ```
|
//! ```
|
||||||
|
|
||||||
|
mod attention;
|
||||||
|
|
||||||
mod distilbert;
|
mod distilbert;
|
||||||
mod embeddings;
|
mod embeddings;
|
||||||
mod attention;
|
|
||||||
mod transformer;
|
mod transformer;
|
||||||
|
|
||||||
pub use distilbert::{DistilBertModelResources, DistilBertConfigResources, DistilBertVocabResources,
|
pub use distilbert::{
|
||||||
DistilBertConfig, Activation, DistilBertModel, DistilBertForQuestionAnswering, DistilBertForTokenClassification,
|
Activation, DistilBertConfig, DistilBertConfigResources, DistilBertForQuestionAnswering,
|
||||||
DistilBertModelMaskedLM, DistilBertModelClassifier};
|
DistilBertForTokenClassification, DistilBertModel, DistilBertModelClassifier,
|
||||||
|
DistilBertModelMaskedLM, DistilBertModelResources, DistilBertVocabResources,
|
||||||
|
};
|
||||||
|
@ -10,13 +10,13 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
use tch::{Tensor, nn};
|
|
||||||
use crate::distilbert::distilbert::{DistilBertConfig, Activation};
|
|
||||||
use crate::distilbert::attention::MultiHeadSelfAttention;
|
|
||||||
use tch::nn::LayerNorm;
|
|
||||||
use std::borrow::BorrowMut;
|
|
||||||
use crate::common::dropout::Dropout;
|
|
||||||
use crate::common::activations::{_gelu, _relu};
|
use crate::common::activations::{_gelu, _relu};
|
||||||
|
use crate::common::dropout::Dropout;
|
||||||
|
use crate::distilbert::attention::MultiHeadSelfAttention;
|
||||||
|
use crate::distilbert::distilbert::{Activation, DistilBertConfig};
|
||||||
|
use std::borrow::BorrowMut;
|
||||||
|
use tch::nn::LayerNorm;
|
||||||
|
use tch::{nn, Tensor};
|
||||||
|
|
||||||
pub struct FeedForwardNetwork {
|
pub struct FeedForwardNetwork {
|
||||||
lin1: nn::Linear,
|
lin1: nn::Linear,
|
||||||
@ -27,18 +27,35 @@ pub struct FeedForwardNetwork {
|
|||||||
|
|
||||||
impl FeedForwardNetwork {
|
impl FeedForwardNetwork {
|
||||||
pub fn new(p: nn::Path, config: &DistilBertConfig) -> FeedForwardNetwork {
|
pub fn new(p: nn::Path, config: &DistilBertConfig) -> FeedForwardNetwork {
|
||||||
let lin1 = nn::linear(&p / "lin1", config.dim, config.hidden_dim, Default::default());
|
let lin1 = nn::linear(
|
||||||
let lin2 = nn::linear(&p / "lin2", config.hidden_dim, config.dim, Default::default());
|
&p / "lin1",
|
||||||
|
config.dim,
|
||||||
|
config.hidden_dim,
|
||||||
|
Default::default(),
|
||||||
|
);
|
||||||
|
let lin2 = nn::linear(
|
||||||
|
&p / "lin2",
|
||||||
|
config.hidden_dim,
|
||||||
|
config.dim,
|
||||||
|
Default::default(),
|
||||||
|
);
|
||||||
let dropout = Dropout::new(config.dropout);
|
let dropout = Dropout::new(config.dropout);
|
||||||
let activation = Box::new(match &config.activation {
|
let activation = Box::new(match &config.activation {
|
||||||
Activation::gelu => _gelu,
|
Activation::gelu => _gelu,
|
||||||
Activation::relu => _relu
|
Activation::relu => _relu,
|
||||||
});
|
});
|
||||||
FeedForwardNetwork { lin1, lin2, dropout, activation }
|
FeedForwardNetwork {
|
||||||
|
lin1,
|
||||||
|
lin2,
|
||||||
|
dropout,
|
||||||
|
activation,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn forward_t(&self, input: &Tensor, train: bool) -> Tensor {
|
pub fn forward_t(&self, input: &Tensor, train: bool) -> Tensor {
|
||||||
(self.activation)(&input.apply(&self.lin1)).apply(&self.lin2).apply_t(&self.dropout, train)
|
(self.activation)(&input.apply(&self.lin1))
|
||||||
|
.apply(&self.lin2)
|
||||||
|
.apply_t(&self.dropout, train)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -52,10 +69,15 @@ pub struct TransformerBlock {
|
|||||||
impl TransformerBlock {
|
impl TransformerBlock {
|
||||||
pub fn new(p: &nn::Path, config: &DistilBertConfig) -> TransformerBlock {
|
pub fn new(p: &nn::Path, config: &DistilBertConfig) -> TransformerBlock {
|
||||||
let attention = MultiHeadSelfAttention::new(p / "attention", &config);
|
let attention = MultiHeadSelfAttention::new(p / "attention", &config);
|
||||||
let layer_norm_config = nn::LayerNormConfig { eps: 1e-12, ..Default::default() };
|
let layer_norm_config = nn::LayerNormConfig {
|
||||||
let sa_layer_norm = nn::layer_norm(p / "sa_layer_norm", vec![config.dim], layer_norm_config);
|
eps: 1e-12,
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
let sa_layer_norm =
|
||||||
|
nn::layer_norm(p / "sa_layer_norm", vec![config.dim], layer_norm_config);
|
||||||
let ffn = FeedForwardNetwork::new(p / "ffn", &config);
|
let ffn = FeedForwardNetwork::new(p / "ffn", &config);
|
||||||
let output_layer_norm = nn::layer_norm(p / "output_layer_norm", vec![config.dim], layer_norm_config);
|
let output_layer_norm =
|
||||||
|
nn::layer_norm(p / "output_layer_norm", vec![config.dim], layer_norm_config);
|
||||||
|
|
||||||
TransformerBlock {
|
TransformerBlock {
|
||||||
attention,
|
attention,
|
||||||
@ -65,8 +87,15 @@ impl TransformerBlock {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn forward_t(&self, input: &Tensor, mask: &Option<Tensor>, train: bool) -> (Tensor, Option<Tensor>) {
|
pub fn forward_t(
|
||||||
let (output, sa_weights) = self.attention.forward_t(&input, &input, &input, mask, train);
|
&self,
|
||||||
|
input: &Tensor,
|
||||||
|
mask: &Option<Tensor>,
|
||||||
|
train: bool,
|
||||||
|
) -> (Tensor, Option<Tensor>) {
|
||||||
|
let (output, sa_weights) = self
|
||||||
|
.attention
|
||||||
|
.forward_t(&input, &input, &input, mask, train);
|
||||||
let output = (input + &output).apply(&self.sa_layer_norm);
|
let output = (input + &output).apply(&self.sa_layer_norm);
|
||||||
let output = (&output + self.ffn.forward_t(&output, train)).apply(&self.output_layer_norm);
|
let output = (&output + self.ffn.forward_t(&output, train)).apply(&self.output_layer_norm);
|
||||||
(output, sa_weights)
|
(output, sa_weights)
|
||||||
@ -84,25 +113,41 @@ impl Transformer {
|
|||||||
let p = &(p / "layer");
|
let p = &(p / "layer");
|
||||||
let output_attentions = match config.output_attentions {
|
let output_attentions = match config.output_attentions {
|
||||||
Some(value) => value,
|
Some(value) => value,
|
||||||
None => false
|
None => false,
|
||||||
};
|
};
|
||||||
let output_hidden_states = match config.output_hidden_states {
|
let output_hidden_states = match config.output_hidden_states {
|
||||||
Some(value) => value,
|
Some(value) => value,
|
||||||
None => false
|
None => false,
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut layers: Vec<TransformerBlock> = vec!();
|
let mut layers: Vec<TransformerBlock> = vec![];
|
||||||
for layer_index in 0..config.n_layers {
|
for layer_index in 0..config.n_layers {
|
||||||
layers.push(TransformerBlock::new(&(p / layer_index), config));
|
layers.push(TransformerBlock::new(&(p / layer_index), config));
|
||||||
};
|
}
|
||||||
|
|
||||||
Transformer { output_attentions, output_hidden_states, layers }
|
Transformer {
|
||||||
|
output_attentions,
|
||||||
|
output_hidden_states,
|
||||||
|
layers,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn forward_t(&self, input: &Tensor, mask: Option<Tensor>, train: bool)
|
pub fn forward_t(
|
||||||
-> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
|
&self,
|
||||||
let mut all_hidden_states: Option<Vec<Tensor>> = if self.output_hidden_states { Some(vec!()) } else { None };
|
input: &Tensor,
|
||||||
let mut all_attentions: Option<Vec<Tensor>> = if self.output_attentions { Some(vec!()) } else { None };
|
mask: Option<Tensor>,
|
||||||
|
train: bool,
|
||||||
|
) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
|
||||||
|
let mut all_hidden_states: Option<Vec<Tensor>> = if self.output_hidden_states {
|
||||||
|
Some(vec![])
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
let mut all_attentions: Option<Vec<Tensor>> = if self.output_attentions {
|
||||||
|
Some(vec![])
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
let mut hidden_state = input.copy();
|
let mut hidden_state = input.copy();
|
||||||
let mut attention_weights: Option<Tensor>;
|
let mut attention_weights: Option<Tensor>;
|
||||||
@ -121,10 +166,10 @@ impl Transformer {
|
|||||||
attentions.push(attention_weights.as_ref().unwrap().copy());
|
attentions.push(attention_weights.as_ref().unwrap().copy());
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
None => break
|
None => break,
|
||||||
};
|
};
|
||||||
};
|
}
|
||||||
|
|
||||||
(hidden_state, all_hidden_states, all_attentions)
|
(hidden_state, all_hidden_states, all_attentions)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -12,15 +12,15 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
|
use crate::bert::encoder::BertEncoder;
|
||||||
|
use crate::bert::{Activation, BertConfig};
|
||||||
|
use crate::common::activations::{_gelu, _mish, _relu};
|
||||||
|
use crate::common::dropout::Dropout;
|
||||||
|
use crate::electra::embeddings::ElectraEmbeddings;
|
||||||
|
use crate::Config;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use crate::bert::{Activation, BertConfig};
|
use tch::{nn, Kind, Tensor};
|
||||||
use crate::Config;
|
|
||||||
use crate::electra::embeddings::ElectraEmbeddings;
|
|
||||||
use tch::{nn, Tensor, Kind};
|
|
||||||
use crate::bert::encoder::BertEncoder;
|
|
||||||
use crate::common::activations::{_gelu, _relu, _mish};
|
|
||||||
use crate::common::dropout::Dropout;
|
|
||||||
|
|
||||||
/// # Electra Pretrained model weight files
|
/// # Electra Pretrained model weight files
|
||||||
pub struct ElectraModelResources;
|
pub struct ElectraModelResources;
|
||||||
@ -33,23 +33,41 @@ pub struct ElectraVocabResources;
|
|||||||
|
|
||||||
impl ElectraModelResources {
|
impl ElectraModelResources {
|
||||||
/// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/electra. Modified with conversion to C-array format.
|
/// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/electra. Modified with conversion to C-array format.
|
||||||
pub const BASE_GENERATOR: (&'static str, &'static str) = ("electra-base-generator/model.ot", "https://cdn.huggingface.co/google/electra-base-generator/rust_model.ot");
|
pub const BASE_GENERATOR: (&'static str, &'static str) = (
|
||||||
|
"electra-base-generator/model.ot",
|
||||||
|
"https://cdn.huggingface.co/google/electra-base-generator/rust_model.ot",
|
||||||
|
);
|
||||||
/// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/electra. Modified with conversion to C-array format.
|
/// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/electra. Modified with conversion to C-array format.
|
||||||
pub const BASE_DISCRIMINATOR: (&'static str, &'static str) = ("electra-base-discriminator/model.ot", "https://cdn.huggingface.co/google/electra-base-discriminator/rust_model.ot");
|
pub const BASE_DISCRIMINATOR: (&'static str, &'static str) = (
|
||||||
|
"electra-base-discriminator/model.ot",
|
||||||
|
"https://cdn.huggingface.co/google/electra-base-discriminator/rust_model.ot",
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ElectraConfigResources {
|
impl ElectraConfigResources {
|
||||||
/// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/electra. Modified with conversion to C-array format.
|
/// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/electra. Modified with conversion to C-array format.
|
||||||
pub const BASE_GENERATOR: (&'static str, &'static str) = ("electra-base-generator/config.json", "https://cdn.huggingface.co/google/electra-base-generator/config.json");
|
pub const BASE_GENERATOR: (&'static str, &'static str) = (
|
||||||
|
"electra-base-generator/config.json",
|
||||||
|
"https://cdn.huggingface.co/google/electra-base-generator/config.json",
|
||||||
|
);
|
||||||
/// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/electra. Modified with conversion to C-array format.
|
/// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/electra. Modified with conversion to C-array format.
|
||||||
pub const BASE_DISCRIMINATOR: (&'static str, &'static str) = ("electra-base-discriminator/config.json", "https://cdn.huggingface.co/google/electra-base-discriminator/config.json");
|
pub const BASE_DISCRIMINATOR: (&'static str, &'static str) = (
|
||||||
|
"electra-base-discriminator/config.json",
|
||||||
|
"https://cdn.huggingface.co/google/electra-base-discriminator/config.json",
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ElectraVocabResources {
|
impl ElectraVocabResources {
|
||||||
/// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/electra. Modified with conversion to C-array format.
|
/// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/electra. Modified with conversion to C-array format.
|
||||||
pub const BASE_GENERATOR: (&'static str, &'static str) = ("electra-base-generator/vocab.txt", "https://cdn.huggingface.co/google/electra-base-generator/vocab.txt");
|
pub const BASE_GENERATOR: (&'static str, &'static str) = (
|
||||||
|
"electra-base-generator/vocab.txt",
|
||||||
|
"https://cdn.huggingface.co/google/electra-base-generator/vocab.txt",
|
||||||
|
);
|
||||||
/// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/electra. Modified with conversion to C-array format.
|
/// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/electra. Modified with conversion to C-array format.
|
||||||
pub const BASE_DISCRIMINATOR: (&'static str, &'static str) = ("electra-base-discriminator/vocab.txt", "https://cdn.huggingface.co/google/electra-base-discriminator/vocab.txt");
|
pub const BASE_DISCRIMINATOR: (&'static str, &'static str) = (
|
||||||
|
"electra-base-discriminator/vocab.txt",
|
||||||
|
"https://cdn.huggingface.co/google/electra-base-discriminator/vocab.txt",
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
@ -103,10 +121,10 @@ impl ElectraModel {
|
|||||||
/// # Example
|
/// # Example
|
||||||
///
|
///
|
||||||
/// ```no_run
|
/// ```no_run
|
||||||
/// use rust_bert::electra::{ElectraModel, ElectraConfig};
|
/// use rust_bert::electra::{ElectraConfig, ElectraModel};
|
||||||
/// use tch::{nn, Device};
|
|
||||||
/// use rust_bert::Config;
|
/// use rust_bert::Config;
|
||||||
/// use std::path::Path;
|
/// use std::path::Path;
|
||||||
|
/// use tch::{nn, Device};
|
||||||
///
|
///
|
||||||
/// let config_path = Path::new("path/to/config.json");
|
/// let config_path = Path::new("path/to/config.json");
|
||||||
/// let device = Device::Cpu;
|
/// let device = Device::Cpu;
|
||||||
@ -114,11 +132,15 @@ impl ElectraModel {
|
|||||||
/// let config = ElectraConfig::from_file(config_path);
|
/// let config = ElectraConfig::from_file(config_path);
|
||||||
/// let electra_model: ElectraModel = ElectraModel::new(&(&p.root() / "electra"), &config);
|
/// let electra_model: ElectraModel = ElectraModel::new(&(&p.root() / "electra"), &config);
|
||||||
/// ```
|
/// ```
|
||||||
///
|
|
||||||
pub fn new(p: &nn::Path, config: &ElectraConfig) -> ElectraModel {
|
pub fn new(p: &nn::Path, config: &ElectraConfig) -> ElectraModel {
|
||||||
let embeddings = ElectraEmbeddings::new(&(p / "embeddings"), config);
|
let embeddings = ElectraEmbeddings::new(&(p / "embeddings"), config);
|
||||||
let embeddings_project = if config.embedding_size != config.hidden_size {
|
let embeddings_project = if config.embedding_size != config.hidden_size {
|
||||||
Some(nn::linear(&(p / "embeddings_project"), config.embedding_size, config.hidden_size, Default::default()))
|
Some(nn::linear(
|
||||||
|
&(p / "embeddings_project"),
|
||||||
|
config.embedding_size,
|
||||||
|
config.hidden_size,
|
||||||
|
Default::default(),
|
||||||
|
))
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
};
|
};
|
||||||
@ -141,7 +163,11 @@ impl ElectraModel {
|
|||||||
label2id: config.label2id.clone(),
|
label2id: config.label2id.clone(),
|
||||||
};
|
};
|
||||||
let encoder = BertEncoder::new(&(p / "encoder"), &bert_config);
|
let encoder = BertEncoder::new(&(p / "encoder"), &bert_config);
|
||||||
ElectraModel { embeddings, embeddings_project, encoder }
|
ElectraModel {
|
||||||
|
embeddings,
|
||||||
|
embeddings_project,
|
||||||
|
encoder,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Forward pass through the model
|
/// Forward pass through the model
|
||||||
@ -164,80 +190,98 @@ impl ElectraModel {
|
|||||||
/// # Example
|
/// # Example
|
||||||
///
|
///
|
||||||
/// ```no_run
|
/// ```no_run
|
||||||
///# use rust_bert::electra::{ElectraModel, ElectraConfig};
|
/// # use rust_bert::electra::{ElectraModel, ElectraConfig};
|
||||||
///# use tch::{nn, Device, Tensor, no_grad};
|
/// # use tch::{nn, Device, Tensor, no_grad};
|
||||||
///# use rust_bert::Config;
|
/// # use rust_bert::Config;
|
||||||
///# use std::path::Path;
|
/// # use std::path::Path;
|
||||||
///# use tch::kind::Kind::Int64;
|
/// # use tch::kind::Kind::Int64;
|
||||||
///# let config_path = Path::new("path/to/config.json");
|
/// # let config_path = Path::new("path/to/config.json");
|
||||||
///# let device = Device::Cpu;
|
/// # let device = Device::Cpu;
|
||||||
///# let vs = nn::VarStore::new(device);
|
/// # let vs = nn::VarStore::new(device);
|
||||||
///# let config = ElectraConfig::from_file(config_path);
|
/// # let config = ElectraConfig::from_file(config_path);
|
||||||
///# let electra_model: ElectraModel = ElectraModel::new(&vs.root(), &config);
|
/// # let electra_model: ElectraModel = ElectraModel::new(&vs.root(), &config);
|
||||||
/// let (batch_size, sequence_length) = (64, 128);
|
/// let (batch_size, sequence_length) = (64, 128);
|
||||||
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
|
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
|
||||||
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
|
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
|
||||||
/// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
|
/// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
|
||||||
/// let position_ids = Tensor::arange(sequence_length, (Int64, device)).expand(&[batch_size, sequence_length], true);
|
/// let position_ids = Tensor::arange(sequence_length, (Int64, device))
|
||||||
///
|
/// .expand(&[batch_size, sequence_length], true);
|
||||||
/// let (output, all_hidden_states, all_attentions) = no_grad(|| {
|
|
||||||
/// electra_model
|
|
||||||
/// .forward_t(Some(input_tensor),
|
|
||||||
/// Some(mask),
|
|
||||||
/// Some(token_type_ids),
|
|
||||||
/// Some(position_ids),
|
|
||||||
/// None,
|
|
||||||
/// false).unwrap()
|
|
||||||
/// });
|
|
||||||
///
|
///
|
||||||
|
/// let (output, all_hidden_states, all_attentions) = no_grad(|| {
|
||||||
|
/// electra_model
|
||||||
|
/// .forward_t(
|
||||||
|
/// Some(input_tensor),
|
||||||
|
/// Some(mask),
|
||||||
|
/// Some(token_type_ids),
|
||||||
|
/// Some(position_ids),
|
||||||
|
/// None,
|
||||||
|
/// false,
|
||||||
|
/// )
|
||||||
|
/// .unwrap()
|
||||||
|
/// });
|
||||||
/// ```
|
/// ```
|
||||||
///
|
pub fn forward_t(
|
||||||
pub fn forward_t(&self,
|
&self,
|
||||||
input_ids: Option<Tensor>,
|
input_ids: Option<Tensor>,
|
||||||
mask: Option<Tensor>,
|
mask: Option<Tensor>,
|
||||||
token_type_ids: Option<Tensor>,
|
token_type_ids: Option<Tensor>,
|
||||||
position_ids: Option<Tensor>,
|
position_ids: Option<Tensor>,
|
||||||
input_embeds: Option<Tensor>,
|
input_embeds: Option<Tensor>,
|
||||||
train: bool)
|
train: bool,
|
||||||
-> Result<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
|
) -> Result<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
|
||||||
let (input_shape, device) = match &input_ids {
|
let (input_shape, device) = match &input_ids {
|
||||||
Some(input_value) => match &input_embeds {
|
Some(input_value) => match &input_embeds {
|
||||||
Some(_) => { return Err("Only one of input ids or input embeddings may be set"); }
|
Some(_) => {
|
||||||
None => (input_value.size(), input_value.device())
|
return Err("Only one of input ids or input embeddings may be set");
|
||||||
}
|
}
|
||||||
|
None => (input_value.size(), input_value.device()),
|
||||||
|
},
|
||||||
None => match &input_embeds {
|
None => match &input_embeds {
|
||||||
Some(embeds) => (vec!(embeds.size()[0], embeds.size()[1]), embeds.device()),
|
Some(embeds) => (vec![embeds.size()[0], embeds.size()[1]], embeds.device()),
|
||||||
None => { return Err("At least one of input ids or input embeddings must be set"); }
|
None => {
|
||||||
}
|
return Err("At least one of input ids or input embeddings must be set");
|
||||||
|
}
|
||||||
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
let mask = match mask {
|
let mask = match mask {
|
||||||
Some(value) => value,
|
Some(value) => value,
|
||||||
None => Tensor::ones(&input_shape, (Kind::Int64, device))
|
None => Tensor::ones(&input_shape, (Kind::Int64, device)),
|
||||||
};
|
};
|
||||||
|
|
||||||
let extended_attention_mask = match mask.dim() {
|
let extended_attention_mask = match mask.dim() {
|
||||||
3 => mask.unsqueeze(1),
|
3 => mask.unsqueeze(1),
|
||||||
2 => mask.unsqueeze(1).unsqueeze(1),
|
2 => mask.unsqueeze(1).unsqueeze(1),
|
||||||
_ => { return Err("Invalid attention mask dimension, must be 2 or 3"); }
|
_ => {
|
||||||
|
return Err("Invalid attention mask dimension, must be 2 or 3");
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let hidden_states = match self.embeddings.forward_t(input_ids, token_type_ids, position_ids, input_embeds, train) {
|
let hidden_states = match self.embeddings.forward_t(
|
||||||
|
input_ids,
|
||||||
|
token_type_ids,
|
||||||
|
position_ids,
|
||||||
|
input_embeds,
|
||||||
|
train,
|
||||||
|
) {
|
||||||
Ok(value) => value,
|
Ok(value) => value,
|
||||||
Err(e) => { return Err(e); }
|
Err(e) => {
|
||||||
|
return Err(e);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let hidden_states = match &self.embeddings_project {
|
let hidden_states = match &self.embeddings_project {
|
||||||
Some(layer) => hidden_states.apply(layer),
|
Some(layer) => hidden_states.apply(layer),
|
||||||
None => hidden_states
|
None => hidden_states,
|
||||||
};
|
};
|
||||||
|
|
||||||
let (hidden_state, all_hidden_states, all_attentions) =
|
let (hidden_state, all_hidden_states, all_attentions) = self.encoder.forward_t(
|
||||||
self.encoder.forward_t(&hidden_states,
|
&hidden_states,
|
||||||
&Some(extended_attention_mask),
|
&Some(extended_attention_mask),
|
||||||
&None,
|
&None,
|
||||||
&None,
|
&None,
|
||||||
train);
|
train,
|
||||||
|
);
|
||||||
|
|
||||||
Ok((hidden_state, all_hidden_states, all_attentions))
|
Ok((hidden_state, all_hidden_states, all_attentions))
|
||||||
}
|
}
|
||||||
@ -268,9 +312,9 @@ impl ElectraDiscriminatorHead {
|
|||||||
///
|
///
|
||||||
/// ```no_run
|
/// ```no_run
|
||||||
/// use rust_bert::electra::{ElectraConfig, ElectraDiscriminatorHead};
|
/// use rust_bert::electra::{ElectraConfig, ElectraDiscriminatorHead};
|
||||||
/// use tch::{nn, Device};
|
|
||||||
/// use rust_bert::Config;
|
/// use rust_bert::Config;
|
||||||
/// use std::path::Path;
|
/// use std::path::Path;
|
||||||
|
/// use tch::{nn, Device};
|
||||||
///
|
///
|
||||||
/// let config_path = Path::new("path/to/config.json");
|
/// let config_path = Path::new("path/to/config.json");
|
||||||
/// let device = Device::Cpu;
|
/// let device = Device::Cpu;
|
||||||
@ -278,16 +322,29 @@ impl ElectraDiscriminatorHead {
|
|||||||
/// let config = ElectraConfig::from_file(config_path);
|
/// let config = ElectraConfig::from_file(config_path);
|
||||||
/// let discriminator_head = ElectraDiscriminatorHead::new(&(&p.root() / "electra"), &config);
|
/// let discriminator_head = ElectraDiscriminatorHead::new(&(&p.root() / "electra"), &config);
|
||||||
/// ```
|
/// ```
|
||||||
///
|
|
||||||
pub fn new(p: &nn::Path, config: &ElectraConfig) -> ElectraDiscriminatorHead {
|
pub fn new(p: &nn::Path, config: &ElectraConfig) -> ElectraDiscriminatorHead {
|
||||||
let dense = nn::linear(&(p / "dense"), config.hidden_size, config.hidden_size, Default::default());
|
let dense = nn::linear(
|
||||||
let dense_prediction = nn::linear(&(p / "dense_prediction"), config.hidden_size, 1, Default::default());
|
&(p / "dense"),
|
||||||
|
config.hidden_size,
|
||||||
|
config.hidden_size,
|
||||||
|
Default::default(),
|
||||||
|
);
|
||||||
|
let dense_prediction = nn::linear(
|
||||||
|
&(p / "dense_prediction"),
|
||||||
|
config.hidden_size,
|
||||||
|
1,
|
||||||
|
Default::default(),
|
||||||
|
);
|
||||||
let activation = Box::new(match &config.hidden_act {
|
let activation = Box::new(match &config.hidden_act {
|
||||||
Activation::gelu => _gelu,
|
Activation::gelu => _gelu,
|
||||||
Activation::relu => _relu,
|
Activation::relu => _relu,
|
||||||
Activation::mish => _mish
|
Activation::mish => _mish,
|
||||||
});
|
});
|
||||||
ElectraDiscriminatorHead { dense, dense_prediction, activation }
|
ElectraDiscriminatorHead {
|
||||||
|
dense,
|
||||||
|
dense_prediction,
|
||||||
|
activation,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Forward pass through the discriminator head
|
/// Forward pass through the discriminator head
|
||||||
@ -305,25 +362,24 @@ impl ElectraDiscriminatorHead {
|
|||||||
/// # Example
|
/// # Example
|
||||||
///
|
///
|
||||||
/// ```no_run
|
/// ```no_run
|
||||||
///# use rust_bert::electra::{ElectraConfig, ElectraDiscriminatorHead};
|
/// # use rust_bert::electra::{ElectraConfig, ElectraDiscriminatorHead};
|
||||||
///# use tch::{nn, Device, Tensor, no_grad};
|
/// # use tch::{nn, Device, Tensor, no_grad};
|
||||||
///# use rust_bert::Config;
|
/// # use rust_bert::Config;
|
||||||
///# use std::path::Path;
|
/// # use std::path::Path;
|
||||||
///# use tch::kind::Kind::Float;
|
/// # use tch::kind::Kind::Float;
|
||||||
///# let config_path = Path::new("path/to/config.json");
|
/// # let config_path = Path::new("path/to/config.json");
|
||||||
///# let device = Device::Cpu;
|
/// # let device = Device::Cpu;
|
||||||
///# let vs = nn::VarStore::new(device);
|
/// # let vs = nn::VarStore::new(device);
|
||||||
///# let config = ElectraConfig::from_file(config_path);
|
/// # let config = ElectraConfig::from_file(config_path);
|
||||||
///# let discriminator_head = ElectraDiscriminatorHead::new(&vs.root(), &config);
|
/// # let discriminator_head = ElectraDiscriminatorHead::new(&vs.root(), &config);
|
||||||
/// let (batch_size, sequence_length) = (64, 128);
|
/// let (batch_size, sequence_length) = (64, 128);
|
||||||
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length, config.hidden_size], (Float, device));
|
/// let input_tensor = Tensor::rand(
|
||||||
///
|
/// &[batch_size, sequence_length, config.hidden_size],
|
||||||
/// let output = no_grad(|| {
|
/// (Float, device),
|
||||||
/// discriminator_head.forward(&input_tensor)
|
/// );
|
||||||
/// });
|
|
||||||
///
|
///
|
||||||
|
/// let output = no_grad(|| discriminator_head.forward(&input_tensor));
|
||||||
/// ```
|
/// ```
|
||||||
///
|
|
||||||
pub fn forward(&self, encoder_hidden_states: &Tensor) -> Tensor {
|
pub fn forward(&self, encoder_hidden_states: &Tensor) -> Tensor {
|
||||||
let output = encoder_hidden_states.apply(&self.dense);
|
let output = encoder_hidden_states.apply(&self.dense);
|
||||||
let output = (self.activation)(&output);
|
let output = (self.activation)(&output);
|
||||||
@ -356,9 +412,9 @@ impl ElectraGeneratorHead {
|
|||||||
///
|
///
|
||||||
/// ```no_run
|
/// ```no_run
|
||||||
/// use rust_bert::electra::{ElectraConfig, ElectraGeneratorHead};
|
/// use rust_bert::electra::{ElectraConfig, ElectraGeneratorHead};
|
||||||
/// use tch::{nn, Device};
|
|
||||||
/// use rust_bert::Config;
|
/// use rust_bert::Config;
|
||||||
/// use std::path::Path;
|
/// use std::path::Path;
|
||||||
|
/// use tch::{nn, Device};
|
||||||
///
|
///
|
||||||
/// let config_path = Path::new("path/to/config.json");
|
/// let config_path = Path::new("path/to/config.json");
|
||||||
/// let device = Device::Cpu;
|
/// let device = Device::Cpu;
|
||||||
@ -366,13 +422,25 @@ impl ElectraGeneratorHead {
|
|||||||
/// let config = ElectraConfig::from_file(config_path);
|
/// let config = ElectraConfig::from_file(config_path);
|
||||||
/// let generator_head = ElectraGeneratorHead::new(&(&p.root() / "electra"), &config);
|
/// let generator_head = ElectraGeneratorHead::new(&(&p.root() / "electra"), &config);
|
||||||
/// ```
|
/// ```
|
||||||
///
|
|
||||||
pub fn new(p: &nn::Path, config: &ElectraConfig) -> ElectraGeneratorHead {
|
pub fn new(p: &nn::Path, config: &ElectraConfig) -> ElectraGeneratorHead {
|
||||||
let layer_norm = nn::layer_norm(p / "LayerNorm", vec![config.embedding_size], Default::default());
|
let layer_norm = nn::layer_norm(
|
||||||
let dense = nn::linear(&(p / "dense"), config.hidden_size, config.embedding_size, Default::default());
|
p / "LayerNorm",
|
||||||
|
vec![config.embedding_size],
|
||||||
|
Default::default(),
|
||||||
|
);
|
||||||
|
let dense = nn::linear(
|
||||||
|
&(p / "dense"),
|
||||||
|
config.hidden_size,
|
||||||
|
config.embedding_size,
|
||||||
|
Default::default(),
|
||||||
|
);
|
||||||
let activation = Box::new(_gelu);
|
let activation = Box::new(_gelu);
|
||||||
|
|
||||||
ElectraGeneratorHead { layer_norm, dense, activation }
|
ElectraGeneratorHead {
|
||||||
|
layer_norm,
|
||||||
|
dense,
|
||||||
|
activation,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Forward pass through the generator head
|
/// Forward pass through the generator head
|
||||||
@ -388,25 +456,24 @@ impl ElectraGeneratorHead {
|
|||||||
/// # Example
|
/// # Example
|
||||||
///
|
///
|
||||||
/// ```no_run
|
/// ```no_run
|
||||||
///# use rust_bert::electra::{ElectraConfig, ElectraGeneratorHead};
|
/// # use rust_bert::electra::{ElectraConfig, ElectraGeneratorHead};
|
||||||
///# use tch::{nn, Device, Tensor, no_grad};
|
/// # use tch::{nn, Device, Tensor, no_grad};
|
||||||
///# use rust_bert::Config;
|
/// # use rust_bert::Config;
|
||||||
///# use std::path::Path;
|
/// # use std::path::Path;
|
||||||
///# use tch::kind::Kind::Float;
|
/// # use tch::kind::Kind::Float;
|
||||||
///# let config_path = Path::new("path/to/config.json");
|
/// # let config_path = Path::new("path/to/config.json");
|
||||||
///# let device = Device::Cpu;
|
/// # let device = Device::Cpu;
|
||||||
///# let vs = nn::VarStore::new(device);
|
/// # let vs = nn::VarStore::new(device);
|
||||||
///# let config = ElectraConfig::from_file(config_path);
|
/// # let config = ElectraConfig::from_file(config_path);
|
||||||
///# let generator_head = ElectraGeneratorHead::new(&vs.root(), &config);
|
/// # let generator_head = ElectraGeneratorHead::new(&vs.root(), &config);
|
||||||
/// let (batch_size, sequence_length) = (64, 128);
|
/// let (batch_size, sequence_length) = (64, 128);
|
||||||
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length, config.hidden_size], (Float, device));
|
/// let input_tensor = Tensor::rand(
|
||||||
///
|
/// &[batch_size, sequence_length, config.hidden_size],
|
||||||
/// let output = no_grad(|| {
|
/// (Float, device),
|
||||||
/// generator_head.forward(&input_tensor)
|
/// );
|
||||||
/// });
|
|
||||||
///
|
///
|
||||||
|
/// let output = no_grad(|| generator_head.forward(&input_tensor));
|
||||||
/// ```
|
/// ```
|
||||||
///
|
|
||||||
pub fn forward(&self, encoder_hidden_states: &Tensor) -> Tensor {
|
pub fn forward(&self, encoder_hidden_states: &Tensor) -> Tensor {
|
||||||
let output = encoder_hidden_states.apply(&self.dense);
|
let output = encoder_hidden_states.apply(&self.dense);
|
||||||
let output = (self.activation)(&output);
|
let output = (self.activation)(&output);
|
||||||
@ -438,10 +505,10 @@ impl ElectraForMaskedLM {
|
|||||||
/// # Example
|
/// # Example
|
||||||
///
|
///
|
||||||
/// ```no_run
|
/// ```no_run
|
||||||
/// use rust_bert::electra::{ElectraForMaskedLM, ElectraConfig};
|
/// use rust_bert::electra::{ElectraConfig, ElectraForMaskedLM};
|
||||||
/// use tch::{nn, Device};
|
|
||||||
/// use rust_bert::Config;
|
/// use rust_bert::Config;
|
||||||
/// use std::path::Path;
|
/// use std::path::Path;
|
||||||
|
/// use tch::{nn, Device};
|
||||||
///
|
///
|
||||||
/// let config_path = Path::new("path/to/config.json");
|
/// let config_path = Path::new("path/to/config.json");
|
||||||
/// let device = Device::Cpu;
|
/// let device = Device::Cpu;
|
||||||
@ -449,13 +516,21 @@ impl ElectraForMaskedLM {
|
|||||||
/// let config = ElectraConfig::from_file(config_path);
|
/// let config = ElectraConfig::from_file(config_path);
|
||||||
/// let electra_model: ElectraForMaskedLM = ElectraForMaskedLM::new(&p.root(), &config);
|
/// let electra_model: ElectraForMaskedLM = ElectraForMaskedLM::new(&p.root(), &config);
|
||||||
/// ```
|
/// ```
|
||||||
///
|
|
||||||
pub fn new(p: &nn::Path, config: &ElectraConfig) -> ElectraForMaskedLM {
|
pub fn new(p: &nn::Path, config: &ElectraConfig) -> ElectraForMaskedLM {
|
||||||
let electra = ElectraModel::new(&(p / "electra"), config);
|
let electra = ElectraModel::new(&(p / "electra"), config);
|
||||||
let generator_head = ElectraGeneratorHead::new(&(p / "generator_predictions"), config);
|
let generator_head = ElectraGeneratorHead::new(&(p / "generator_predictions"), config);
|
||||||
let lm_head = nn::linear(&(p / "generator_lm_head"), config.embedding_size, config.vocab_size, Default::default());
|
let lm_head = nn::linear(
|
||||||
|
&(p / "generator_lm_head"),
|
||||||
|
config.embedding_size,
|
||||||
|
config.vocab_size,
|
||||||
|
Default::default(),
|
||||||
|
);
|
||||||
|
|
||||||
ElectraForMaskedLM { electra, generator_head, lm_head }
|
ElectraForMaskedLM {
|
||||||
|
electra,
|
||||||
|
generator_head,
|
||||||
|
lm_head,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Forward pass through the model
|
/// Forward pass through the model
|
||||||
@ -478,46 +553,53 @@ impl ElectraForMaskedLM {
|
|||||||
/// # Example
|
/// # Example
|
||||||
///
|
///
|
||||||
/// ```no_run
|
/// ```no_run
|
||||||
///# use rust_bert::electra::{ElectraForMaskedLM, ElectraConfig};
|
/// # use rust_bert::electra::{ElectraForMaskedLM, ElectraConfig};
|
||||||
///# use tch::{nn, Device, Tensor, no_grad};
|
/// # use tch::{nn, Device, Tensor, no_grad};
|
||||||
///# use rust_bert::Config;
|
/// # use rust_bert::Config;
|
||||||
///# use std::path::Path;
|
/// # use std::path::Path;
|
||||||
///# use tch::kind::Kind::Int64;
|
/// # use tch::kind::Kind::Int64;
|
||||||
///# let config_path = Path::new("path/to/config.json");
|
/// # let config_path = Path::new("path/to/config.json");
|
||||||
///# let device = Device::Cpu;
|
/// # let device = Device::Cpu;
|
||||||
///# let vs = nn::VarStore::new(device);
|
/// # let vs = nn::VarStore::new(device);
|
||||||
///# let config = ElectraConfig::from_file(config_path);
|
/// # let config = ElectraConfig::from_file(config_path);
|
||||||
///# let electra_model: ElectraForMaskedLM = ElectraForMaskedLM::new(&vs.root(), &config);
|
/// # let electra_model: ElectraForMaskedLM = ElectraForMaskedLM::new(&vs.root(), &config);
|
||||||
/// let (batch_size, sequence_length) = (64, 128);
|
/// let (batch_size, sequence_length) = (64, 128);
|
||||||
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
|
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
|
||||||
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
|
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
|
||||||
/// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
|
/// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
|
||||||
/// let position_ids = Tensor::arange(sequence_length, (Int64, device)).expand(&[batch_size, sequence_length], true);
|
/// let position_ids = Tensor::arange(sequence_length, (Int64, device))
|
||||||
///
|
/// .expand(&[batch_size, sequence_length], true);
|
||||||
/// let (output, all_hidden_states, all_attentions) = no_grad(|| {
|
|
||||||
/// electra_model
|
|
||||||
/// .forward_t(Some(input_tensor),
|
|
||||||
/// Some(mask),
|
|
||||||
/// Some(token_type_ids),
|
|
||||||
/// Some(position_ids),
|
|
||||||
/// None,
|
|
||||||
/// false)
|
|
||||||
/// });
|
|
||||||
///
|
///
|
||||||
|
/// let (output, all_hidden_states, all_attentions) = no_grad(|| {
|
||||||
|
/// electra_model.forward_t(
|
||||||
|
/// Some(input_tensor),
|
||||||
|
/// Some(mask),
|
||||||
|
/// Some(token_type_ids),
|
||||||
|
/// Some(position_ids),
|
||||||
|
/// None,
|
||||||
|
/// false,
|
||||||
|
/// )
|
||||||
|
/// });
|
||||||
/// ```
|
/// ```
|
||||||
///
|
pub fn forward_t(
|
||||||
pub fn forward_t(&self,
|
&self,
|
||||||
input_ids: Option<Tensor>,
|
input_ids: Option<Tensor>,
|
||||||
mask: Option<Tensor>,
|
mask: Option<Tensor>,
|
||||||
token_type_ids: Option<Tensor>,
|
token_type_ids: Option<Tensor>,
|
||||||
position_ids: Option<Tensor>,
|
position_ids: Option<Tensor>,
|
||||||
input_embeds: Option<Tensor>,
|
input_embeds: Option<Tensor>,
|
||||||
train: bool)
|
train: bool,
|
||||||
-> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
|
) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
|
||||||
let (hidden_states,
|
let (hidden_states, all_hidden_states, all_attentions) = self
|
||||||
all_hidden_states,
|
.electra
|
||||||
all_attentions) = self.electra
|
.forward_t(
|
||||||
.forward_t(input_ids, mask, token_type_ids, position_ids, input_embeds, train)
|
input_ids,
|
||||||
|
mask,
|
||||||
|
token_type_ids,
|
||||||
|
position_ids,
|
||||||
|
input_embeds,
|
||||||
|
train,
|
||||||
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
let hidden_states = self.generator_head.forward(&hidden_states);
|
let hidden_states = self.generator_head.forward(&hidden_states);
|
||||||
let hidden_states = hidden_states.apply(&self.lm_head);
|
let hidden_states = hidden_states.apply(&self.lm_head);
|
||||||
@ -547,10 +629,10 @@ impl ElectraDiscriminator {
|
|||||||
/// # Example
|
/// # Example
|
||||||
///
|
///
|
||||||
/// ```no_run
|
/// ```no_run
|
||||||
/// use rust_bert::electra::{ElectraDiscriminator, ElectraConfig};
|
/// use rust_bert::electra::{ElectraConfig, ElectraDiscriminator};
|
||||||
/// use tch::{nn, Device};
|
|
||||||
/// use rust_bert::Config;
|
/// use rust_bert::Config;
|
||||||
/// use std::path::Path;
|
/// use std::path::Path;
|
||||||
|
/// use tch::{nn, Device};
|
||||||
///
|
///
|
||||||
/// let config_path = Path::new("path/to/config.json");
|
/// let config_path = Path::new("path/to/config.json");
|
||||||
/// let device = Device::Cpu;
|
/// let device = Device::Cpu;
|
||||||
@ -558,12 +640,15 @@ impl ElectraDiscriminator {
|
|||||||
/// let config = ElectraConfig::from_file(config_path);
|
/// let config = ElectraConfig::from_file(config_path);
|
||||||
/// let electra_model: ElectraDiscriminator = ElectraDiscriminator::new(&p.root(), &config);
|
/// let electra_model: ElectraDiscriminator = ElectraDiscriminator::new(&p.root(), &config);
|
||||||
/// ```
|
/// ```
|
||||||
///
|
|
||||||
pub fn new(p: &nn::Path, config: &ElectraConfig) -> ElectraDiscriminator {
|
pub fn new(p: &nn::Path, config: &ElectraConfig) -> ElectraDiscriminator {
|
||||||
let electra = ElectraModel::new(&(p / "electra"), config);
|
let electra = ElectraModel::new(&(p / "electra"), config);
|
||||||
let discriminator_head = ElectraDiscriminatorHead::new(&(p / "discriminator_predictions"), config);
|
let discriminator_head =
|
||||||
|
ElectraDiscriminatorHead::new(&(p / "discriminator_predictions"), config);
|
||||||
|
|
||||||
ElectraDiscriminator { electra, discriminator_head }
|
ElectraDiscriminator {
|
||||||
|
electra,
|
||||||
|
discriminator_head,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Forward pass through the model
|
/// Forward pass through the model
|
||||||
@ -586,16 +671,16 @@ impl ElectraDiscriminator {
|
|||||||
/// # Example
|
/// # Example
|
||||||
///
|
///
|
||||||
/// ```no_run
|
/// ```no_run
|
||||||
///# use rust_bert::electra::{ElectraDiscriminator, ElectraConfig};
|
/// # use rust_bert::electra::{ElectraDiscriminator, ElectraConfig};
|
||||||
///# use tch::{nn, Device, Tensor, no_grad};
|
/// # use tch::{nn, Device, Tensor, no_grad};
|
||||||
///# use rust_bert::Config;
|
/// # use rust_bert::Config;
|
||||||
///# use std::path::Path;
|
/// # use std::path::Path;
|
||||||
///# use tch::kind::Kind::Int64;
|
/// # use tch::kind::Kind::Int64;
|
||||||
///# let config_path = Path::new("path/to/config.json");
|
/// # let config_path = Path::new("path/to/config.json");
|
||||||
///# let device = Device::Cpu;
|
/// # let device = Device::Cpu;
|
||||||
///# let vs = nn::VarStore::new(device);
|
/// # let vs = nn::VarStore::new(device);
|
||||||
///# let config = ElectraConfig::from_file(config_path);
|
/// # let config = ElectraConfig::from_file(config_path);
|
||||||
///# let electra_model: ElectraDiscriminator = ElectraDiscriminator::new(&vs.root(), &config);
|
/// # let electra_model: ElectraDiscriminator = ElectraDiscriminator::new(&vs.root(), &config);
|
||||||
/// let (batch_size, sequence_length) = (64, 128);
|
/// let (batch_size, sequence_length) = (64, 128);
|
||||||
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
|
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
|
||||||
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
|
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
|
||||||
@ -611,21 +696,26 @@ impl ElectraDiscriminator {
|
|||||||
/// None,
|
/// None,
|
||||||
/// false)
|
/// false)
|
||||||
/// });
|
/// });
|
||||||
///
|
|
||||||
/// ```
|
/// ```
|
||||||
///
|
pub fn forward_t(
|
||||||
pub fn forward_t(&self,
|
&self,
|
||||||
input_ids: Option<Tensor>,
|
input_ids: Option<Tensor>,
|
||||||
mask: Option<Tensor>,
|
mask: Option<Tensor>,
|
||||||
token_type_ids: Option<Tensor>,
|
token_type_ids: Option<Tensor>,
|
||||||
position_ids: Option<Tensor>,
|
position_ids: Option<Tensor>,
|
||||||
input_embeds: Option<Tensor>,
|
input_embeds: Option<Tensor>,
|
||||||
train: bool)
|
train: bool,
|
||||||
-> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
|
) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
|
||||||
let (hidden_states,
|
let (hidden_states, all_hidden_states, all_attentions) = self
|
||||||
all_hidden_states,
|
.electra
|
||||||
all_attentions) = self.electra
|
.forward_t(
|
||||||
.forward_t(input_ids, mask, token_type_ids, position_ids, input_embeds, train)
|
input_ids,
|
||||||
|
mask,
|
||||||
|
token_type_ids,
|
||||||
|
position_ids,
|
||||||
|
input_embeds,
|
||||||
|
train,
|
||||||
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
let probabilities = self.discriminator_head.forward(&hidden_states).sigmoid();
|
let probabilities = self.discriminator_head.forward(&hidden_states).sigmoid();
|
||||||
(probabilities, all_hidden_states, all_attentions)
|
(probabilities, all_hidden_states, all_attentions)
|
||||||
@ -656,24 +746,37 @@ impl ElectraForTokenClassification {
|
|||||||
/// # Example
|
/// # Example
|
||||||
///
|
///
|
||||||
/// ```no_run
|
/// ```no_run
|
||||||
/// use rust_bert::electra::{ElectraForTokenClassification, ElectraConfig};
|
/// use rust_bert::electra::{ElectraConfig, ElectraForTokenClassification};
|
||||||
/// use tch::{nn, Device};
|
|
||||||
/// use rust_bert::Config;
|
/// use rust_bert::Config;
|
||||||
/// use std::path::Path;
|
/// use std::path::Path;
|
||||||
|
/// use tch::{nn, Device};
|
||||||
/// let config_path = Path::new("path/to/config.json");
|
/// let config_path = Path::new("path/to/config.json");
|
||||||
/// let device = Device::Cpu;
|
/// let device = Device::Cpu;
|
||||||
/// let p = nn::VarStore::new(device);
|
/// let p = nn::VarStore::new(device);
|
||||||
/// let config = ElectraConfig::from_file(config_path);
|
/// let config = ElectraConfig::from_file(config_path);
|
||||||
/// let electra_model: ElectraForTokenClassification = ElectraForTokenClassification::new(&p.root(), &config);
|
/// let electra_model: ElectraForTokenClassification =
|
||||||
|
/// ElectraForTokenClassification::new(&p.root(), &config);
|
||||||
/// ```
|
/// ```
|
||||||
///
|
|
||||||
pub fn new(p: &nn::Path, config: &ElectraConfig) -> ElectraForTokenClassification {
|
pub fn new(p: &nn::Path, config: &ElectraConfig) -> ElectraForTokenClassification {
|
||||||
let electra = ElectraModel::new(&(p / "electra"), config);
|
let electra = ElectraModel::new(&(p / "electra"), config);
|
||||||
let dropout = Dropout::new(config.hidden_dropout_prob);
|
let dropout = Dropout::new(config.hidden_dropout_prob);
|
||||||
let num_labels = config.id2label.as_ref().expect("id2label must be provided for classifiers").len() as i64;
|
let num_labels = config
|
||||||
let classifier = nn::linear(&(p / "classifier"), config.hidden_size, num_labels, Default::default());
|
.id2label
|
||||||
|
.as_ref()
|
||||||
|
.expect("id2label must be provided for classifiers")
|
||||||
|
.len() as i64;
|
||||||
|
let classifier = nn::linear(
|
||||||
|
&(p / "classifier"),
|
||||||
|
config.hidden_size,
|
||||||
|
num_labels,
|
||||||
|
Default::default(),
|
||||||
|
);
|
||||||
|
|
||||||
ElectraForTokenClassification { electra, dropout, classifier }
|
ElectraForTokenClassification {
|
||||||
|
electra,
|
||||||
|
dropout,
|
||||||
|
classifier,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Forward pass through the model
|
/// Forward pass through the model
|
||||||
@ -696,16 +799,16 @@ impl ElectraForTokenClassification {
|
|||||||
/// # Example
|
/// # Example
|
||||||
///
|
///
|
||||||
/// ```no_run
|
/// ```no_run
|
||||||
///# use rust_bert::electra::{ElectraForTokenClassification, ElectraConfig};
|
/// # use rust_bert::electra::{ElectraForTokenClassification, ElectraConfig};
|
||||||
///# use tch::{nn, Device, Tensor, no_grad};
|
/// # use tch::{nn, Device, Tensor, no_grad};
|
||||||
///# use rust_bert::Config;
|
/// # use rust_bert::Config;
|
||||||
///# use std::path::Path;
|
/// # use std::path::Path;
|
||||||
///# use tch::kind::Kind::Int64;
|
/// # use tch::kind::Kind::Int64;
|
||||||
///# let config_path = Path::new("path/to/config.json");
|
/// # let config_path = Path::new("path/to/config.json");
|
||||||
///# let device = Device::Cpu;
|
/// # let device = Device::Cpu;
|
||||||
///# let vs = nn::VarStore::new(device);
|
/// # let vs = nn::VarStore::new(device);
|
||||||
///# let config = ElectraConfig::from_file(config_path);
|
/// # let config = ElectraConfig::from_file(config_path);
|
||||||
///# let electra_model: ElectraForTokenClassification = ElectraForTokenClassification::new(&vs.root(), &config);
|
/// # let electra_model: ElectraForTokenClassification = ElectraForTokenClassification::new(&vs.root(), &config);
|
||||||
/// let (batch_size, sequence_length) = (64, 128);
|
/// let (batch_size, sequence_length) = (64, 128);
|
||||||
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
|
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
|
||||||
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
|
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
|
||||||
@ -721,21 +824,26 @@ impl ElectraForTokenClassification {
|
|||||||
/// None,
|
/// None,
|
||||||
/// false)
|
/// false)
|
||||||
/// });
|
/// });
|
||||||
///
|
|
||||||
/// ```
|
/// ```
|
||||||
///
|
pub fn forward_t(
|
||||||
pub fn forward_t(&self,
|
&self,
|
||||||
input_ids: Option<Tensor>,
|
input_ids: Option<Tensor>,
|
||||||
mask: Option<Tensor>,
|
mask: Option<Tensor>,
|
||||||
token_type_ids: Option<Tensor>,
|
token_type_ids: Option<Tensor>,
|
||||||
position_ids: Option<Tensor>,
|
position_ids: Option<Tensor>,
|
||||||
input_embeds: Option<Tensor>,
|
input_embeds: Option<Tensor>,
|
||||||
train: bool)
|
train: bool,
|
||||||
-> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
|
) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
|
||||||
let (hidden_states,
|
let (hidden_states, all_hidden_states, all_attentions) = self
|
||||||
all_hidden_states,
|
.electra
|
||||||
all_attentions) = self.electra
|
.forward_t(
|
||||||
.forward_t(input_ids, mask, token_type_ids, position_ids, input_embeds, train)
|
input_ids,
|
||||||
|
mask,
|
||||||
|
token_type_ids,
|
||||||
|
position_ids,
|
||||||
|
input_embeds,
|
||||||
|
train,
|
||||||
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
let output = hidden_states
|
let output = hidden_states
|
||||||
.apply_t(&self.dropout, train)
|
.apply_t(&self.dropout, train)
|
||||||
|
@ -12,10 +12,10 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
use tch::{nn, Tensor, Kind};
|
|
||||||
use crate::common::dropout::Dropout;
|
use crate::common::dropout::Dropout;
|
||||||
use crate::electra::electra::ElectraConfig;
|
use crate::electra::electra::ElectraConfig;
|
||||||
use tch::nn::{EmbeddingConfig, embedding};
|
use tch::nn::{embedding, EmbeddingConfig};
|
||||||
|
use tch::{nn, Kind, Tensor};
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
/// # Embeddings implementation for Electra model
|
/// # Embeddings implementation for Electra model
|
||||||
@ -34,49 +34,77 @@ impl ElectraEmbeddings {
|
|||||||
..Default::default()
|
..Default::default()
|
||||||
};
|
};
|
||||||
|
|
||||||
let word_embeddings: nn::Embedding = embedding(p / "word_embeddings",
|
let word_embeddings: nn::Embedding = embedding(
|
||||||
config.vocab_size,
|
p / "word_embeddings",
|
||||||
config.embedding_size,
|
config.vocab_size,
|
||||||
embedding_config);
|
config.embedding_size,
|
||||||
|
embedding_config,
|
||||||
|
);
|
||||||
|
|
||||||
let position_embeddings: nn::Embedding = embedding(p / "position_embeddings",
|
let position_embeddings: nn::Embedding = embedding(
|
||||||
config.max_position_embeddings,
|
p / "position_embeddings",
|
||||||
config.embedding_size,
|
config.max_position_embeddings,
|
||||||
Default::default());
|
config.embedding_size,
|
||||||
|
Default::default(),
|
||||||
|
);
|
||||||
|
|
||||||
let token_type_embeddings: nn::Embedding = embedding(p / "token_type_embeddings",
|
let token_type_embeddings: nn::Embedding = embedding(
|
||||||
config.type_vocab_size,
|
p / "token_type_embeddings",
|
||||||
config.embedding_size,
|
config.type_vocab_size,
|
||||||
Default::default());
|
config.embedding_size,
|
||||||
|
Default::default(),
|
||||||
|
);
|
||||||
|
|
||||||
let layer_norm_eps = match config.layer_norm_eps {
|
let layer_norm_eps = match config.layer_norm_eps {
|
||||||
Some(value) => value,
|
Some(value) => value,
|
||||||
None => 1e-12
|
None => 1e-12,
|
||||||
};
|
};
|
||||||
let layer_norm_config = nn::LayerNormConfig { eps: layer_norm_eps, ..Default::default() };
|
let layer_norm_config = nn::LayerNormConfig {
|
||||||
let layer_norm: nn::LayerNorm = nn::layer_norm(p / "LayerNorm", vec![config.embedding_size], layer_norm_config);
|
eps: layer_norm_eps,
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
let layer_norm: nn::LayerNorm = nn::layer_norm(
|
||||||
|
p / "LayerNorm",
|
||||||
|
vec![config.embedding_size],
|
||||||
|
layer_norm_config,
|
||||||
|
);
|
||||||
let dropout: Dropout = Dropout::new(config.hidden_dropout_prob);
|
let dropout: Dropout = Dropout::new(config.hidden_dropout_prob);
|
||||||
ElectraEmbeddings { word_embeddings, position_embeddings, token_type_embeddings, layer_norm, dropout}
|
ElectraEmbeddings {
|
||||||
|
word_embeddings,
|
||||||
|
position_embeddings,
|
||||||
|
token_type_embeddings,
|
||||||
|
layer_norm,
|
||||||
|
dropout,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn forward_t(&self,
|
pub fn forward_t(
|
||||||
input_ids: Option<Tensor>,
|
&self,
|
||||||
token_type_ids: Option<Tensor>,
|
input_ids: Option<Tensor>,
|
||||||
position_ids: Option<Tensor>,
|
token_type_ids: Option<Tensor>,
|
||||||
input_embeds: Option<Tensor>,
|
position_ids: Option<Tensor>,
|
||||||
train: bool) -> Result<Tensor, &'static str> {
|
input_embeds: Option<Tensor>,
|
||||||
|
train: bool,
|
||||||
|
) -> Result<Tensor, &'static str> {
|
||||||
let (input_embeddings, input_shape) = match input_ids {
|
let (input_embeddings, input_shape) = match input_ids {
|
||||||
Some(input_value) => match input_embeds {
|
Some(input_value) => match input_embeds {
|
||||||
Some(_) => { return Err("Only one of input ids or input embeddings may be set"); }
|
Some(_) => {
|
||||||
None => (input_value.apply_t(&self.word_embeddings, train), input_value.size())
|
return Err("Only one of input ids or input embeddings may be set");
|
||||||
}
|
}
|
||||||
|
None => (
|
||||||
|
input_value.apply_t(&self.word_embeddings, train),
|
||||||
|
input_value.size(),
|
||||||
|
),
|
||||||
|
},
|
||||||
None => match input_embeds {
|
None => match input_embeds {
|
||||||
Some(embeds) => {
|
Some(embeds) => {
|
||||||
let size = vec!(embeds.size()[0], embeds.size()[1]);
|
let size = vec![embeds.size()[0], embeds.size()[1]];
|
||||||
(embeds, size)
|
(embeds, size)
|
||||||
},
|
}
|
||||||
None => { return Err("Only one of input ids or input embeddings may be set"); }
|
None => {
|
||||||
}
|
return Err("Only one of input ids or input embeddings may be set");
|
||||||
|
}
|
||||||
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
let seq_length = input_embeddings.as_ref().size()[1].to_owned();
|
let seq_length = input_embeddings.as_ref().size()[1].to_owned();
|
||||||
@ -84,19 +112,22 @@ impl ElectraEmbeddings {
|
|||||||
let position_ids = match position_ids {
|
let position_ids = match position_ids {
|
||||||
Some(value) => value,
|
Some(value) => value,
|
||||||
None => Tensor::arange(seq_length, (Kind::Int64, input_embeddings.device()))
|
None => Tensor::arange(seq_length, (Kind::Int64, input_embeddings.device()))
|
||||||
.unsqueeze(0).
|
.unsqueeze(0)
|
||||||
expand(&input_shape, true)
|
.expand(&input_shape, true),
|
||||||
};
|
};
|
||||||
|
|
||||||
let token_type_ids = match token_type_ids {
|
let token_type_ids = match token_type_ids {
|
||||||
Some(value) => value,
|
Some(value) => value,
|
||||||
None => Tensor::zeros(&input_shape, (Kind::Int64, input_embeddings.device()))
|
None => Tensor::zeros(&input_shape, (Kind::Int64, input_embeddings.device())),
|
||||||
};
|
};
|
||||||
|
|
||||||
let position_embeddings = position_ids.apply(&self.position_embeddings);
|
let position_embeddings = position_ids.apply(&self.position_embeddings);
|
||||||
let token_type_embeddings = token_type_ids.apply(&self.token_type_embeddings);
|
let token_type_embeddings = token_type_ids.apply(&self.token_type_embeddings);
|
||||||
|
|
||||||
let input_embeddings: Tensor = input_embeddings + position_embeddings + token_type_embeddings;
|
let input_embeddings: Tensor =
|
||||||
Ok(input_embeddings.apply(&self.layer_norm).apply_t(&self.dropout, train))
|
input_embeddings + position_embeddings + token_type_embeddings;
|
||||||
|
Ok(input_embeddings
|
||||||
|
.apply(&self.layer_norm)
|
||||||
|
.apply_t(&self.dropout, train))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -23,18 +23,24 @@
|
|||||||
//! Pretrained models are available and can be downloaded using RemoteResources.
|
//! Pretrained models are available and can be downloaded using RemoteResources.
|
||||||
//!
|
//!
|
||||||
//! ```no_run
|
//! ```no_run
|
||||||
//!# fn main() -> failure::Fallible<()> {
|
//! # fn main() -> failure::Fallible<()> {
|
||||||
//!#
|
//! #
|
||||||
//! use rust_tokenizers::BertTokenizer;
|
//! use rust_tokenizers::BertTokenizer;
|
||||||
//! use tch::{nn, Device};
|
//! use tch::{nn, Device};
|
||||||
//!# use std::path::PathBuf;
|
//! # use std::path::PathBuf;
|
||||||
//! use rust_bert::electra::{ElectraForMaskedLM, ElectraConfig};
|
//! use rust_bert::electra::{ElectraConfig, ElectraForMaskedLM};
|
||||||
|
//! use rust_bert::resources::{download_resource, LocalResource, Resource};
|
||||||
//! use rust_bert::Config;
|
//! use rust_bert::Config;
|
||||||
//! use rust_bert::resources::{Resource, download_resource, LocalResource};
|
|
||||||
//!
|
//!
|
||||||
//! let config_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/config.json")});
|
//! let config_resource = Resource::Local(LocalResource {
|
||||||
//! let vocab_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/vocab.txt")});
|
//! local_path: PathBuf::from("path/to/config.json"),
|
||||||
//! let weights_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/model.ot")});
|
//! });
|
||||||
|
//! let vocab_resource = Resource::Local(LocalResource {
|
||||||
|
//! local_path: PathBuf::from("path/to/vocab.txt"),
|
||||||
|
//! });
|
||||||
|
//! let weights_resource = Resource::Local(LocalResource {
|
||||||
|
//! local_path: PathBuf::from("path/to/model.ot"),
|
||||||
|
//! });
|
||||||
//! let config_path = download_resource(&config_resource)?;
|
//! let config_path = download_resource(&config_resource)?;
|
||||||
//! let vocab_path = download_resource(&vocab_resource)?;
|
//! let vocab_path = download_resource(&vocab_resource)?;
|
||||||
//! let weights_path = download_resource(&weights_resource)?;
|
//! let weights_path = download_resource(&weights_resource)?;
|
||||||
@ -45,13 +51,15 @@
|
|||||||
//! let electra_model = ElectraForMaskedLM::new(&vs.root(), &config);
|
//! let electra_model = ElectraForMaskedLM::new(&vs.root(), &config);
|
||||||
//! vs.load(weights_path)?;
|
//! vs.load(weights_path)?;
|
||||||
//!
|
//!
|
||||||
//!# Ok(())
|
//! # Ok(())
|
||||||
//!# }
|
//! # }
|
||||||
//! ```
|
//! ```
|
||||||
|
|
||||||
|
|
||||||
mod embeddings;
|
|
||||||
mod electra;
|
mod electra;
|
||||||
|
mod embeddings;
|
||||||
|
|
||||||
pub use electra::{ElectraModelResources, ElectraVocabResources, ElectraConfigResources, ElectraConfig,
|
pub use electra::{
|
||||||
ElectraModel, ElectraDiscriminator, ElectraForMaskedLM, ElectraDiscriminatorHead, ElectraGeneratorHead, ElectraForTokenClassification};
|
ElectraConfig, ElectraConfigResources, ElectraDiscriminator, ElectraDiscriminatorHead,
|
||||||
|
ElectraForMaskedLM, ElectraForTokenClassification, ElectraGeneratorHead, ElectraModel,
|
||||||
|
ElectraModelResources, ElectraVocabResources,
|
||||||
|
};
|
||||||
|
@ -12,11 +12,11 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
use tch::{Tensor, nn};
|
|
||||||
use crate::common::dropout::Dropout;
|
use crate::common::dropout::Dropout;
|
||||||
use crate::gpt2::gpt2::Gpt2Config;
|
use crate::gpt2::gpt2::Gpt2Config;
|
||||||
use tch::kind::Kind::Float;
|
use tch::kind::Kind::Float;
|
||||||
use tch::nn::{Init, Module};
|
use tch::nn::{Init, Module};
|
||||||
|
use tch::{nn, Tensor};
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct GPTConv1D {
|
pub struct GPTConv1D {
|
||||||
@ -26,7 +26,14 @@ pub struct GPTConv1D {
|
|||||||
|
|
||||||
impl GPTConv1D {
|
impl GPTConv1D {
|
||||||
pub fn new(p: &nn::Path, nf: i64, nx: i64) -> GPTConv1D {
|
pub fn new(p: &nn::Path, nf: i64, nx: i64) -> GPTConv1D {
|
||||||
let weight = p.var("weight", &[nx, nf], Init::Randn { mean: 0., stdev: 0.02 });
|
let weight = p.var(
|
||||||
|
"weight",
|
||||||
|
&[nx, nf],
|
||||||
|
Init::Randn {
|
||||||
|
mean: 0.,
|
||||||
|
stdev: 0.02,
|
||||||
|
},
|
||||||
|
);
|
||||||
let bias = p.var("bias", &[nf], Init::Const(0.));
|
let bias = p.var("bias", &[nf], Init::Const(0.));
|
||||||
GPTConv1D { weight, bias }
|
GPTConv1D { weight, bias }
|
||||||
}
|
}
|
||||||
@ -38,7 +45,6 @@ impl Module for GPTConv1D {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
pub struct Attention {
|
pub struct Attention {
|
||||||
bias: Tensor,
|
bias: Tensor,
|
||||||
c_attn: GPTConv1D,
|
c_attn: GPTConv1D,
|
||||||
@ -62,23 +68,27 @@ impl Attention {
|
|||||||
|
|
||||||
let attn_pdrop = match config.attn_pdrop {
|
let attn_pdrop = match config.attn_pdrop {
|
||||||
Some(value) => value,
|
Some(value) => value,
|
||||||
None => 0.1
|
None => 0.1,
|
||||||
};
|
};
|
||||||
|
|
||||||
let resid_pdrop = match config.resid_pdrop {
|
let resid_pdrop = match config.resid_pdrop {
|
||||||
Some(value) => value,
|
Some(value) => value,
|
||||||
None => 0.1
|
None => 0.1,
|
||||||
};
|
};
|
||||||
|
|
||||||
let output_attentions = match config.output_attentions {
|
let output_attentions = match config.output_attentions {
|
||||||
Some(value) => value,
|
Some(value) => value,
|
||||||
None => false
|
None => false,
|
||||||
};
|
};
|
||||||
|
|
||||||
let attn_dropout = Dropout::new(attn_pdrop);
|
let attn_dropout = Dropout::new(attn_pdrop);
|
||||||
let resid_dropout = Dropout::new(resid_pdrop);
|
let resid_dropout = Dropout::new(resid_pdrop);
|
||||||
|
|
||||||
assert_eq!(config.n_embd % config.n_head, 0, "Attention hidden states not a multiple of the number of heads");
|
assert_eq!(
|
||||||
|
config.n_embd % config.n_head,
|
||||||
|
0,
|
||||||
|
"Attention hidden states not a multiple of the number of heads"
|
||||||
|
);
|
||||||
let dim_per_head = config.n_embd / config.n_head;
|
let dim_per_head = config.n_embd / config.n_head;
|
||||||
|
|
||||||
Attention {
|
Attention {
|
||||||
@ -105,19 +115,31 @@ impl Attention {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn flatten(&self, x: Tensor) -> Tensor {
|
fn flatten(&self, x: Tensor) -> Tensor {
|
||||||
x.transpose(1, 2).contiguous().view((x.size()[0], -1, &self.n_head * self.dim_per_head))
|
x.transpose(1, 2)
|
||||||
|
.contiguous()
|
||||||
|
.view((x.size()[0], -1, &self.n_head * self.dim_per_head))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn attention(&self, q: &Tensor, k: &Tensor, v: &Tensor, attention_mask: &Option<Tensor>, train: bool)
|
fn attention(
|
||||||
-> (Tensor, Option<Tensor>) {
|
&self,
|
||||||
|
q: &Tensor,
|
||||||
|
k: &Tensor,
|
||||||
|
v: &Tensor,
|
||||||
|
attention_mask: &Option<Tensor>,
|
||||||
|
train: bool,
|
||||||
|
) -> (Tensor, Option<Tensor>) {
|
||||||
let mut w = q.matmul(&k);
|
let mut w = q.matmul(&k);
|
||||||
if self.scale { w = w / (*v.size().last().unwrap() as f64).sqrt(); }
|
if self.scale {
|
||||||
|
w = w / (*v.size().last().unwrap() as f64).sqrt();
|
||||||
|
}
|
||||||
|
|
||||||
let (nd, ns) = (w.size()[2], w.size()[3]);
|
let (nd, ns) = (w.size()[2], w.size()[3]);
|
||||||
let b = self.bias.narrow(2, ns - nd, nd).narrow(3, 0, ns);
|
let b = self.bias.narrow(2, ns - nd, nd).narrow(3, 0, ns);
|
||||||
|
|
||||||
let mut w: Tensor = w * &b + 1e4 * (&b - 1);
|
let mut w: Tensor = w * &b + 1e4 * (&b - 1);
|
||||||
if let Some(mask) = attention_mask { w = w + mask; }
|
if let Some(mask) = attention_mask {
|
||||||
|
w = w + mask;
|
||||||
|
}
|
||||||
w = w.softmax(-1, Float).apply_t(&self.attn_dropout, train);
|
w = w.softmax(-1, Float).apply_t(&self.attn_dropout, train);
|
||||||
let output = w.matmul(&v);
|
let output = w.matmul(&v);
|
||||||
|
|
||||||
@ -128,29 +150,36 @@ impl Attention {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn forward_t(&self, x: &Tensor, layer_past: &Option<Tensor>, attention_mask: &Option<Tensor>, train: bool)
|
pub fn forward_t(
|
||||||
-> (Tensor, Tensor, Option<Tensor>) {
|
&self,
|
||||||
|
x: &Tensor,
|
||||||
|
layer_past: &Option<Tensor>,
|
||||||
|
attention_mask: &Option<Tensor>,
|
||||||
|
train: bool,
|
||||||
|
) -> (Tensor, Tensor, Option<Tensor>) {
|
||||||
let x = x.apply(&self.c_attn).split(self.n_state, 2);
|
let x = x.apply(&self.c_attn).split(self.n_state, 2);
|
||||||
|
|
||||||
let (query, key, value) =
|
let (query, key, value) = (
|
||||||
(
|
self.split_heads(&x[0], false),
|
||||||
self.split_heads(&x[0], false),
|
self.split_heads(&x[1], true),
|
||||||
self.split_heads(&x[1], true),
|
self.split_heads(&x[2], false),
|
||||||
self.split_heads(&x[2], false)
|
);
|
||||||
);
|
|
||||||
let (key, value) = match layer_past {
|
let (key, value) = match layer_past {
|
||||||
Some(past) => {
|
Some(past) => {
|
||||||
let key = Tensor::cat(&[past.get(0).transpose(-2, -1), key], -1);
|
let key = Tensor::cat(&[past.get(0).transpose(-2, -1), key], -1);
|
||||||
let value = Tensor::cat(&[past.get(1), value], -2);
|
let value = Tensor::cat(&[past.get(1), value], -2);
|
||||||
(key, value)
|
(key, value)
|
||||||
}
|
}
|
||||||
None => (key, value)
|
None => (key, value),
|
||||||
};
|
};
|
||||||
let present = Tensor::stack(&[key.transpose(-2, -1), value.copy()], 0);
|
let present = Tensor::stack(&[key.transpose(-2, -1), value.copy()], 0);
|
||||||
let (a, attentions) = self.attention(&query, &key, &value, &attention_mask, train);
|
let (a, attentions) = self.attention(&query, &key, &value, &attention_mask, train);
|
||||||
|
|
||||||
let a = self.flatten(a).apply(&self.c_proj).apply_t(&self.resid_dropout, train);
|
let a = self
|
||||||
|
.flatten(a)
|
||||||
|
.apply(&self.c_proj)
|
||||||
|
.apply_t(&self.resid_dropout, train);
|
||||||
|
|
||||||
(a, present, attentions)
|
(a, present, attentions)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
522
src/gpt2/gpt2.rs
522
src/gpt2/gpt2.rs
@ -12,16 +12,16 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
use serde::{Deserialize, Serialize};
|
|
||||||
use tch::{nn, Tensor};
|
|
||||||
use crate::common::dropout::Dropout;
|
use crate::common::dropout::Dropout;
|
||||||
use tch::nn::embedding;
|
use crate::common::linear::{linear_no_bias, LinearNoBias};
|
||||||
use crate::gpt2::transformer::Block;
|
use crate::gpt2::transformer::Block;
|
||||||
use tch::kind::Kind::Int64;
|
use crate::pipelines::generation::{Cache, LMHeadModel};
|
||||||
use std::borrow::BorrowMut;
|
|
||||||
use crate::common::linear::{LinearNoBias, linear_no_bias};
|
|
||||||
use crate::Config;
|
use crate::Config;
|
||||||
use crate::pipelines::generation::{LMHeadModel, Cache};
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::borrow::BorrowMut;
|
||||||
|
use tch::kind::Kind::Int64;
|
||||||
|
use tch::nn::embedding;
|
||||||
|
use tch::{nn, Tensor};
|
||||||
|
|
||||||
/// # GPT2 Pretrained model weight files
|
/// # GPT2 Pretrained model weight files
|
||||||
pub struct Gpt2ModelResources;
|
pub struct Gpt2ModelResources;
|
||||||
@ -37,54 +37,114 @@ pub struct Gpt2MergesResources;
|
|||||||
|
|
||||||
impl Gpt2ModelResources {
|
impl Gpt2ModelResources {
|
||||||
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
|
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
|
||||||
pub const GPT2: (&'static str, &'static str) = ("gpt2/model.ot", "https://cdn.huggingface.co/gpt2-rust_model.ot");
|
pub const GPT2: (&'static str, &'static str) = (
|
||||||
|
"gpt2/model.ot",
|
||||||
|
"https://cdn.huggingface.co/gpt2-rust_model.ot",
|
||||||
|
);
|
||||||
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
|
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
|
||||||
pub const GPT2_MEDIUM: (&'static str, &'static str) = ("gpt2-medium/model.ot", "https://cdn.huggingface.co/gpt2-medium-rust_model.ot");
|
pub const GPT2_MEDIUM: (&'static str, &'static str) = (
|
||||||
|
"gpt2-medium/model.ot",
|
||||||
|
"https://cdn.huggingface.co/gpt2-medium-rust_model.ot",
|
||||||
|
);
|
||||||
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
|
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
|
||||||
pub const GPT2_LARGE: (&'static str, &'static str) = ("gpt2-large/model.ot", "https://cdn.huggingface.co/gpt2-large-rust_model.ot");
|
pub const GPT2_LARGE: (&'static str, &'static str) = (
|
||||||
|
"gpt2-large/model.ot",
|
||||||
|
"https://cdn.huggingface.co/gpt2-large-rust_model.ot",
|
||||||
|
);
|
||||||
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
|
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
|
||||||
pub const GPT2_XL: (&'static str, &'static str) = ("gpt2-xl/model.ot", "https://cdn.huggingface.co/gpt2-xl-rust_model.ot");
|
pub const GPT2_XL: (&'static str, &'static str) = (
|
||||||
|
"gpt2-xl/model.ot",
|
||||||
|
"https://cdn.huggingface.co/gpt2-xl-rust_model.ot",
|
||||||
|
);
|
||||||
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
|
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
|
||||||
pub const DISTIL_GPT2: (&'static str, &'static str) = ("distilgpt2/model.ot", "https://cdn.huggingface.co/distilgpt2-rust_model.ot");
|
pub const DISTIL_GPT2: (&'static str, &'static str) = (
|
||||||
|
"distilgpt2/model.ot",
|
||||||
|
"https://cdn.huggingface.co/distilgpt2-rust_model.ot",
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Gpt2ConfigResources {
|
impl Gpt2ConfigResources {
|
||||||
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
|
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
|
||||||
pub const GPT2: (&'static str, &'static str) = ("gpt2/config.json", "https://cdn.huggingface.co/gpt2-config.json");
|
pub const GPT2: (&'static str, &'static str) = (
|
||||||
|
"gpt2/config.json",
|
||||||
|
"https://cdn.huggingface.co/gpt2-config.json",
|
||||||
|
);
|
||||||
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
|
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
|
||||||
pub const GPT2_MEDIUM: (&'static str, &'static str) = ("gpt2-medium/config.json", "https://cdn.huggingface.co/gpt2-medium-config.json");
|
pub const GPT2_MEDIUM: (&'static str, &'static str) = (
|
||||||
|
"gpt2-medium/config.json",
|
||||||
|
"https://cdn.huggingface.co/gpt2-medium-config.json",
|
||||||
|
);
|
||||||
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
|
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
|
||||||
pub const GPT2_LARGE: (&'static str, &'static str) = ("gpt2-large/config.json", "https://cdn.huggingface.co/gpt2-large-config.json");
|
pub const GPT2_LARGE: (&'static str, &'static str) = (
|
||||||
|
"gpt2-large/config.json",
|
||||||
|
"https://cdn.huggingface.co/gpt2-large-config.json",
|
||||||
|
);
|
||||||
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
|
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
|
||||||
pub const GPT2_XL: (&'static str, &'static str) = ("gpt2-xl/config.json", "https://cdn.huggingface.co/gpt2-xl-config.json");
|
pub const GPT2_XL: (&'static str, &'static str) = (
|
||||||
|
"gpt2-xl/config.json",
|
||||||
|
"https://cdn.huggingface.co/gpt2-xl-config.json",
|
||||||
|
);
|
||||||
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
|
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
|
||||||
pub const DISTIL_GPT2: (&'static str, &'static str) = ("distilgpt2/config.json", "https://cdn.huggingface.co/distilgpt2-config.json");
|
pub const DISTIL_GPT2: (&'static str, &'static str) = (
|
||||||
|
"distilgpt2/config.json",
|
||||||
|
"https://cdn.huggingface.co/distilgpt2-config.json",
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Gpt2VocabResources {
|
impl Gpt2VocabResources {
|
||||||
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
|
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
|
||||||
pub const GPT2: (&'static str, &'static str) = ("gpt2/vocab.txt", "https://cdn.huggingface.co/gpt2-vocab.json");
|
pub const GPT2: (&'static str, &'static str) = (
|
||||||
|
"gpt2/vocab.txt",
|
||||||
|
"https://cdn.huggingface.co/gpt2-vocab.json",
|
||||||
|
);
|
||||||
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
|
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
|
||||||
pub const GPT2_MEDIUM: (&'static str, &'static str) = ("gpt2-medium/vocab.txt", "https://cdn.huggingface.co/gpt2-medium-vocab.json");
|
pub const GPT2_MEDIUM: (&'static str, &'static str) = (
|
||||||
|
"gpt2-medium/vocab.txt",
|
||||||
|
"https://cdn.huggingface.co/gpt2-medium-vocab.json",
|
||||||
|
);
|
||||||
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
|
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
|
||||||
pub const GPT2_LARGE: (&'static str, &'static str) = ("gpt2-large/vocab.txt", "https://cdn.huggingface.co/gpt2-large-vocab.json");
|
pub const GPT2_LARGE: (&'static str, &'static str) = (
|
||||||
|
"gpt2-large/vocab.txt",
|
||||||
|
"https://cdn.huggingface.co/gpt2-large-vocab.json",
|
||||||
|
);
|
||||||
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
|
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
|
||||||
pub const GPT2_XL: (&'static str, &'static str) = ("gpt2-xl/vocab.txt", "https://cdn.huggingface.co/gpt2-xl-vocab.json");
|
pub const GPT2_XL: (&'static str, &'static str) = (
|
||||||
|
"gpt2-xl/vocab.txt",
|
||||||
|
"https://cdn.huggingface.co/gpt2-xl-vocab.json",
|
||||||
|
);
|
||||||
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
|
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
|
||||||
pub const DISTIL_GPT2: (&'static str, &'static str) = ("distilgpt2/vocab.txt", "https://cdn.huggingface.co/distilgpt2-vocab.json");
|
pub const DISTIL_GPT2: (&'static str, &'static str) = (
|
||||||
|
"distilgpt2/vocab.txt",
|
||||||
|
"https://cdn.huggingface.co/distilgpt2-vocab.json",
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Gpt2MergesResources {
|
impl Gpt2MergesResources {
|
||||||
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
|
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
|
||||||
pub const GPT2: (&'static str, &'static str) = ("gpt2/merges.txt", "https://cdn.huggingface.co/gpt2-merges.txt");
|
pub const GPT2: (&'static str, &'static str) = (
|
||||||
|
"gpt2/merges.txt",
|
||||||
|
"https://cdn.huggingface.co/gpt2-merges.txt",
|
||||||
|
);
|
||||||
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
|
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
|
||||||
pub const GPT2_MEDIUM: (&'static str, &'static str) = ("gpt2-medium/merges.txt", "https://cdn.huggingface.co/gpt2-medium-merges.txt");
|
pub const GPT2_MEDIUM: (&'static str, &'static str) = (
|
||||||
|
"gpt2-medium/merges.txt",
|
||||||
|
"https://cdn.huggingface.co/gpt2-medium-merges.txt",
|
||||||
|
);
|
||||||
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
|
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
|
||||||
pub const GPT2_LARGE: (&'static str, &'static str) = ("gpt2-large/merges.txt", "https://cdn.huggingface.co/gpt2-large-merges.txt");
|
pub const GPT2_LARGE: (&'static str, &'static str) = (
|
||||||
|
"gpt2-large/merges.txt",
|
||||||
|
"https://cdn.huggingface.co/gpt2-large-merges.txt",
|
||||||
|
);
|
||||||
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
|
/// Shared under Modified MIT license by the OpenAI team at https://github.com/openai/gpt-2/blob/master/LICENSE. Modified with conversion to C-array format.
|
||||||
pub const GPT2_XL: (&'static str, &'static str) = ("gpt2-xl/merges.txt", "https://cdn.huggingface.co/gpt2-xl-merges.txt");
|
pub const GPT2_XL: (&'static str, &'static str) = (
|
||||||
|
"gpt2-xl/merges.txt",
|
||||||
|
"https://cdn.huggingface.co/gpt2-xl-merges.txt",
|
||||||
|
);
|
||||||
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
|
/// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models. Modified with conversion to C-array format.
|
||||||
pub const DISTIL_GPT2: (&'static str, &'static str) = ("distilgpt2/merges.txt", "https://cdn.huggingface.co/distilgpt2-merges.txt");
|
pub const DISTIL_GPT2: (&'static str, &'static str) = (
|
||||||
|
"distilgpt2/merges.txt",
|
||||||
|
"https://cdn.huggingface.co/distilgpt2-merges.txt",
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(non_camel_case_types)]
|
#[allow(non_camel_case_types)]
|
||||||
@ -156,10 +216,10 @@ impl Gpt2Model {
|
|||||||
/// # Example
|
/// # Example
|
||||||
///
|
///
|
||||||
/// ```no_run
|
/// ```no_run
|
||||||
/// use tch::{nn, Device};
|
/// use rust_bert::gpt2::{Gpt2Config, Gpt2Model};
|
||||||
/// use rust_bert::Config;
|
/// use rust_bert::Config;
|
||||||
/// use std::path::Path;
|
/// use std::path::Path;
|
||||||
/// use rust_bert::gpt2::{Gpt2Config, Gpt2Model};
|
/// use tch::{nn, Device};
|
||||||
///
|
///
|
||||||
/// let config_path = Path::new("path/to/config.json");
|
/// let config_path = Path::new("path/to/config.json");
|
||||||
/// let device = Device::Cpu;
|
/// let device = Device::Cpu;
|
||||||
@ -167,37 +227,58 @@ impl Gpt2Model {
|
|||||||
/// let config = Gpt2Config::from_file(config_path);
|
/// let config = Gpt2Config::from_file(config_path);
|
||||||
/// let gpt2: Gpt2Model = Gpt2Model::new(&(&p.root() / "gpt2"), &config);
|
/// let gpt2: Gpt2Model = Gpt2Model::new(&(&p.root() / "gpt2"), &config);
|
||||||
/// ```
|
/// ```
|
||||||
///
|
|
||||||
pub fn new(p: &nn::Path, config: &Gpt2Config) -> Gpt2Model {
|
pub fn new(p: &nn::Path, config: &Gpt2Config) -> Gpt2Model {
|
||||||
let p = &(p / "transformer");
|
let p = &(p / "transformer");
|
||||||
let wte = embedding(&(p / "wte"), config.vocab_size, config.n_embd, Default::default());
|
let wte = embedding(
|
||||||
let wpe = embedding(&(p / "wpe"), config.n_positions, config.n_embd, Default::default());
|
&(p / "wte"),
|
||||||
|
config.vocab_size,
|
||||||
|
config.n_embd,
|
||||||
|
Default::default(),
|
||||||
|
);
|
||||||
|
let wpe = embedding(
|
||||||
|
&(p / "wpe"),
|
||||||
|
config.n_positions,
|
||||||
|
config.n_embd,
|
||||||
|
Default::default(),
|
||||||
|
);
|
||||||
|
|
||||||
let embd_pdrop = match config.embd_pdrop {
|
let embd_pdrop = match config.embd_pdrop {
|
||||||
Some(value) => value,
|
Some(value) => value,
|
||||||
None => 0.1
|
None => 0.1,
|
||||||
};
|
};
|
||||||
let drop = Dropout::new(embd_pdrop);
|
let drop = Dropout::new(embd_pdrop);
|
||||||
let layer_norm_config = nn::LayerNormConfig { eps: config.layer_norm_epsilon, ..Default::default() };
|
let layer_norm_config = nn::LayerNormConfig {
|
||||||
|
eps: config.layer_norm_epsilon,
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
let ln_f = nn::layer_norm(p / "ln_f", vec![config.n_embd], layer_norm_config);
|
let ln_f = nn::layer_norm(p / "ln_f", vec![config.n_embd], layer_norm_config);
|
||||||
let mut h: Vec<Block> = vec!();
|
let mut h: Vec<Block> = vec![];
|
||||||
let h_path = &(p / "h");
|
let h_path = &(p / "h");
|
||||||
for layer_index in 0..config.n_layer {
|
for layer_index in 0..config.n_layer {
|
||||||
h.push(Block::new(&(h_path / layer_index), config, true));
|
h.push(Block::new(&(h_path / layer_index), config, true));
|
||||||
};
|
}
|
||||||
let output_attentions = match config.output_attentions {
|
let output_attentions = match config.output_attentions {
|
||||||
Some(value) => value,
|
Some(value) => value,
|
||||||
None => false
|
None => false,
|
||||||
};
|
};
|
||||||
let output_past = match config.output_past {
|
let output_past = match config.output_past {
|
||||||
Some(value) => value,
|
Some(value) => value,
|
||||||
None => true
|
None => true,
|
||||||
};
|
};
|
||||||
let output_hidden_states = match config.output_hidden_states {
|
let output_hidden_states = match config.output_hidden_states {
|
||||||
Some(value) => value,
|
Some(value) => value,
|
||||||
None => false
|
None => false,
|
||||||
};
|
};
|
||||||
Gpt2Model { wte, wpe, drop, ln_f, h, output_past, output_hidden_states, output_attentions }
|
Gpt2Model {
|
||||||
|
wte,
|
||||||
|
wpe,
|
||||||
|
drop,
|
||||||
|
ln_f,
|
||||||
|
h,
|
||||||
|
output_past,
|
||||||
|
output_hidden_states,
|
||||||
|
output_attentions,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Forward pass through the model
|
/// Forward pass through the model
|
||||||
@ -222,63 +303,101 @@ impl Gpt2Model {
|
|||||||
/// # Example
|
/// # Example
|
||||||
///
|
///
|
||||||
/// ```no_run
|
/// ```no_run
|
||||||
///# use tch::{nn, Device, Tensor, no_grad};
|
/// # use tch::{nn, Device, Tensor, no_grad};
|
||||||
///# use rust_bert::Config;
|
/// # use rust_bert::Config;
|
||||||
///# use std::path::Path;
|
/// # use std::path::Path;
|
||||||
///# use tch::kind::Kind::{Int64, Double};
|
/// # use tch::kind::Kind::{Int64, Double};
|
||||||
/// use rust_bert::gpt2::{Gpt2Model, Gpt2Config};
|
/// use rust_bert::gpt2::{Gpt2Config, Gpt2Model};
|
||||||
///# let config_path = Path::new("path/to/config.json");
|
/// # let config_path = Path::new("path/to/config.json");
|
||||||
///# let vocab_path = Path::new("path/to/vocab.txt");
|
/// # let vocab_path = Path::new("path/to/vocab.txt");
|
||||||
///# let device = Device::Cpu;
|
/// # let device = Device::Cpu;
|
||||||
///# let vs = nn::VarStore::new(device);
|
/// # let vs = nn::VarStore::new(device);
|
||||||
///# let config = Gpt2Config::from_file(config_path);
|
/// # let config = Gpt2Config::from_file(config_path);
|
||||||
///# let gpt2_model: Gpt2Model = Gpt2Model::new(&vs.root(), &config);
|
/// # let gpt2_model: Gpt2Model = Gpt2Model::new(&vs.root(), &config);
|
||||||
/// let (batch_size, sequence_length, past_sequence_length) = (64, 128, 56);
|
/// let (batch_size, sequence_length, past_sequence_length) = (64, 128, 56);
|
||||||
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
|
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
|
||||||
/// let mut past: Vec<Tensor> = Vec::with_capacity(config.n_layer as usize);
|
/// let mut past: Vec<Tensor> = Vec::with_capacity(config.n_layer as usize);
|
||||||
/// for _ in 0..config.n_layer as usize {
|
/// for _ in 0..config.n_layer as usize {
|
||||||
/// past.push(Tensor::rand(&[2, batch_size, config.n_head, past_sequence_length, config.n_embd / config.n_head], (Double, device)))
|
/// past.push(Tensor::rand(
|
||||||
|
/// &[
|
||||||
|
/// 2,
|
||||||
|
/// batch_size,
|
||||||
|
/// config.n_head,
|
||||||
|
/// past_sequence_length,
|
||||||
|
/// config.n_embd / config.n_head,
|
||||||
|
/// ],
|
||||||
|
/// (Double, device),
|
||||||
|
/// ))
|
||||||
/// }
|
/// }
|
||||||
/// let attention_mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
|
/// let attention_mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
|
||||||
/// let token_type_ids = Tensor::ones(&[batch_size, sequence_length], (Int64, device));
|
/// let token_type_ids = Tensor::ones(&[batch_size, sequence_length], (Int64, device));
|
||||||
/// let position_ids = Tensor::arange(sequence_length, (Int64, device)).expand(&[batch_size, sequence_length], true);
|
/// let position_ids = Tensor::arange(sequence_length, (Int64, device))
|
||||||
///
|
/// .expand(&[batch_size, sequence_length], true);
|
||||||
/// let (output, past, hidden_states, attentions) = no_grad(|| {
|
|
||||||
/// gpt2_model
|
|
||||||
/// .forward_t(&Some(input_tensor),
|
|
||||||
/// &Some(past),
|
|
||||||
/// &Some(attention_mask),
|
|
||||||
/// &Some(token_type_ids),
|
|
||||||
/// &Some(position_ids),
|
|
||||||
/// &None,
|
|
||||||
/// false).unwrap()
|
|
||||||
/// });
|
|
||||||
///
|
///
|
||||||
|
/// let (output, past, hidden_states, attentions) = no_grad(|| {
|
||||||
|
/// gpt2_model
|
||||||
|
/// .forward_t(
|
||||||
|
/// &Some(input_tensor),
|
||||||
|
/// &Some(past),
|
||||||
|
/// &Some(attention_mask),
|
||||||
|
/// &Some(token_type_ids),
|
||||||
|
/// &Some(position_ids),
|
||||||
|
/// &None,
|
||||||
|
/// false,
|
||||||
|
/// )
|
||||||
|
/// .unwrap()
|
||||||
|
/// });
|
||||||
/// ```
|
/// ```
|
||||||
///
|
pub fn forward_t(
|
||||||
pub fn forward_t(&self,
|
&self,
|
||||||
input_ids: &Option<Tensor>,
|
input_ids: &Option<Tensor>,
|
||||||
layer_past: &Option<Vec<Tensor>>,
|
layer_past: &Option<Vec<Tensor>>,
|
||||||
attention_mask: &Option<Tensor>,
|
attention_mask: &Option<Tensor>,
|
||||||
token_type_ids: &Option<Tensor>,
|
token_type_ids: &Option<Tensor>,
|
||||||
position_ids: &Option<Tensor>,
|
position_ids: &Option<Tensor>,
|
||||||
input_embeds: &Option<Tensor>,
|
input_embeds: &Option<Tensor>,
|
||||||
train: bool) -> Result<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
|
train: bool,
|
||||||
|
) -> Result<
|
||||||
|
(
|
||||||
|
Tensor,
|
||||||
|
Option<Vec<Tensor>>,
|
||||||
|
Option<Vec<Tensor>>,
|
||||||
|
Option<Vec<Tensor>>,
|
||||||
|
),
|
||||||
|
&'static str,
|
||||||
|
> {
|
||||||
let (input_embeddings, seq_length) = match input_ids {
|
let (input_embeddings, seq_length) = match input_ids {
|
||||||
Some(input_value) => match input_embeds {
|
Some(input_value) => match input_embeds {
|
||||||
Some(_) => { return Err("Only one of input ids or input embeddings may be set"); }
|
Some(_) => {
|
||||||
None => (input_value.apply(&self.wte), *input_value.size().last().unwrap())
|
return Err("Only one of input ids or input embeddings may be set");
|
||||||
}
|
}
|
||||||
|
None => (
|
||||||
|
input_value.apply(&self.wte),
|
||||||
|
*input_value.size().last().unwrap(),
|
||||||
|
),
|
||||||
|
},
|
||||||
None => match input_embeds {
|
None => match input_embeds {
|
||||||
Some(embeds) => (embeds.copy(), embeds.size()[1]),
|
Some(embeds) => (embeds.copy(), embeds.size()[1]),
|
||||||
None => { return Err("At least one of input ids or input embeddings must be set"); }
|
None => {
|
||||||
}
|
return Err("At least one of input ids or input embeddings must be set");
|
||||||
|
}
|
||||||
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
let (layer_past, layer_past_length) = match layer_past {
|
let (layer_past, layer_past_length) = match layer_past {
|
||||||
Some(value) => {
|
Some(value) => {
|
||||||
assert_eq!(value.len(), self.h.len(), "Past activations vector must be of length equal to the number of layers");
|
assert_eq!(
|
||||||
(value.iter().map(|v| Some(v.copy())).collect::<Vec<Option<Tensor>>>(), value[0].size()[3])
|
value.len(),
|
||||||
|
self.h.len(),
|
||||||
|
"Past activations vector must be of length equal to the number of layers"
|
||||||
|
);
|
||||||
|
(
|
||||||
|
value
|
||||||
|
.iter()
|
||||||
|
.map(|v| Some(v.copy()))
|
||||||
|
.collect::<Vec<Option<Tensor>>>(),
|
||||||
|
value[0].size()[3],
|
||||||
|
)
|
||||||
}
|
}
|
||||||
None => {
|
None => {
|
||||||
let mut out = Vec::with_capacity(self.h.len());
|
let mut out = Vec::with_capacity(self.h.len());
|
||||||
@ -289,31 +408,45 @@ impl Gpt2Model {
|
|||||||
|
|
||||||
let position_ids = match position_ids {
|
let position_ids = match position_ids {
|
||||||
Some(value) => value.copy(),
|
Some(value) => value.copy(),
|
||||||
None => Tensor::arange1(layer_past_length, seq_length + layer_past_length, (Int64, input_embeddings.device())).unsqueeze(0)
|
None => Tensor::arange1(
|
||||||
|
layer_past_length,
|
||||||
|
seq_length + layer_past_length,
|
||||||
|
(Int64, input_embeddings.device()),
|
||||||
|
)
|
||||||
|
.unsqueeze(0),
|
||||||
};
|
};
|
||||||
|
|
||||||
let attention_mask: Option<Tensor> = match attention_mask {
|
let attention_mask: Option<Tensor> = match attention_mask {
|
||||||
Some(value) => {
|
Some(value) => Some(
|
||||||
Some(
|
(value
|
||||||
(value
|
.view((input_embeddings.size()[0], -1))
|
||||||
.view((input_embeddings.size()[0], -1))
|
.unsqueeze(1)
|
||||||
.unsqueeze(1)
|
.unsqueeze(2)
|
||||||
.unsqueeze(2)
|
- 1.0)
|
||||||
- 1.0
|
* 10000.0,
|
||||||
) * 10000.0)
|
),
|
||||||
}
|
None => None,
|
||||||
None => None
|
|
||||||
};
|
};
|
||||||
|
|
||||||
let position_embeds = position_ids.apply(&self.wpe);
|
let position_embeds = position_ids.apply(&self.wpe);
|
||||||
let token_type_embeds = match token_type_ids {
|
let token_type_embeds = match token_type_ids {
|
||||||
Some(value) => value.apply(&self.wte),
|
Some(value) => value.apply(&self.wte),
|
||||||
None => Tensor::zeros_like(&position_embeds)
|
None => Tensor::zeros_like(&position_embeds),
|
||||||
|
};
|
||||||
|
let mut hidden_state: Tensor =
|
||||||
|
(input_embeddings + position_embeds + token_type_embeds).apply_t(&self.drop, train);
|
||||||
|
let mut all_presents: Option<Vec<Tensor>> =
|
||||||
|
if self.output_past { Some(vec![]) } else { None };
|
||||||
|
let mut all_hidden_states: Option<Vec<Tensor>> = if self.output_hidden_states {
|
||||||
|
Some(vec![])
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
let mut all_attentions: Option<Vec<Tensor>> = if self.output_attentions {
|
||||||
|
Some(vec![])
|
||||||
|
} else {
|
||||||
|
None
|
||||||
};
|
};
|
||||||
let mut hidden_state: Tensor = (input_embeddings + position_embeds + token_type_embeds).apply_t(&self.drop, train);
|
|
||||||
let mut all_presents: Option<Vec<Tensor>> = if self.output_past { Some(vec!()) } else { None };
|
|
||||||
let mut all_hidden_states: Option<Vec<Tensor>> = if self.output_hidden_states { Some(vec!()) } else { None };
|
|
||||||
let mut all_attentions: Option<Vec<Tensor>> = if self.output_attentions { Some(vec!()) } else { None };
|
|
||||||
|
|
||||||
let mut layer_iter = self.h.iter().zip(layer_past);
|
let mut layer_iter = self.h.iter().zip(layer_past);
|
||||||
loop {
|
loop {
|
||||||
@ -333,11 +466,16 @@ impl Gpt2Model {
|
|||||||
attentions.push(temp.2.as_ref().unwrap().copy());
|
attentions.push(temp.2.as_ref().unwrap().copy());
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
None => break
|
None => break,
|
||||||
};
|
};
|
||||||
};
|
}
|
||||||
|
|
||||||
Ok((hidden_state.apply(&self.ln_f), all_presents, all_hidden_states, all_attentions))
|
Ok((
|
||||||
|
hidden_state.apply(&self.ln_f),
|
||||||
|
all_presents,
|
||||||
|
all_hidden_states,
|
||||||
|
all_attentions,
|
||||||
|
))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -362,10 +500,10 @@ impl GPT2LMHeadModel {
|
|||||||
/// # Example
|
/// # Example
|
||||||
///
|
///
|
||||||
/// ```no_run
|
/// ```no_run
|
||||||
/// use tch::{nn, Device};
|
/// use rust_bert::gpt2::{GPT2LMHeadModel, Gpt2Config};
|
||||||
/// use rust_bert::Config;
|
/// use rust_bert::Config;
|
||||||
/// use std::path::Path;
|
/// use std::path::Path;
|
||||||
/// use rust_bert::gpt2::{Gpt2Config, GPT2LMHeadModel};
|
/// use tch::{nn, Device};
|
||||||
///
|
///
|
||||||
/// let config_path = Path::new("path/to/config.json");
|
/// let config_path = Path::new("path/to/config.json");
|
||||||
/// let device = Device::Cpu;
|
/// let device = Device::Cpu;
|
||||||
@ -373,11 +511,18 @@ impl GPT2LMHeadModel {
|
|||||||
/// let config = Gpt2Config::from_file(config_path);
|
/// let config = Gpt2Config::from_file(config_path);
|
||||||
/// let gpt2: GPT2LMHeadModel = GPT2LMHeadModel::new(&(&p.root() / "gpt2"), &config);
|
/// let gpt2: GPT2LMHeadModel = GPT2LMHeadModel::new(&(&p.root() / "gpt2"), &config);
|
||||||
/// ```
|
/// ```
|
||||||
///
|
|
||||||
pub fn new(p: &nn::Path, config: &Gpt2Config) -> GPT2LMHeadModel {
|
pub fn new(p: &nn::Path, config: &Gpt2Config) -> GPT2LMHeadModel {
|
||||||
let transformer = Gpt2Model::new(&p, config);
|
let transformer = Gpt2Model::new(&p, config);
|
||||||
let lm_head = linear_no_bias(&(p / "lm_head"), config.n_embd, config.vocab_size, Default::default());
|
let lm_head = linear_no_bias(
|
||||||
GPT2LMHeadModel { transformer, lm_head }
|
&(p / "lm_head"),
|
||||||
|
config.n_embd,
|
||||||
|
config.vocab_size,
|
||||||
|
Default::default(),
|
||||||
|
);
|
||||||
|
GPT2LMHeadModel {
|
||||||
|
transformer,
|
||||||
|
lm_head,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -408,75 +553,104 @@ impl LMHeadModel for GPT2LMHeadModel {
|
|||||||
/// # Example
|
/// # Example
|
||||||
///
|
///
|
||||||
/// ```no_run
|
/// ```no_run
|
||||||
///# use tch::{nn, Device, Tensor, no_grad};
|
/// # use tch::{nn, Device, Tensor, no_grad};
|
||||||
///# use rust_bert::Config;
|
/// # use rust_bert::Config;
|
||||||
///# use std::path::Path;
|
/// # use std::path::Path;
|
||||||
///# use tch::kind::Kind::{Int64, Double};
|
/// # use tch::kind::Kind::{Int64, Double};
|
||||||
/// use rust_bert::gpt2::{Gpt2Config, GPT2LMHeadModel};
|
/// use rust_bert::gpt2::{GPT2LMHeadModel, Gpt2Config};
|
||||||
/// use rust_bert::pipelines::generation::{LMHeadModel, Cache};
|
/// use rust_bert::pipelines::generation::{Cache, LMHeadModel};
|
||||||
///# let config_path = Path::new("path/to/config.json");
|
/// # let config_path = Path::new("path/to/config.json");
|
||||||
///# let vocab_path = Path::new("path/to/vocab.txt");
|
/// # let vocab_path = Path::new("path/to/vocab.txt");
|
||||||
///# let device = Device::Cpu;
|
/// # let device = Device::Cpu;
|
||||||
///# let vs = nn::VarStore::new(device);
|
/// # let vs = nn::VarStore::new(device);
|
||||||
///# let config = Gpt2Config::from_file(config_path);
|
/// # let config = Gpt2Config::from_file(config_path);
|
||||||
///# let mut gpt2_model: GPT2LMHeadModel = GPT2LMHeadModel::new(&vs.root(), &config);
|
/// # let mut gpt2_model: GPT2LMHeadModel = GPT2LMHeadModel::new(&vs.root(), &config);
|
||||||
/// let (batch_size, sequence_length, past_sequence_length) = (64, 128, 56);
|
/// let (batch_size, sequence_length, past_sequence_length) = (64, 128, 56);
|
||||||
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
|
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
|
||||||
/// let mut past: Vec<Tensor> = Vec::with_capacity(config.n_layer as usize);
|
/// let mut past: Vec<Tensor> = Vec::with_capacity(config.n_layer as usize);
|
||||||
/// for _ in 0..config.n_layer as usize {
|
/// for _ in 0..config.n_layer as usize {
|
||||||
/// past.push(Tensor::rand(&[2, batch_size, config.n_head, past_sequence_length, config.n_embd / config.n_head], (Double, device)))
|
/// past.push(Tensor::rand(
|
||||||
|
/// &[
|
||||||
|
/// 2,
|
||||||
|
/// batch_size,
|
||||||
|
/// config.n_head,
|
||||||
|
/// past_sequence_length,
|
||||||
|
/// config.n_embd / config.n_head,
|
||||||
|
/// ],
|
||||||
|
/// (Double, device),
|
||||||
|
/// ))
|
||||||
/// }
|
/// }
|
||||||
/// let attention_mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
|
/// let attention_mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
|
||||||
/// let token_type_ids = Tensor::ones(&[batch_size, sequence_length], (Int64, device));
|
/// let token_type_ids = Tensor::ones(&[batch_size, sequence_length], (Int64, device));
|
||||||
/// let position_ids = Tensor::arange(sequence_length, (Int64, device)).expand(&[batch_size, sequence_length], true);
|
/// let position_ids = Tensor::arange(sequence_length, (Int64, device))
|
||||||
///
|
/// .expand(&[batch_size, sequence_length], true);
|
||||||
/// let (output, _, past, hidden_states, attentions) = no_grad(|| {
|
|
||||||
/// gpt2_model
|
|
||||||
/// .forward_t(&Some(input_tensor),
|
|
||||||
/// Cache::GPT2Cache(Some(past)),
|
|
||||||
/// &Some(attention_mask),
|
|
||||||
/// &Some(token_type_ids),
|
|
||||||
/// &Some(position_ids),
|
|
||||||
/// &None,
|
|
||||||
/// None,
|
|
||||||
/// &None,
|
|
||||||
/// false).unwrap()
|
|
||||||
/// });
|
|
||||||
///
|
///
|
||||||
|
/// let (output, _, past, hidden_states, attentions) = no_grad(|| {
|
||||||
|
/// gpt2_model
|
||||||
|
/// .forward_t(
|
||||||
|
/// &Some(input_tensor),
|
||||||
|
/// Cache::GPT2Cache(Some(past)),
|
||||||
|
/// &Some(attention_mask),
|
||||||
|
/// &Some(token_type_ids),
|
||||||
|
/// &Some(position_ids),
|
||||||
|
/// &None,
|
||||||
|
/// None,
|
||||||
|
/// &None,
|
||||||
|
/// false,
|
||||||
|
/// )
|
||||||
|
/// .unwrap()
|
||||||
|
/// });
|
||||||
/// ```
|
/// ```
|
||||||
///
|
fn forward_t(
|
||||||
fn forward_t(&self,
|
&self,
|
||||||
input_ids: &Option<Tensor>,
|
input_ids: &Option<Tensor>,
|
||||||
layer_past: Cache,
|
layer_past: Cache,
|
||||||
attention_mask: &Option<Tensor>,
|
attention_mask: &Option<Tensor>,
|
||||||
token_type_ids: &Option<Tensor>,
|
token_type_ids: &Option<Tensor>,
|
||||||
position_ids: &Option<Tensor>,
|
position_ids: &Option<Tensor>,
|
||||||
input_embeds: &Option<Tensor>,
|
input_embeds: &Option<Tensor>,
|
||||||
_encoder_outputs: Option<&Tensor>,
|
_encoder_outputs: Option<&Tensor>,
|
||||||
_decoder_input_ids: &Option<Tensor>,
|
_decoder_input_ids: &Option<Tensor>,
|
||||||
train: bool) -> Result<(Tensor, Option<Tensor>, Cache, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
|
train: bool,
|
||||||
let (output,
|
) -> Result<
|
||||||
past,
|
(
|
||||||
all_hidden_states,
|
Tensor,
|
||||||
all_attentions) = match layer_past {
|
Option<Tensor>,
|
||||||
Cache::GPT2Cache(layer_past) => Ok(self.transformer.forward_t(input_ids,
|
Cache,
|
||||||
&layer_past,
|
Option<Vec<Tensor>>,
|
||||||
attention_mask,
|
Option<Vec<Tensor>>,
|
||||||
token_type_ids,
|
),
|
||||||
position_ids,
|
&'static str,
|
||||||
input_embeds,
|
> {
|
||||||
train)?),
|
let (output, past, all_hidden_states, all_attentions) = match layer_past {
|
||||||
Cache::None => Ok(self.transformer.forward_t(input_ids,
|
Cache::GPT2Cache(layer_past) => Ok(self.transformer.forward_t(
|
||||||
&None,
|
input_ids,
|
||||||
attention_mask,
|
&layer_past,
|
||||||
token_type_ids,
|
attention_mask,
|
||||||
position_ids,
|
token_type_ids,
|
||||||
input_embeds,
|
position_ids,
|
||||||
train)?),
|
input_embeds,
|
||||||
_ => Err("Cache not compatible with GPT2 model")
|
train,
|
||||||
|
)?),
|
||||||
|
Cache::None => Ok(self.transformer.forward_t(
|
||||||
|
input_ids,
|
||||||
|
&None,
|
||||||
|
attention_mask,
|
||||||
|
token_type_ids,
|
||||||
|
position_ids,
|
||||||
|
input_embeds,
|
||||||
|
train,
|
||||||
|
)?),
|
||||||
|
_ => Err("Cache not compatible with GPT2 model"),
|
||||||
}?;
|
}?;
|
||||||
|
|
||||||
let lm_logits = output.apply(&self.lm_head);
|
let lm_logits = output.apply(&self.lm_head);
|
||||||
Ok((lm_logits, None, Cache::GPT2Cache(past), all_hidden_states, all_attentions))
|
Ok((
|
||||||
|
lm_logits,
|
||||||
|
None,
|
||||||
|
Cache::GPT2Cache(past),
|
||||||
|
all_hidden_states,
|
||||||
|
all_attentions,
|
||||||
|
))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -14,19 +14,27 @@
|
|||||||
//! Pretrained models are available and can be downloaded using RemoteResources.
|
//! Pretrained models are available and can be downloaded using RemoteResources.
|
||||||
//!
|
//!
|
||||||
//! ```no_run
|
//! ```no_run
|
||||||
//!# fn main() -> failure::Fallible<()> {
|
//! # fn main() -> failure::Fallible<()> {
|
||||||
//!#
|
//! #
|
||||||
//! use rust_tokenizers::Gpt2Tokenizer;
|
//! use rust_tokenizers::Gpt2Tokenizer;
|
||||||
//! use tch::{nn, Device};
|
//! use tch::{nn, Device};
|
||||||
//!# use std::path::PathBuf;
|
//! # use std::path::PathBuf;
|
||||||
|
//! use rust_bert::gpt2::{GPT2LMHeadModel, Gpt2Config};
|
||||||
|
//! use rust_bert::resources::{download_resource, LocalResource, Resource};
|
||||||
//! use rust_bert::Config;
|
//! use rust_bert::Config;
|
||||||
//! use rust_bert::gpt2::{Gpt2Config, GPT2LMHeadModel};
|
|
||||||
//! use rust_bert::resources::{Resource, download_resource, LocalResource};
|
|
||||||
//!
|
//!
|
||||||
//! let config_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/config.json")});
|
//! let config_resource = Resource::Local(LocalResource {
|
||||||
//! let vocab_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/vocab.txt")});
|
//! local_path: PathBuf::from("path/to/config.json"),
|
||||||
//! let merges_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/vocab.txt")});
|
//! });
|
||||||
//! let weights_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/model.ot")});
|
//! let vocab_resource = Resource::Local(LocalResource {
|
||||||
|
//! local_path: PathBuf::from("path/to/vocab.txt"),
|
||||||
|
//! });
|
||||||
|
//! let merges_resource = Resource::Local(LocalResource {
|
||||||
|
//! local_path: PathBuf::from("path/to/vocab.txt"),
|
||||||
|
//! });
|
||||||
|
//! let weights_resource = Resource::Local(LocalResource {
|
||||||
|
//! local_path: PathBuf::from("path/to/model.ot"),
|
||||||
|
//! });
|
||||||
//! let config_path = download_resource(&config_resource)?;
|
//! let config_path = download_resource(&config_resource)?;
|
||||||
//! let vocab_path = download_resource(&vocab_resource)?;
|
//! let vocab_path = download_resource(&vocab_resource)?;
|
||||||
//! let merges_path = download_resource(&merges_resource)?;
|
//! let merges_path = download_resource(&merges_resource)?;
|
||||||
@ -34,18 +42,24 @@
|
|||||||
//!
|
//!
|
||||||
//! let device = Device::cuda_if_available();
|
//! let device = Device::cuda_if_available();
|
||||||
//! let mut vs = nn::VarStore::new(device);
|
//! let mut vs = nn::VarStore::new(device);
|
||||||
//! let tokenizer: Gpt2Tokenizer = Gpt2Tokenizer::from_file(vocab_path.to_str().unwrap(), merges_path.to_str().unwrap(), true);
|
//! let tokenizer: Gpt2Tokenizer = Gpt2Tokenizer::from_file(
|
||||||
|
//! vocab_path.to_str().unwrap(),
|
||||||
|
//! merges_path.to_str().unwrap(),
|
||||||
|
//! true,
|
||||||
|
//! );
|
||||||
//! let config = Gpt2Config::from_file(config_path);
|
//! let config = Gpt2Config::from_file(config_path);
|
||||||
//! let gpt2_model = GPT2LMHeadModel::new(&vs.root(), &config);
|
//! let gpt2_model = GPT2LMHeadModel::new(&vs.root(), &config);
|
||||||
//! vs.load(weights_path)?;
|
//! vs.load(weights_path)?;
|
||||||
//!
|
//!
|
||||||
//!# Ok(())
|
//! # Ok(())
|
||||||
//!# }
|
//! # }
|
||||||
//! ```
|
//! ```
|
||||||
|
|
||||||
mod gpt2;
|
|
||||||
pub(crate) mod attention;
|
pub(crate) mod attention;
|
||||||
|
mod gpt2;
|
||||||
pub(crate) mod transformer;
|
pub(crate) mod transformer;
|
||||||
|
|
||||||
pub use gpt2::{Gpt2ModelResources, Gpt2ConfigResources, Gpt2VocabResources, Gpt2MergesResources,
|
pub use gpt2::{
|
||||||
Gpt2Config, Gpt2Model, GptActivation, GPT2LMHeadModel};
|
GPT2LMHeadModel, Gpt2Config, Gpt2ConfigResources, Gpt2MergesResources, Gpt2Model,
|
||||||
|
Gpt2ModelResources, Gpt2VocabResources, GptActivation,
|
||||||
|
};
|
||||||
|
@ -12,11 +12,11 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
use crate::gpt2::attention::{GPTConv1D, Attention};
|
|
||||||
use tch::{Tensor, nn};
|
|
||||||
use crate::common::dropout::Dropout;
|
|
||||||
use crate::gpt2::gpt2::{Gpt2Config, GptActivation};
|
|
||||||
use crate::common::activations::{_gelu_new, _relu, _swish};
|
use crate::common::activations::{_gelu_new, _relu, _swish};
|
||||||
|
use crate::common::dropout::Dropout;
|
||||||
|
use crate::gpt2::attention::{Attention, GPTConv1D};
|
||||||
|
use crate::gpt2::gpt2::{Gpt2Config, GptActivation};
|
||||||
|
use tch::{nn, Tensor};
|
||||||
|
|
||||||
pub struct MLP {
|
pub struct MLP {
|
||||||
c_fc: GPTConv1D,
|
c_fc: GPTConv1D,
|
||||||
@ -35,14 +35,19 @@ impl MLP {
|
|||||||
GptActivation::relu => _relu,
|
GptActivation::relu => _relu,
|
||||||
GptActivation::swish => _swish,
|
GptActivation::swish => _swish,
|
||||||
},
|
},
|
||||||
None => _gelu_new
|
None => _gelu_new,
|
||||||
});
|
});
|
||||||
let resid_pdrop = match config.resid_pdrop {
|
let resid_pdrop = match config.resid_pdrop {
|
||||||
Some(value) => value,
|
Some(value) => value,
|
||||||
None => 0.1
|
None => 0.1,
|
||||||
};
|
};
|
||||||
let dropout = Dropout::new(resid_pdrop);
|
let dropout = Dropout::new(resid_pdrop);
|
||||||
MLP { c_fc, c_proj, activation, dropout }
|
MLP {
|
||||||
|
c_fc,
|
||||||
|
c_proj,
|
||||||
|
activation,
|
||||||
|
dropout,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn forward_t(&self, x: &Tensor, train: bool) -> Tensor {
|
pub fn forward_t(&self, x: &Tensor, train: bool) -> Tensor {
|
||||||
@ -60,21 +65,36 @@ pub struct Block {
|
|||||||
|
|
||||||
impl Block {
|
impl Block {
|
||||||
pub fn new(p: &nn::Path, config: &Gpt2Config, scale: bool) -> Block {
|
pub fn new(p: &nn::Path, config: &Gpt2Config, scale: bool) -> Block {
|
||||||
let layer_norm_config = nn::LayerNormConfig { eps: config.layer_norm_epsilon, ..Default::default() };
|
let layer_norm_config = nn::LayerNormConfig {
|
||||||
|
eps: config.layer_norm_epsilon,
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
let ln_1 = nn::layer_norm(p / "ln_1", vec![config.n_embd], layer_norm_config);
|
let ln_1 = nn::layer_norm(p / "ln_1", vec![config.n_embd], layer_norm_config);
|
||||||
let ln_2 = nn::layer_norm(p / "ln_2", vec![config.n_embd], layer_norm_config);
|
let ln_2 = nn::layer_norm(p / "ln_2", vec![config.n_embd], layer_norm_config);
|
||||||
let attn = Attention::new(&(p / "attn"), config, scale);
|
let attn = Attention::new(&(p / "attn"), config, scale);
|
||||||
let mlp = MLP::new(&(p / "mlp"), config);
|
let mlp = MLP::new(&(p / "mlp"), config);
|
||||||
|
|
||||||
Block { ln_1, attn, ln_2, mlp }
|
Block {
|
||||||
|
ln_1,
|
||||||
|
attn,
|
||||||
|
ln_2,
|
||||||
|
mlp,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn forward_t(&self, x: &Tensor, layer_past: &Option<Tensor>, attention_mask: &Option<Tensor>, train: bool)
|
pub fn forward_t(
|
||||||
-> (Tensor, Tensor, Option<Tensor>) {
|
&self,
|
||||||
let (output, present, attentions) = self.attn.forward_t(&x.apply(&self.ln_1), layer_past, attention_mask, train);
|
x: &Tensor,
|
||||||
|
layer_past: &Option<Tensor>,
|
||||||
|
attention_mask: &Option<Tensor>,
|
||||||
|
train: bool,
|
||||||
|
) -> (Tensor, Tensor, Option<Tensor>) {
|
||||||
|
let (output, present, attentions) =
|
||||||
|
self.attn
|
||||||
|
.forward_t(&x.apply(&self.ln_1), layer_past, attention_mask, train);
|
||||||
let x = x + output;
|
let x = x + output;
|
||||||
let m = self.mlp.forward_t(&x.apply(&self.ln_2), train);
|
let m = self.mlp.forward_t(&x.apply(&self.ln_2), train);
|
||||||
let x = x + m;
|
let x = x + m;
|
||||||
(x, present, attentions)
|
(x, present, attentions)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
25
src/lib.rs
25
src/lib.rs
@ -16,14 +16,14 @@
|
|||||||
//!
|
//!
|
||||||
//! More information on these can be found in the [`pipelines` module](./pipelines/index.html)
|
//! More information on these can be found in the [`pipelines` module](./pipelines/index.html)
|
||||||
//! ```no_run
|
//! ```no_run
|
||||||
//! use rust_bert::pipelines::question_answering::{QuestionAnsweringModel, QaInput};
|
//! use rust_bert::pipelines::question_answering::{QaInput, QuestionAnsweringModel};
|
||||||
//!
|
//!
|
||||||
//!# fn main() -> failure::Fallible<()> {
|
//! # fn main() -> failure::Fallible<()> {
|
||||||
//! let qa_model = QuestionAnsweringModel::new(Default::default())?;
|
//! let qa_model = QuestionAnsweringModel::new(Default::default())?;
|
||||||
//!
|
//!
|
||||||
//! let question = String::from("Where does Amy live ?");
|
//! let question = String::from("Where does Amy live ?");
|
||||||
//! let context = String::from("Amy lives in Amsterdam");
|
//! let context = String::from("Amy lives in Amsterdam");
|
||||||
//! let answers = qa_model.predict(&vec!(QaInput { question, context }), 1, 32);
|
//! let answers = qa_model.predict(&vec![QaInput { question, context }], 1, 32);
|
||||||
//! # Ok(())
|
//! # Ok(())
|
||||||
//! # }
|
//! # }
|
||||||
//! ```
|
//! ```
|
||||||
@ -55,19 +55,18 @@
|
|||||||
//! - Set-up a virtual environment and install dependencies
|
//! - Set-up a virtual environment and install dependencies
|
||||||
//! - run the conversion script python /utils/download-dependencies_{MODEL_TO_DOWNLOAD}.py. The dependencies will be downloaded to the user's home directory, under ~/rustbert/{}
|
//! - run the conversion script python /utils/download-dependencies_{MODEL_TO_DOWNLOAD}.py. The dependencies will be downloaded to the user's home directory, under ~/rustbert/{}
|
||||||
//! 3. Run the example cargo run --release
|
//! 3. Run the example cargo run --release
|
||||||
//!
|
|
||||||
|
|
||||||
pub mod distilbert;
|
|
||||||
pub mod bert;
|
|
||||||
pub mod roberta;
|
|
||||||
pub mod openai_gpt;
|
|
||||||
pub mod gpt2;
|
|
||||||
pub mod bart;
|
|
||||||
pub mod electra;
|
|
||||||
pub mod marian;
|
|
||||||
pub mod albert;
|
pub mod albert;
|
||||||
|
pub mod bart;
|
||||||
|
pub mod bert;
|
||||||
mod common;
|
mod common;
|
||||||
|
pub mod distilbert;
|
||||||
|
pub mod electra;
|
||||||
|
pub mod gpt2;
|
||||||
|
pub mod marian;
|
||||||
|
pub mod openai_gpt;
|
||||||
pub mod pipelines;
|
pub mod pipelines;
|
||||||
|
pub mod roberta;
|
||||||
|
|
||||||
pub use common::Config;
|
|
||||||
pub use common::resources;
|
pub use common::resources;
|
||||||
|
pub use common::Config;
|
||||||
|
@ -11,10 +11,10 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
use crate::bart::{BartModel, BartConfig, LayerState};
|
use crate::bart::{BartConfig, BartModel, LayerState};
|
||||||
use tch::{Tensor, nn};
|
use crate::pipelines::generation::{Cache, LMHeadModel};
|
||||||
use crate::pipelines::generation::{LMHeadModel, Cache};
|
|
||||||
use tch::nn::Init;
|
use tch::nn::Init;
|
||||||
|
use tch::{nn, Tensor};
|
||||||
|
|
||||||
/// # Marian Pretrained model weight files
|
/// # Marian Pretrained model weight files
|
||||||
pub struct MarianModelResources;
|
pub struct MarianModelResources;
|
||||||
@ -33,78 +33,174 @@ pub struct MarianPrefix;
|
|||||||
|
|
||||||
impl MarianModelResources {
|
impl MarianModelResources {
|
||||||
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. Modified with conversion to C-array format.
|
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. Modified with conversion to C-array format.
|
||||||
pub const ENGLISH2ROMANCE: (&'static str, &'static str) = ("marian-mt-en-ROMANCE/model.ot", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-ROMANCE/rust_model.ot");
|
pub const ENGLISH2ROMANCE: (&'static str, &'static str) = (
|
||||||
|
"marian-mt-en-ROMANCE/model.ot",
|
||||||
|
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-ROMANCE/rust_model.ot",
|
||||||
|
);
|
||||||
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. Modified with conversion to C-array format.
|
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. Modified with conversion to C-array format.
|
||||||
pub const ROMANCE2ENGLISH: (&'static str, &'static str) = ("marian-mt-ROMANCE-en/model.ot", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-ROMANCE-en/rust_model.ot");
|
pub const ROMANCE2ENGLISH: (&'static str, &'static str) = (
|
||||||
|
"marian-mt-ROMANCE-en/model.ot",
|
||||||
|
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-ROMANCE-en/rust_model.ot",
|
||||||
|
);
|
||||||
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. Modified with conversion to C-array format.
|
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. Modified with conversion to C-array format.
|
||||||
pub const ENGLISH2GERMAN: (&'static str, &'static str) = ("marian-mt-en-de/model.ot", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-de/rust_model.ot");
|
pub const ENGLISH2GERMAN: (&'static str, &'static str) = (
|
||||||
|
"marian-mt-en-de/model.ot",
|
||||||
|
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-de/rust_model.ot",
|
||||||
|
);
|
||||||
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. Modified with conversion to C-array format.
|
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. Modified with conversion to C-array format.
|
||||||
pub const GERMAN2ENGLISH: (&'static str, &'static str) = ("marian-mt-de-en/model.ot", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-de-en/rust_model.ot");
|
pub const GERMAN2ENGLISH: (&'static str, &'static str) = (
|
||||||
|
"marian-mt-de-en/model.ot",
|
||||||
|
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-de-en/rust_model.ot",
|
||||||
|
);
|
||||||
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. Modified with conversion to C-array format.
|
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. Modified with conversion to C-array format.
|
||||||
pub const ENGLISH2RUSSIAN: (&'static str, &'static str) = ("marian-mt-en-ru/model.ot", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-ru/rust_model.ot");
|
pub const ENGLISH2RUSSIAN: (&'static str, &'static str) = (
|
||||||
|
"marian-mt-en-ru/model.ot",
|
||||||
|
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-ru/rust_model.ot",
|
||||||
|
);
|
||||||
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. Modified with conversion to C-array format.
|
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. Modified with conversion to C-array format.
|
||||||
pub const RUSSIAN2ENGLISH: (&'static str, &'static str) = ("marian-mt-ru-en/model.ot", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-ru-en/rust_model.ot");
|
pub const RUSSIAN2ENGLISH: (&'static str, &'static str) = (
|
||||||
|
"marian-mt-ru-en/model.ot",
|
||||||
|
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-ru-en/rust_model.ot",
|
||||||
|
);
|
||||||
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. Modified with conversion to C-array format.
|
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. Modified with conversion to C-array format.
|
||||||
pub const FRENCH2GERMAN: (&'static str, &'static str) = ("marian-mt-fr-de/model.ot", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-fr-de/rust_model.ot");
|
pub const FRENCH2GERMAN: (&'static str, &'static str) = (
|
||||||
|
"marian-mt-fr-de/model.ot",
|
||||||
|
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-fr-de/rust_model.ot",
|
||||||
|
);
|
||||||
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. Modified with conversion to C-array format.
|
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT. Modified with conversion to C-array format.
|
||||||
pub const GERMAN2FRENCH: (&'static str, &'static str) = ("marian-mt-de-fr/model.ot", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-de-fr/rust_model.ot");
|
pub const GERMAN2FRENCH: (&'static str, &'static str) = (
|
||||||
|
"marian-mt-de-fr/model.ot",
|
||||||
|
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-de-fr/rust_model.ot",
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
impl MarianConfigResources {
|
impl MarianConfigResources {
|
||||||
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
|
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
|
||||||
pub const ENGLISH2ROMANCE: (&'static str, &'static str) = ("marian-mt-en-ROMANCE/config.json", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-ROMANCE/config.json");
|
pub const ENGLISH2ROMANCE: (&'static str, &'static str) = (
|
||||||
|
"marian-mt-en-ROMANCE/config.json",
|
||||||
|
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-ROMANCE/config.json",
|
||||||
|
);
|
||||||
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
|
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
|
||||||
pub const ROMANCE2ENGLISH: (&'static str, &'static str) = ("marian-mt-ROMANCE-en/config.json", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-ROMANCE-en/config.json");
|
pub const ROMANCE2ENGLISH: (&'static str, &'static str) = (
|
||||||
|
"marian-mt-ROMANCE-en/config.json",
|
||||||
|
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-ROMANCE-en/config.json",
|
||||||
|
);
|
||||||
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
|
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
|
||||||
pub const ENGLISH2GERMAN: (&'static str, &'static str) = ("marian-mt-en-de/config.json", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-de/config.json");
|
pub const ENGLISH2GERMAN: (&'static str, &'static str) = (
|
||||||
|
"marian-mt-en-de/config.json",
|
||||||
|
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-de/config.json",
|
||||||
|
);
|
||||||
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
|
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
|
||||||
pub const GERMAN2ENGLISH: (&'static str, &'static str) = ("marian-mt-de-en/config.json", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-de-en/config.json");
|
pub const GERMAN2ENGLISH: (&'static str, &'static str) = (
|
||||||
|
"marian-mt-de-en/config.json",
|
||||||
|
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-de-en/config.json",
|
||||||
|
);
|
||||||
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
|
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
|
||||||
pub const ENGLISH2RUSSIAN: (&'static str, &'static str) = ("marian-mt-en-ru/config.json", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-ru/config.json");
|
pub const ENGLISH2RUSSIAN: (&'static str, &'static str) = (
|
||||||
|
"marian-mt-en-ru/config.json",
|
||||||
|
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-ru/config.json",
|
||||||
|
);
|
||||||
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
|
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
|
||||||
pub const RUSSIAN2ENGLISH: (&'static str, &'static str) = ("marian-mt-ru-en/config.json", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-ru-en/config.json");
|
pub const RUSSIAN2ENGLISH: (&'static str, &'static str) = (
|
||||||
|
"marian-mt-ru-en/config.json",
|
||||||
|
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-ru-en/config.json",
|
||||||
|
);
|
||||||
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
|
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
|
||||||
pub const FRENCH2GERMAN: (&'static str, &'static str) = ("marian-mt-fr-de/config.json", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-fr-de/config.json");
|
pub const FRENCH2GERMAN: (&'static str, &'static str) = (
|
||||||
|
"marian-mt-fr-de/config.json",
|
||||||
|
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-fr-de/config.json",
|
||||||
|
);
|
||||||
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
|
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
|
||||||
pub const GERMAN2FRENCH: (&'static str, &'static str) = ("marian-mt-de-fr/config.json", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-de-fr/config.json");
|
pub const GERMAN2FRENCH: (&'static str, &'static str) = (
|
||||||
|
"marian-mt-de-fr/config.json",
|
||||||
|
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-de-fr/config.json",
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
impl MarianVocabResources {
|
impl MarianVocabResources {
|
||||||
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
|
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
|
||||||
pub const ENGLISH2ROMANCE: (&'static str, &'static str) = ("marian-mt-en-ROMANCE/vocab.json", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-ROMANCE/vocab.json");
|
pub const ENGLISH2ROMANCE: (&'static str, &'static str) = (
|
||||||
|
"marian-mt-en-ROMANCE/vocab.json",
|
||||||
|
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-ROMANCE/vocab.json",
|
||||||
|
);
|
||||||
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
|
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
|
||||||
pub const ROMANCE2ENGLISH: (&'static str, &'static str) = ("marian-mt-ROMANCE-en/vocab.json", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-ROMANCE-en/vocab.json");
|
pub const ROMANCE2ENGLISH: (&'static str, &'static str) = (
|
||||||
|
"marian-mt-ROMANCE-en/vocab.json",
|
||||||
|
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-ROMANCE-en/vocab.json",
|
||||||
|
);
|
||||||
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
|
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
|
||||||
pub const ENGLISH2GERMAN: (&'static str, &'static str) = ("marian-mt-en-de/vocab.json", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-de/vocab.json");
|
pub const ENGLISH2GERMAN: (&'static str, &'static str) = (
|
||||||
|
"marian-mt-en-de/vocab.json",
|
||||||
|
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-de/vocab.json",
|
||||||
|
);
|
||||||
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
|
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
|
||||||
pub const GERMAN2ENGLISH: (&'static str, &'static str) = ("marian-mt-de-en/vocab.json", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-de-en/vocab.json");
|
pub const GERMAN2ENGLISH: (&'static str, &'static str) = (
|
||||||
|
"marian-mt-de-en/vocab.json",
|
||||||
|
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-de-en/vocab.json",
|
||||||
|
);
|
||||||
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
|
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
|
||||||
pub const ENGLISH2RUSSIAN: (&'static str, &'static str) = ("marian-mt-en-ru/vocab.json", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-ru/vocab.json");
|
pub const ENGLISH2RUSSIAN: (&'static str, &'static str) = (
|
||||||
|
"marian-mt-en-ru/vocab.json",
|
||||||
|
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-ru/vocab.json",
|
||||||
|
);
|
||||||
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
|
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
|
||||||
pub const RUSSIAN2ENGLISH: (&'static str, &'static str) = ("marian-mt-ru-en/vocab.json", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-ru-en/vocab.json");
|
pub const RUSSIAN2ENGLISH: (&'static str, &'static str) = (
|
||||||
|
"marian-mt-ru-en/vocab.json",
|
||||||
|
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-ru-en/vocab.json",
|
||||||
|
);
|
||||||
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
|
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
|
||||||
pub const FRENCH2GERMAN: (&'static str, &'static str) = ("marian-mt-fr-de/vocab.json", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-fr-de/vocab.json");
|
pub const FRENCH2GERMAN: (&'static str, &'static str) = (
|
||||||
|
"marian-mt-fr-de/vocab.json",
|
||||||
|
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-fr-de/vocab.json",
|
||||||
|
);
|
||||||
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
|
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
|
||||||
pub const GERMAN2FRENCH: (&'static str, &'static str) = ("marian-mt-de-fr/vocab.json", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-de-fr/vocab.json");
|
pub const GERMAN2FRENCH: (&'static str, &'static str) = (
|
||||||
|
"marian-mt-de-fr/vocab.json",
|
||||||
|
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-de-fr/vocab.json",
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
impl MarianSpmResources {
|
impl MarianSpmResources {
|
||||||
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
|
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
|
||||||
pub const ENGLISH2ROMANCE: (&'static str, &'static str) = ("marian-mt-en-ROMANCE/spiece.model", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-ROMANCE/source.spm");
|
pub const ENGLISH2ROMANCE: (&'static str, &'static str) = (
|
||||||
|
"marian-mt-en-ROMANCE/spiece.model",
|
||||||
|
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-ROMANCE/source.spm",
|
||||||
|
);
|
||||||
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
|
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
|
||||||
pub const ROMANCE2ENGLISH: (&'static str, &'static str) = ("marian-mt-ROMANCE-en/spiece.model", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-ROMANCE-en/source.spm");
|
pub const ROMANCE2ENGLISH: (&'static str, &'static str) = (
|
||||||
|
"marian-mt-ROMANCE-en/spiece.model",
|
||||||
|
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-ROMANCE-en/source.spm",
|
||||||
|
);
|
||||||
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
|
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
|
||||||
pub const ENGLISH2GERMAN: (&'static str, &'static str) = ("marian-mt-en-de/spiece.model", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-de/source.spm");
|
pub const ENGLISH2GERMAN: (&'static str, &'static str) = (
|
||||||
|
"marian-mt-en-de/spiece.model",
|
||||||
|
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-de/source.spm",
|
||||||
|
);
|
||||||
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
|
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
|
||||||
pub const GERMAN2ENGLISH: (&'static str, &'static str) = ("marian-mt-de-en/spiece.model", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-de-en/source.spm");
|
pub const GERMAN2ENGLISH: (&'static str, &'static str) = (
|
||||||
|
"marian-mt-de-en/spiece.model",
|
||||||
|
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-de-en/source.spm",
|
||||||
|
);
|
||||||
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
|
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
|
||||||
pub const ENGLISH2RUSSIAN: (&'static str, &'static str) = ("marian-mt-en-ru/spiece.model", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-ru/source.spm");
|
pub const ENGLISH2RUSSIAN: (&'static str, &'static str) = (
|
||||||
|
"marian-mt-en-ru/spiece.model",
|
||||||
|
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-en-ru/source.spm",
|
||||||
|
);
|
||||||
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
|
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
|
||||||
pub const RUSSIAN2ENGLISH: (&'static str, &'static str) = ("marian-mt-ru-en/spiece.model", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-ru-en/source.spm");
|
pub const RUSSIAN2ENGLISH: (&'static str, &'static str) = (
|
||||||
|
"marian-mt-ru-en/spiece.model",
|
||||||
|
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-ru-en/source.spm",
|
||||||
|
);
|
||||||
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
|
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
|
||||||
pub const FRENCH2GERMAN: (&'static str, &'static str) = ("marian-mt-fr-de/spiece.model", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-fr-de/source.spm");
|
pub const FRENCH2GERMAN: (&'static str, &'static str) = (
|
||||||
|
"marian-mt-fr-de/spiece.model",
|
||||||
|
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-fr-de/source.spm",
|
||||||
|
);
|
||||||
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
|
/// Shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
|
||||||
pub const GERMAN2FRENCH: (&'static str, &'static str) = ("marian-mt-de-fr/spiece.model", "https://cdn.huggingface.co/Helsinki-NLP/opus-mt-de-fr/source.spm");
|
pub const GERMAN2FRENCH: (&'static str, &'static str) = (
|
||||||
|
"marian-mt-de-fr/spiece.model",
|
||||||
|
"https://cdn.huggingface.co/Helsinki-NLP/opus-mt-de-fr/source.spm",
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
impl MarianPrefix {
|
impl MarianPrefix {
|
||||||
@ -150,23 +246,34 @@ impl MarianForConditionalGeneration {
|
|||||||
/// # Example
|
/// # Example
|
||||||
///
|
///
|
||||||
/// ```no_run
|
/// ```no_run
|
||||||
/// use tch::{nn, Device};
|
/// use rust_bert::bart::{BartConfig, BartForConditionalGeneration};
|
||||||
/// use rust_bert::Config;
|
/// use rust_bert::Config;
|
||||||
/// use std::path::Path;
|
/// use std::path::Path;
|
||||||
/// use rust_bert::bart::{BartConfig, BartForConditionalGeneration};
|
/// use tch::{nn, Device};
|
||||||
///
|
///
|
||||||
/// let config_path = Path::new("path/to/config.json");
|
/// let config_path = Path::new("path/to/config.json");
|
||||||
/// let device = Device::Cpu;
|
/// let device = Device::Cpu;
|
||||||
/// let p = nn::VarStore::new(device);
|
/// let p = nn::VarStore::new(device);
|
||||||
/// let config = BartConfig::from_file(config_path);
|
/// let config = BartConfig::from_file(config_path);
|
||||||
/// let generation_mode = true;
|
/// let generation_mode = true;
|
||||||
/// let bart: BartForConditionalGeneration = BartForConditionalGeneration::new(&(&p.root() / "bart"), &config, generation_mode);
|
/// let bart: BartForConditionalGeneration =
|
||||||
|
/// BartForConditionalGeneration::new(&(&p.root() / "bart"), &config, generation_mode);
|
||||||
/// ```
|
/// ```
|
||||||
///
|
pub fn new(
|
||||||
pub fn new(p: &nn::Path, config: &BartConfig, generation_mode: bool) -> MarianForConditionalGeneration {
|
p: &nn::Path,
|
||||||
|
config: &BartConfig,
|
||||||
|
generation_mode: bool,
|
||||||
|
) -> MarianForConditionalGeneration {
|
||||||
let base_model = BartModel::new(&(p / "model"), config, generation_mode);
|
let base_model = BartModel::new(&(p / "model"), config, generation_mode);
|
||||||
let final_logits_bias = p.var("final_logits_bias", &[1, config.vocab_size], Init::Const(0.));
|
let final_logits_bias = p.var(
|
||||||
MarianForConditionalGeneration { base_model, final_logits_bias }
|
"final_logits_bias",
|
||||||
|
&[1, config.vocab_size],
|
||||||
|
Init::Const(0.),
|
||||||
|
);
|
||||||
|
MarianForConditionalGeneration {
|
||||||
|
base_model,
|
||||||
|
final_logits_bias,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Forward pass through the model
|
/// Forward pass through the model
|
||||||
@ -193,64 +300,101 @@ impl MarianForConditionalGeneration {
|
|||||||
/// # Example
|
/// # Example
|
||||||
///
|
///
|
||||||
/// ```no_run
|
/// ```no_run
|
||||||
///# use tch::{nn, Device, Tensor, no_grad};
|
/// # use tch::{nn, Device, Tensor, no_grad};
|
||||||
///# use rust_bert::Config;
|
/// # use rust_bert::Config;
|
||||||
///# use std::path::Path;
|
/// # use std::path::Path;
|
||||||
///# use tch::kind::Kind::{Int64, Double};
|
/// # use tch::kind::Kind::{Int64, Double};
|
||||||
/// use rust_bert::bart::{BartConfig};
|
/// use rust_bert::bart::BartConfig;
|
||||||
/// use rust_bert::marian::MarianForConditionalGeneration;
|
/// use rust_bert::marian::MarianForConditionalGeneration;
|
||||||
///# let config_path = Path::new("path/to/config.json");
|
/// # let config_path = Path::new("path/to/config.json");
|
||||||
///# let vocab_path = Path::new("path/to/vocab.txt");
|
/// # let vocab_path = Path::new("path/to/vocab.txt");
|
||||||
///# let device = Device::Cpu;
|
/// # let device = Device::Cpu;
|
||||||
///# let vs = nn::VarStore::new(device);
|
/// # let vs = nn::VarStore::new(device);
|
||||||
///# let config = BartConfig::from_file(config_path);
|
/// # let config = BartConfig::from_file(config_path);
|
||||||
///# let mut marian_model = MarianForConditionalGeneration::new(&vs.root(), &config, false);
|
/// # let mut marian_model = MarianForConditionalGeneration::new(&vs.root(), &config, false);
|
||||||
/// let (batch_size, source_sequence_length, target_sequence_length) = (64, 128, 56);
|
/// let (batch_size, source_sequence_length, target_sequence_length) = (64, 128, 56);
|
||||||
/// let input_tensor = Tensor::rand(&[batch_size, source_sequence_length], (Int64, device));
|
/// let input_tensor = Tensor::rand(&[batch_size, source_sequence_length], (Int64, device));
|
||||||
/// let target_tensor = Tensor::rand(&[batch_size, target_sequence_length], (Int64, device));
|
/// let target_tensor = Tensor::rand(&[batch_size, target_sequence_length], (Int64, device));
|
||||||
/// let encoder_attention_mask = Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
|
/// let encoder_attention_mask =
|
||||||
/// let decoder_attention_mask = Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
|
/// Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
|
||||||
///
|
/// let decoder_attention_mask =
|
||||||
/// let (decoder_output, encoder_hidden_states, cache,
|
/// Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
|
||||||
/// all_encoder_hidden_states, all_encoder_attentions,
|
|
||||||
/// all_decoder_hidden_states, all_decoder_attentions) = no_grad(|| {
|
|
||||||
/// marian_model
|
|
||||||
/// .forward_t(Some(&input_tensor),
|
|
||||||
/// Some(&encoder_attention_mask),
|
|
||||||
/// None,
|
|
||||||
/// Some(&target_tensor),
|
|
||||||
/// Some(&decoder_attention_mask),
|
|
||||||
/// None,
|
|
||||||
/// false)
|
|
||||||
/// });
|
|
||||||
///
|
///
|
||||||
|
/// let (
|
||||||
|
/// decoder_output,
|
||||||
|
/// encoder_hidden_states,
|
||||||
|
/// cache,
|
||||||
|
/// all_encoder_hidden_states,
|
||||||
|
/// all_encoder_attentions,
|
||||||
|
/// all_decoder_hidden_states,
|
||||||
|
/// all_decoder_attentions,
|
||||||
|
/// ) = no_grad(|| {
|
||||||
|
/// marian_model.forward_t(
|
||||||
|
/// Some(&input_tensor),
|
||||||
|
/// Some(&encoder_attention_mask),
|
||||||
|
/// None,
|
||||||
|
/// Some(&target_tensor),
|
||||||
|
/// Some(&decoder_attention_mask),
|
||||||
|
/// None,
|
||||||
|
/// false,
|
||||||
|
/// )
|
||||||
|
/// });
|
||||||
/// ```
|
/// ```
|
||||||
///
|
pub fn forward_t(
|
||||||
pub fn forward_t(&self,
|
&self,
|
||||||
input_ids: Option<&Tensor>,
|
input_ids: Option<&Tensor>,
|
||||||
attention_mask: Option<&Tensor>,
|
attention_mask: Option<&Tensor>,
|
||||||
encoder_outputs: Option<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>)>,
|
encoder_outputs: Option<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>)>,
|
||||||
decoder_input_ids: Option<&Tensor>,
|
decoder_input_ids: Option<&Tensor>,
|
||||||
decoder_attention_mask: Option<&Tensor>,
|
decoder_attention_mask: Option<&Tensor>,
|
||||||
old_layer_states: Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
|
old_layer_states: Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
|
||||||
train: bool)
|
train: bool,
|
||||||
-> (Tensor, Tensor, Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
|
) -> (
|
||||||
Option<Vec<Tensor>>, Option<Vec<Tensor>>,
|
Tensor,
|
||||||
Option<Vec<Tensor>>, Option<Vec<Tensor>>)
|
Tensor,
|
||||||
{
|
Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
|
||||||
let (decoder_outputs, encoder_hidden_states, decoder_cache,
|
Option<Vec<Tensor>>,
|
||||||
all_decoder_hidden_states, all_decoder_attentions,
|
Option<Vec<Tensor>>,
|
||||||
all_encoder_hidden_states, all_encoder_attentions) =
|
Option<Vec<Tensor>>,
|
||||||
self.base_model.forward_t(input_ids, attention_mask, decoder_input_ids, encoder_outputs, decoder_attention_mask, old_layer_states, train);
|
Option<Vec<Tensor>>,
|
||||||
|
) {
|
||||||
|
let (
|
||||||
|
decoder_outputs,
|
||||||
|
encoder_hidden_states,
|
||||||
|
decoder_cache,
|
||||||
|
all_decoder_hidden_states,
|
||||||
|
all_decoder_attentions,
|
||||||
|
all_encoder_hidden_states,
|
||||||
|
all_encoder_attentions,
|
||||||
|
) = self.base_model.forward_t(
|
||||||
|
input_ids,
|
||||||
|
attention_mask,
|
||||||
|
decoder_input_ids,
|
||||||
|
encoder_outputs,
|
||||||
|
decoder_attention_mask,
|
||||||
|
old_layer_states,
|
||||||
|
train,
|
||||||
|
);
|
||||||
|
|
||||||
let lm_logits = decoder_outputs.linear::<Tensor>(&self.base_model.embeddings.ws, None);
|
let lm_logits = decoder_outputs.linear::<Tensor>(&self.base_model.embeddings.ws, None);
|
||||||
(lm_logits, encoder_hidden_states, decoder_cache,
|
(
|
||||||
all_decoder_hidden_states, all_decoder_attentions,
|
lm_logits,
|
||||||
all_encoder_hidden_states, all_encoder_attentions)
|
encoder_hidden_states,
|
||||||
|
decoder_cache,
|
||||||
|
all_decoder_hidden_states,
|
||||||
|
all_decoder_attentions,
|
||||||
|
all_encoder_hidden_states,
|
||||||
|
all_encoder_attentions,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn encode(&self, input_ids: &Tensor, attention_mask: Option<&Tensor>) -> Tensor {
|
pub fn encode(&self, input_ids: &Tensor, attention_mask: Option<&Tensor>) -> Tensor {
|
||||||
let (encoder_hidden_states, _, _) = self.base_model.encoder.forward_t(input_ids, attention_mask, &self.base_model.embeddings, false);
|
let (encoder_hidden_states, _, _) = self.base_model.encoder.forward_t(
|
||||||
|
input_ids,
|
||||||
|
attention_mask,
|
||||||
|
&self.base_model.embeddings,
|
||||||
|
false,
|
||||||
|
);
|
||||||
encoder_hidden_states
|
encoder_hidden_states
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -283,68 +427,97 @@ impl LMHeadModel for MarianForConditionalGeneration {
|
|||||||
/// # Example
|
/// # Example
|
||||||
///
|
///
|
||||||
/// ```no_run
|
/// ```no_run
|
||||||
///# use tch::{nn, Device, Tensor, no_grad};
|
/// # use tch::{nn, Device, Tensor, no_grad};
|
||||||
///# use rust_bert::Config;
|
/// # use rust_bert::Config;
|
||||||
///# use std::path::Path;
|
/// # use std::path::Path;
|
||||||
///# use tch::kind::Kind::{Int64, Double};
|
/// # use tch::kind::Kind::{Int64, Double};
|
||||||
/// use rust_bert::bart::{BartConfig};
|
/// use rust_bert::bart::BartConfig;
|
||||||
/// use rust_bert::marian::MarianForConditionalGeneration;
|
/// use rust_bert::marian::MarianForConditionalGeneration;
|
||||||
///# let config_path = Path::new("path/to/config.json");
|
/// # let config_path = Path::new("path/to/config.json");
|
||||||
///# let vocab_path = Path::new("path/to/vocab.txt");
|
/// # let vocab_path = Path::new("path/to/vocab.txt");
|
||||||
///# let device = Device::Cpu;
|
/// # let device = Device::Cpu;
|
||||||
///# let vs = nn::VarStore::new(device);
|
/// # let vs = nn::VarStore::new(device);
|
||||||
///# let config = BartConfig::from_file(config_path);
|
/// # let config = BartConfig::from_file(config_path);
|
||||||
///# let marian_model = MarianForConditionalGeneration::new(&vs.root(), &config, false);
|
/// # let marian_model = MarianForConditionalGeneration::new(&vs.root(), &config, false);
|
||||||
/// let (batch_size, source_sequence_length, target_sequence_length) = (64, 128, 56);
|
/// let (batch_size, source_sequence_length, target_sequence_length) = (64, 128, 56);
|
||||||
/// let input_tensor = Tensor::rand(&[batch_size, source_sequence_length], (Int64, device));
|
/// let input_tensor = Tensor::rand(&[batch_size, source_sequence_length], (Int64, device));
|
||||||
/// let target_tensor = Tensor::rand(&[batch_size, target_sequence_length], (Int64, device));
|
/// let target_tensor = Tensor::rand(&[batch_size, target_sequence_length], (Int64, device));
|
||||||
/// let encoder_attention_mask = Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
|
/// let encoder_attention_mask =
|
||||||
/// let decoder_attention_mask = Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
|
/// Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
|
||||||
///
|
/// let decoder_attention_mask =
|
||||||
/// let (decoder_output, encoder_hidden_states, cache,
|
/// Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
|
||||||
/// all_encoder_hidden_states, all_encoder_attentions,
|
|
||||||
/// all_decoder_hidden_states, all_decoder_attentions) = no_grad(|| {
|
|
||||||
/// marian_model
|
|
||||||
/// .forward_t(Some(&input_tensor),
|
|
||||||
/// Some(&encoder_attention_mask),
|
|
||||||
/// None,
|
|
||||||
/// Some(&target_tensor),
|
|
||||||
/// Some(&decoder_attention_mask),
|
|
||||||
/// None,
|
|
||||||
/// false)
|
|
||||||
/// });
|
|
||||||
///
|
///
|
||||||
|
/// let (
|
||||||
|
/// decoder_output,
|
||||||
|
/// encoder_hidden_states,
|
||||||
|
/// cache,
|
||||||
|
/// all_encoder_hidden_states,
|
||||||
|
/// all_encoder_attentions,
|
||||||
|
/// all_decoder_hidden_states,
|
||||||
|
/// all_decoder_attentions,
|
||||||
|
/// ) = no_grad(|| {
|
||||||
|
/// marian_model.forward_t(
|
||||||
|
/// Some(&input_tensor),
|
||||||
|
/// Some(&encoder_attention_mask),
|
||||||
|
/// None,
|
||||||
|
/// Some(&target_tensor),
|
||||||
|
/// Some(&decoder_attention_mask),
|
||||||
|
/// None,
|
||||||
|
/// false,
|
||||||
|
/// )
|
||||||
|
/// });
|
||||||
/// ```
|
/// ```
|
||||||
///
|
fn forward_t(
|
||||||
fn forward_t(&self,
|
&self,
|
||||||
input_ids: &Option<Tensor>,
|
input_ids: &Option<Tensor>,
|
||||||
cache: Cache,
|
cache: Cache,
|
||||||
attention_mask: &Option<Tensor>,
|
attention_mask: &Option<Tensor>,
|
||||||
_token_type_ids: &Option<Tensor>,
|
_token_type_ids: &Option<Tensor>,
|
||||||
_position_ids: &Option<Tensor>,
|
_position_ids: &Option<Tensor>,
|
||||||
_input_embeds: &Option<Tensor>,
|
_input_embeds: &Option<Tensor>,
|
||||||
encoder_outputs: Option<&Tensor>,
|
encoder_outputs: Option<&Tensor>,
|
||||||
decoder_input_ids: &Option<Tensor>,
|
decoder_input_ids: &Option<Tensor>,
|
||||||
train: bool) -> Result<(Tensor, Option<Tensor>, Cache, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
|
train: bool,
|
||||||
|
) -> Result<
|
||||||
|
(
|
||||||
|
Tensor,
|
||||||
|
Option<Tensor>,
|
||||||
|
Cache,
|
||||||
|
Option<Vec<Tensor>>,
|
||||||
|
Option<Vec<Tensor>>,
|
||||||
|
),
|
||||||
|
&'static str,
|
||||||
|
> {
|
||||||
let (decoder_output, encoder_hidden_states, new_cache, _, _, _, _) = match cache {
|
let (decoder_output, encoder_hidden_states, new_cache, _, _, _, _) = match cache {
|
||||||
Cache::BARTCache(cached_layer_states) => self.base_model.forward_t(input_ids.as_ref(),
|
Cache::BARTCache(cached_layer_states) => self.base_model.forward_t(
|
||||||
attention_mask.as_ref(),
|
input_ids.as_ref(),
|
||||||
decoder_input_ids.as_ref(),
|
attention_mask.as_ref(),
|
||||||
Some((encoder_outputs.as_ref().unwrap().copy(), None, None)),
|
decoder_input_ids.as_ref(),
|
||||||
None,
|
Some((encoder_outputs.as_ref().unwrap().copy(), None, None)),
|
||||||
cached_layer_states,
|
None,
|
||||||
train),
|
cached_layer_states,
|
||||||
Cache::None => self.base_model.forward_t(input_ids.as_ref(),
|
train,
|
||||||
attention_mask.as_ref(),
|
),
|
||||||
decoder_input_ids.as_ref(),
|
Cache::None => self.base_model.forward_t(
|
||||||
Some((encoder_outputs.as_ref().unwrap().copy(), None, None)),
|
input_ids.as_ref(),
|
||||||
None,
|
attention_mask.as_ref(),
|
||||||
None,
|
decoder_input_ids.as_ref(),
|
||||||
train),
|
Some((encoder_outputs.as_ref().unwrap().copy(), None, None)),
|
||||||
_ => Err("Cache not compatible with Marian Model")?
|
None,
|
||||||
|
None,
|
||||||
|
train,
|
||||||
|
),
|
||||||
|
_ => Err("Cache not compatible with Marian Model")?,
|
||||||
};
|
};
|
||||||
|
|
||||||
let lm_logits = decoder_output.linear::<Tensor>(&self.base_model.embeddings.ws, None) + &self.final_logits_bias;
|
let lm_logits = decoder_output.linear::<Tensor>(&self.base_model.embeddings.ws, None)
|
||||||
Ok((lm_logits, Some(encoder_hidden_states), Cache::BARTCache(new_cache), None, None))
|
+ &self.final_logits_bias;
|
||||||
|
Ok((
|
||||||
|
lm_logits,
|
||||||
|
Some(encoder_hidden_states),
|
||||||
|
Cache::BARTCache(new_cache),
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -15,20 +15,28 @@
|
|||||||
//! Pretrained models for a number of language pairs are available and can be downloaded using RemoteResources. These are shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
|
//! Pretrained models for a number of language pairs are available and can be downloaded using RemoteResources. These are shared under Creative Commons Attribution 4.0 International License license by the Opus-MT team from Language Technology at the University of Helsinki at https://github.com/Helsinki-NLP/Opus-MT.
|
||||||
//!
|
//!
|
||||||
//! ```no_run
|
//! ```no_run
|
||||||
//!# fn main() -> failure::Fallible<()> {
|
//! # fn main() -> failure::Fallible<()> {
|
||||||
//!#
|
//! #
|
||||||
//! use tch::{nn, Device};
|
//! use tch::{nn, Device};
|
||||||
//!# use std::path::PathBuf;
|
//! # use std::path::PathBuf;
|
||||||
//! use rust_bert::Config;
|
|
||||||
//! use rust_bert::bart::{BartConfig, BartModel};
|
//! use rust_bert::bart::{BartConfig, BartModel};
|
||||||
//! use rust_bert::resources::{Resource, download_resource, LocalResource};
|
|
||||||
//! use rust_tokenizers::preprocessing::tokenizer::marian_tokenizer::MarianTokenizer;
|
|
||||||
//! use rust_bert::marian::MarianForConditionalGeneration;
|
//! use rust_bert::marian::MarianForConditionalGeneration;
|
||||||
|
//! use rust_bert::resources::{download_resource, LocalResource, Resource};
|
||||||
|
//! use rust_bert::Config;
|
||||||
|
//! use rust_tokenizers::preprocessing::tokenizer::marian_tokenizer::MarianTokenizer;
|
||||||
//!
|
//!
|
||||||
//! let config_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/config.json")});
|
//! let config_resource = Resource::Local(LocalResource {
|
||||||
//! let vocab_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/vocab.json")});
|
//! local_path: PathBuf::from("path/to/config.json"),
|
||||||
//! let sentence_piece_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/spiece.model")});
|
//! });
|
||||||
//! let weights_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/model.ot")});
|
//! let vocab_resource = Resource::Local(LocalResource {
|
||||||
|
//! local_path: PathBuf::from("path/to/vocab.json"),
|
||||||
|
//! });
|
||||||
|
//! let sentence_piece_resource = Resource::Local(LocalResource {
|
||||||
|
//! local_path: PathBuf::from("path/to/spiece.model"),
|
||||||
|
//! });
|
||||||
|
//! let weights_resource = Resource::Local(LocalResource {
|
||||||
|
//! local_path: PathBuf::from("path/to/model.ot"),
|
||||||
|
//! });
|
||||||
//! let config_path = download_resource(&config_resource)?;
|
//! let config_path = download_resource(&config_resource)?;
|
||||||
//! let vocab_path = download_resource(&vocab_resource)?;
|
//! let vocab_path = download_resource(&vocab_resource)?;
|
||||||
//! let spiece_path = download_resource(&sentence_piece_resource)?;
|
//! let spiece_path = download_resource(&sentence_piece_resource)?;
|
||||||
@ -36,15 +44,22 @@
|
|||||||
//!
|
//!
|
||||||
//! let device = Device::cuda_if_available();
|
//! let device = Device::cuda_if_available();
|
||||||
//! let mut vs = nn::VarStore::new(device);
|
//! let mut vs = nn::VarStore::new(device);
|
||||||
//! let tokenizer = MarianTokenizer::from_files(vocab_path.to_str().unwrap(), spiece_path.to_str().unwrap(), true);
|
//! let tokenizer = MarianTokenizer::from_files(
|
||||||
|
//! vocab_path.to_str().unwrap(),
|
||||||
|
//! spiece_path.to_str().unwrap(),
|
||||||
|
//! true,
|
||||||
|
//! );
|
||||||
//! let config = BartConfig::from_file(config_path);
|
//! let config = BartConfig::from_file(config_path);
|
||||||
//! let marian_model = MarianForConditionalGeneration::new(&vs.root(), &config, false);
|
//! let marian_model = MarianForConditionalGeneration::new(&vs.root(), &config, false);
|
||||||
//! vs.load(weights_path)?;
|
//! vs.load(weights_path)?;
|
||||||
//!
|
//!
|
||||||
//!# Ok(())
|
//! # Ok(())
|
||||||
//!# }
|
//! # }
|
||||||
//! ```
|
//! ```
|
||||||
|
|
||||||
mod marian;
|
mod marian;
|
||||||
|
|
||||||
pub use marian::{MarianForConditionalGeneration, MarianModelResources, MarianConfigResources, MarianVocabResources, MarianSpmResources, MarianPrefix};
|
pub use marian::{
|
||||||
|
MarianConfigResources, MarianForConditionalGeneration, MarianModelResources, MarianPrefix,
|
||||||
|
MarianSpmResources, MarianVocabResources,
|
||||||
|
};
|
||||||
|
@ -14,19 +14,27 @@
|
|||||||
//! Pretrained models are available and can be downloaded using RemoteResources.
|
//! Pretrained models are available and can be downloaded using RemoteResources.
|
||||||
//!
|
//!
|
||||||
//! ```no_run
|
//! ```no_run
|
||||||
//!# fn main() -> failure::Fallible<()> {
|
//! # fn main() -> failure::Fallible<()> {
|
||||||
//! use rust_tokenizers::OpenAiGptTokenizer;
|
//! use rust_tokenizers::OpenAiGptTokenizer;
|
||||||
//! use tch::{nn, Device};
|
//! use tch::{nn, Device};
|
||||||
//!# use std::path::PathBuf;
|
//! # use std::path::PathBuf;
|
||||||
//! use rust_bert::Config;
|
|
||||||
//! use rust_bert::gpt2::Gpt2Config;
|
//! use rust_bert::gpt2::Gpt2Config;
|
||||||
//! use rust_bert::openai_gpt::OpenAiGptModel;
|
//! use rust_bert::openai_gpt::OpenAiGptModel;
|
||||||
//! use rust_bert::resources::{Resource, download_resource, LocalResource};
|
//! use rust_bert::resources::{download_resource, LocalResource, Resource};
|
||||||
|
//! use rust_bert::Config;
|
||||||
//!
|
//!
|
||||||
//! let config_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/config.json")});
|
//! let config_resource = Resource::Local(LocalResource {
|
||||||
//! let vocab_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/vocab.txt")});
|
//! local_path: PathBuf::from("path/to/config.json"),
|
||||||
//! let merges_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/vocab.txt")});
|
//! });
|
||||||
//! let weights_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/model.ot")});
|
//! let vocab_resource = Resource::Local(LocalResource {
|
||||||
|
//! local_path: PathBuf::from("path/to/vocab.txt"),
|
||||||
|
//! });
|
||||||
|
//! let merges_resource = Resource::Local(LocalResource {
|
||||||
|
//! local_path: PathBuf::from("path/to/vocab.txt"),
|
||||||
|
//! });
|
||||||
|
//! let weights_resource = Resource::Local(LocalResource {
|
||||||
|
//! local_path: PathBuf::from("path/to/model.ot"),
|
||||||
|
//! });
|
||||||
//! let config_path = download_resource(&config_resource)?;
|
//! let config_path = download_resource(&config_resource)?;
|
||||||
//! let vocab_path = download_resource(&vocab_resource)?;
|
//! let vocab_path = download_resource(&vocab_resource)?;
|
||||||
//! let merges_path = download_resource(&merges_resource)?;
|
//! let merges_path = download_resource(&merges_resource)?;
|
||||||
@ -34,17 +42,23 @@
|
|||||||
//!
|
//!
|
||||||
//! let device = Device::cuda_if_available();
|
//! let device = Device::cuda_if_available();
|
||||||
//! let mut vs = nn::VarStore::new(device);
|
//! let mut vs = nn::VarStore::new(device);
|
||||||
//! let tokenizer: OpenAiGptTokenizer = OpenAiGptTokenizer::from_file(vocab_path.to_str().unwrap(), merges_path.to_str().unwrap(), true);
|
//! let tokenizer: OpenAiGptTokenizer = OpenAiGptTokenizer::from_file(
|
||||||
|
//! vocab_path.to_str().unwrap(),
|
||||||
|
//! merges_path.to_str().unwrap(),
|
||||||
|
//! true,
|
||||||
|
//! );
|
||||||
//! let config = Gpt2Config::from_file(config_path);
|
//! let config = Gpt2Config::from_file(config_path);
|
||||||
//! let gpt_model = OpenAiGptModel::new(&vs.root(), &config);
|
//! let gpt_model = OpenAiGptModel::new(&vs.root(), &config);
|
||||||
//! vs.load(weights_path)?;
|
//! vs.load(weights_path)?;
|
||||||
//!
|
//!
|
||||||
//!# Ok(())
|
//! # Ok(())
|
||||||
//!# }
|
//! # }
|
||||||
//! ```
|
//! ```
|
||||||
|
|
||||||
mod openai_gpt;
|
mod openai_gpt;
|
||||||
mod transformer;
|
mod transformer;
|
||||||
|
|
||||||
pub use openai_gpt::{OpenAiGptModelResources, OpenAiGptConfigResources, OpenAiGptVocabResources, OpenAiGptMergesResources,
|
pub use openai_gpt::{
|
||||||
OpenAiGptModel, OpenAIGPTLMHeadModel};
|
OpenAIGPTLMHeadModel, OpenAiGptConfigResources, OpenAiGptMergesResources, OpenAiGptModel,
|
||||||
|
OpenAiGptModelResources, OpenAiGptVocabResources,
|
||||||
|
};
|
||||||
|
@ -12,15 +12,15 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
use tch::{nn, Tensor};
|
|
||||||
use crate::common::dropout::Dropout;
|
use crate::common::dropout::Dropout;
|
||||||
use tch::nn::embedding;
|
use crate::common::linear::{linear_no_bias, LinearNoBias};
|
||||||
use tch::kind::Kind::Int64;
|
|
||||||
use std::borrow::BorrowMut;
|
|
||||||
use crate::common::linear::{LinearNoBias, linear_no_bias};
|
|
||||||
use crate::openai_gpt::transformer::Block;
|
|
||||||
use crate::gpt2::Gpt2Config;
|
use crate::gpt2::Gpt2Config;
|
||||||
use crate::pipelines::generation::{LMHeadModel, Cache};
|
use crate::openai_gpt::transformer::Block;
|
||||||
|
use crate::pipelines::generation::{Cache, LMHeadModel};
|
||||||
|
use std::borrow::BorrowMut;
|
||||||
|
use tch::kind::Kind::Int64;
|
||||||
|
use tch::nn::embedding;
|
||||||
|
use tch::{nn, Tensor};
|
||||||
|
|
||||||
/// # GPT Pretrained model weight files
|
/// # GPT Pretrained model weight files
|
||||||
pub struct OpenAiGptModelResources;
|
pub struct OpenAiGptModelResources;
|
||||||
@ -36,22 +36,34 @@ pub struct OpenAiGptMergesResources;
|
|||||||
|
|
||||||
impl OpenAiGptModelResources {
|
impl OpenAiGptModelResources {
|
||||||
/// Shared under MIT license by the OpenAI team at https://github.com/openai/finetune-transformer-lm. Modified with conversion to C-array format.
|
/// Shared under MIT license by the OpenAI team at https://github.com/openai/finetune-transformer-lm. Modified with conversion to C-array format.
|
||||||
pub const GPT: (&'static str, &'static str) = ("openai-gpt/model.ot", "https://cdn.huggingface.co/openai-gpt-rust_model.ot");
|
pub const GPT: (&'static str, &'static str) = (
|
||||||
|
"openai-gpt/model.ot",
|
||||||
|
"https://cdn.huggingface.co/openai-gpt-rust_model.ot",
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
impl OpenAiGptConfigResources {
|
impl OpenAiGptConfigResources {
|
||||||
/// Shared under MIT license by the OpenAI team at https://github.com/openai/finetune-transformer-lm. Modified with conversion to C-array format.
|
/// Shared under MIT license by the OpenAI team at https://github.com/openai/finetune-transformer-lm. Modified with conversion to C-array format.
|
||||||
pub const GPT: (&'static str, &'static str) = ("openai-gpt/config.json", "https://cdn.huggingface.co/openai-gpt-config.json");
|
pub const GPT: (&'static str, &'static str) = (
|
||||||
|
"openai-gpt/config.json",
|
||||||
|
"https://cdn.huggingface.co/openai-gpt-config.json",
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
impl OpenAiGptVocabResources {
|
impl OpenAiGptVocabResources {
|
||||||
/// Shared under MIT license by the OpenAI team at https://github.com/openai/finetune-transformer-lm. Modified with conversion to C-array format.
|
/// Shared under MIT license by the OpenAI team at https://github.com/openai/finetune-transformer-lm. Modified with conversion to C-array format.
|
||||||
pub const GPT: (&'static str, &'static str) = ("openai-gpt/vocab.txt", "https://cdn.huggingface.co/openai-gpt-vocab.json");
|
pub const GPT: (&'static str, &'static str) = (
|
||||||
|
"openai-gpt/vocab.txt",
|
||||||
|
"https://cdn.huggingface.co/openai-gpt-vocab.json",
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
impl OpenAiGptMergesResources {
|
impl OpenAiGptMergesResources {
|
||||||
/// Shared under MIT license by the OpenAI team at https://github.com/openai/finetune-transformer-lm. Modified with conversion to C-array format.
|
/// Shared under MIT license by the OpenAI team at https://github.com/openai/finetune-transformer-lm. Modified with conversion to C-array format.
|
||||||
pub const GPT: (&'static str, &'static str) = ("openai-gpt/merges.txt", "https://cdn.huggingface.co/openai-gpt-merges.txt");
|
pub const GPT: (&'static str, &'static str) = (
|
||||||
|
"openai-gpt/merges.txt",
|
||||||
|
"https://cdn.huggingface.co/openai-gpt-merges.txt",
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// # GPT Base model
|
/// # GPT Base model
|
||||||
@ -82,11 +94,11 @@ impl OpenAiGptModel {
|
|||||||
/// # Example
|
/// # Example
|
||||||
///
|
///
|
||||||
/// ```no_run
|
/// ```no_run
|
||||||
/// use tch::{nn, Device};
|
|
||||||
/// use rust_bert::Config;
|
|
||||||
/// use std::path::Path;
|
|
||||||
/// use rust_bert::gpt2::Gpt2Config;
|
/// use rust_bert::gpt2::Gpt2Config;
|
||||||
/// use rust_bert::openai_gpt::OpenAiGptModel;
|
/// use rust_bert::openai_gpt::OpenAiGptModel;
|
||||||
|
/// use rust_bert::Config;
|
||||||
|
/// use std::path::Path;
|
||||||
|
/// use tch::{nn, Device};
|
||||||
///
|
///
|
||||||
/// let config_path = Path::new("path/to/config.json");
|
/// let config_path = Path::new("path/to/config.json");
|
||||||
/// let device = Device::Cpu;
|
/// let device = Device::Cpu;
|
||||||
@ -94,30 +106,46 @@ impl OpenAiGptModel {
|
|||||||
/// let config = Gpt2Config::from_file(config_path);
|
/// let config = Gpt2Config::from_file(config_path);
|
||||||
/// let gpt2: OpenAiGptModel = OpenAiGptModel::new(&(&p.root() / "gpt"), &config);
|
/// let gpt2: OpenAiGptModel = OpenAiGptModel::new(&(&p.root() / "gpt"), &config);
|
||||||
/// ```
|
/// ```
|
||||||
///
|
|
||||||
pub fn new(p: &nn::Path, config: &Gpt2Config) -> OpenAiGptModel {
|
pub fn new(p: &nn::Path, config: &Gpt2Config) -> OpenAiGptModel {
|
||||||
let tokens_embed = embedding(&(p / "tokens_embed"), config.vocab_size, config.n_embd, Default::default());
|
let tokens_embed = embedding(
|
||||||
let positions_embed = embedding(&(p / "positions_embed"), config.n_positions, config.n_embd, Default::default());
|
&(p / "tokens_embed"),
|
||||||
|
config.vocab_size,
|
||||||
|
config.n_embd,
|
||||||
|
Default::default(),
|
||||||
|
);
|
||||||
|
let positions_embed = embedding(
|
||||||
|
&(p / "positions_embed"),
|
||||||
|
config.n_positions,
|
||||||
|
config.n_embd,
|
||||||
|
Default::default(),
|
||||||
|
);
|
||||||
|
|
||||||
let embd_pdrop = match config.embd_pdrop {
|
let embd_pdrop = match config.embd_pdrop {
|
||||||
Some(value) => value,
|
Some(value) => value,
|
||||||
None => 0.1
|
None => 0.1,
|
||||||
};
|
};
|
||||||
let drop = Dropout::new(embd_pdrop);
|
let drop = Dropout::new(embd_pdrop);
|
||||||
let mut h: Vec<Block> = vec!();
|
let mut h: Vec<Block> = vec![];
|
||||||
let h_path = &(p / "h");
|
let h_path = &(p / "h");
|
||||||
for layer_index in 0..config.n_layer {
|
for layer_index in 0..config.n_layer {
|
||||||
h.push(Block::new(&(h_path / layer_index), config, true));
|
h.push(Block::new(&(h_path / layer_index), config, true));
|
||||||
};
|
}
|
||||||
let output_attentions = match config.output_attentions {
|
let output_attentions = match config.output_attentions {
|
||||||
Some(value) => value,
|
Some(value) => value,
|
||||||
None => false
|
None => false,
|
||||||
};
|
};
|
||||||
let output_hidden_states = match config.output_hidden_states {
|
let output_hidden_states = match config.output_hidden_states {
|
||||||
Some(value) => value,
|
Some(value) => value,
|
||||||
None => false
|
None => false,
|
||||||
};
|
};
|
||||||
OpenAiGptModel { tokens_embed, positions_embed, drop, h, output_hidden_states, output_attentions }
|
OpenAiGptModel {
|
||||||
|
tokens_embed,
|
||||||
|
positions_embed,
|
||||||
|
drop,
|
||||||
|
h,
|
||||||
|
output_hidden_states,
|
||||||
|
output_attentions,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Forward pass through the model
|
/// Forward pass through the model
|
||||||
@ -140,80 +168,99 @@ impl OpenAiGptModel {
|
|||||||
/// # Example
|
/// # Example
|
||||||
///
|
///
|
||||||
/// ```no_run
|
/// ```no_run
|
||||||
///# use tch::{nn, Device, Tensor, no_grad};
|
/// # use tch::{nn, Device, Tensor, no_grad};
|
||||||
///# use rust_bert::Config;
|
/// # use rust_bert::Config;
|
||||||
///# use std::path::Path;
|
/// # use std::path::Path;
|
||||||
///# use tch::kind::Kind::{Int64, Double};
|
/// # use tch::kind::Kind::{Int64, Double};
|
||||||
/// use rust_bert::gpt2::Gpt2Config;
|
/// use rust_bert::gpt2::Gpt2Config;
|
||||||
/// use rust_bert::openai_gpt::OpenAiGptModel;
|
/// use rust_bert::openai_gpt::OpenAiGptModel;
|
||||||
///# let config_path = Path::new("path/to/config.json");
|
/// # let config_path = Path::new("path/to/config.json");
|
||||||
///# let vocab_path = Path::new("path/to/vocab.txt");
|
/// # let vocab_path = Path::new("path/to/vocab.txt");
|
||||||
///# let device = Device::Cpu;
|
/// # let device = Device::Cpu;
|
||||||
///# let vs = nn::VarStore::new(device);
|
/// # let vs = nn::VarStore::new(device);
|
||||||
///# let config = Gpt2Config::from_file(config_path);
|
/// # let config = Gpt2Config::from_file(config_path);
|
||||||
///# let gpt_model: OpenAiGptModel = OpenAiGptModel::new(&vs.root(), &config);
|
/// # let gpt_model: OpenAiGptModel = OpenAiGptModel::new(&vs.root(), &config);
|
||||||
/// let (batch_size, sequence_length, past_sequence_length) = (64, 128, 56);
|
/// let (batch_size, sequence_length, past_sequence_length) = (64, 128, 56);
|
||||||
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
|
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
|
||||||
/// let attention_mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
|
/// let attention_mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
|
||||||
/// let token_type_ids = Tensor::ones(&[batch_size, sequence_length], (Int64, device));
|
/// let token_type_ids = Tensor::ones(&[batch_size, sequence_length], (Int64, device));
|
||||||
/// let position_ids = Tensor::arange(sequence_length, (Int64, device)).expand(&[batch_size, sequence_length], true);
|
/// let position_ids = Tensor::arange(sequence_length, (Int64, device))
|
||||||
///
|
/// .expand(&[batch_size, sequence_length], true);
|
||||||
/// let (output, hidden_states, attentions) = no_grad(|| {
|
|
||||||
/// gpt_model
|
|
||||||
/// .forward_t(&Some(input_tensor),
|
|
||||||
/// &Some(attention_mask),
|
|
||||||
/// &Some(token_type_ids),
|
|
||||||
/// &Some(position_ids),
|
|
||||||
/// &None,
|
|
||||||
/// false).unwrap()
|
|
||||||
/// });
|
|
||||||
///
|
///
|
||||||
|
/// let (output, hidden_states, attentions) = no_grad(|| {
|
||||||
|
/// gpt_model
|
||||||
|
/// .forward_t(
|
||||||
|
/// &Some(input_tensor),
|
||||||
|
/// &Some(attention_mask),
|
||||||
|
/// &Some(token_type_ids),
|
||||||
|
/// &Some(position_ids),
|
||||||
|
/// &None,
|
||||||
|
/// false,
|
||||||
|
/// )
|
||||||
|
/// .unwrap()
|
||||||
|
/// });
|
||||||
/// ```
|
/// ```
|
||||||
///
|
pub fn forward_t(
|
||||||
pub fn forward_t(&self,
|
&self,
|
||||||
input_ids: &Option<Tensor>,
|
input_ids: &Option<Tensor>,
|
||||||
attention_mask: &Option<Tensor>,
|
attention_mask: &Option<Tensor>,
|
||||||
token_type_ids: &Option<Tensor>,
|
token_type_ids: &Option<Tensor>,
|
||||||
position_ids: &Option<Tensor>,
|
position_ids: &Option<Tensor>,
|
||||||
input_embeds: &Option<Tensor>,
|
input_embeds: &Option<Tensor>,
|
||||||
train: bool) -> Result<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
|
train: bool,
|
||||||
|
) -> Result<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
|
||||||
let (input_embeddings, seq_length) = match input_ids {
|
let (input_embeddings, seq_length) = match input_ids {
|
||||||
Some(input_value) => match input_embeds {
|
Some(input_value) => match input_embeds {
|
||||||
Some(_) => { return Err("Only one of input ids or input embeddings may be set"); }
|
Some(_) => {
|
||||||
None => (input_value.apply(&self.tokens_embed), *input_value.size().last().unwrap())
|
return Err("Only one of input ids or input embeddings may be set");
|
||||||
}
|
}
|
||||||
|
None => (
|
||||||
|
input_value.apply(&self.tokens_embed),
|
||||||
|
*input_value.size().last().unwrap(),
|
||||||
|
),
|
||||||
|
},
|
||||||
None => match input_embeds {
|
None => match input_embeds {
|
||||||
Some(embeds) => (embeds.copy(), embeds.size()[1]),
|
Some(embeds) => (embeds.copy(), embeds.size()[1]),
|
||||||
None => { return Err("At least one of input ids or input embeddings must be set"); }
|
None => {
|
||||||
}
|
return Err("At least one of input ids or input embeddings must be set");
|
||||||
|
}
|
||||||
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
let position_ids = match position_ids {
|
let position_ids = match position_ids {
|
||||||
Some(value) => value.copy(),
|
Some(value) => value.copy(),
|
||||||
None => Tensor::arange(seq_length, (Int64, input_embeddings.device())).unsqueeze(0)
|
None => Tensor::arange(seq_length, (Int64, input_embeddings.device())).unsqueeze(0),
|
||||||
};
|
};
|
||||||
|
|
||||||
let attention_mask: Option<Tensor> = match attention_mask {
|
let attention_mask: Option<Tensor> = match attention_mask {
|
||||||
Some(value) => {
|
Some(value) => Some(
|
||||||
Some(
|
(value
|
||||||
(value
|
.view((input_embeddings.size()[0], -1))
|
||||||
.view((input_embeddings.size()[0], -1))
|
.unsqueeze(1)
|
||||||
.unsqueeze(1)
|
.unsqueeze(2)
|
||||||
.unsqueeze(2)
|
- 1.0)
|
||||||
- 1.0
|
* 10000.0,
|
||||||
) * 10000.0)
|
),
|
||||||
}
|
None => None,
|
||||||
None => None
|
|
||||||
};
|
};
|
||||||
|
|
||||||
let position_embeds = position_ids.apply(&self.positions_embed);
|
let position_embeds = position_ids.apply(&self.positions_embed);
|
||||||
let token_type_embeds = match token_type_ids {
|
let token_type_embeds = match token_type_ids {
|
||||||
Some(value) => value.apply(&self.tokens_embed),
|
Some(value) => value.apply(&self.tokens_embed),
|
||||||
None => Tensor::zeros_like(&position_embeds)
|
None => Tensor::zeros_like(&position_embeds),
|
||||||
|
};
|
||||||
|
let mut hidden_state: Tensor =
|
||||||
|
(input_embeddings + position_embeds + token_type_embeds).apply_t(&self.drop, train);
|
||||||
|
let mut all_hidden_states: Option<Vec<Tensor>> = if self.output_hidden_states {
|
||||||
|
Some(vec![])
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
let mut all_attentions: Option<Vec<Tensor>> = if self.output_attentions {
|
||||||
|
Some(vec![])
|
||||||
|
} else {
|
||||||
|
None
|
||||||
};
|
};
|
||||||
let mut hidden_state: Tensor = (input_embeddings + position_embeds + token_type_embeds).apply_t(&self.drop, train);
|
|
||||||
let mut all_hidden_states: Option<Vec<Tensor>> = if self.output_hidden_states { Some(vec!()) } else { None };
|
|
||||||
let mut all_attentions: Option<Vec<Tensor>> = if self.output_attentions { Some(vec!()) } else { None };
|
|
||||||
|
|
||||||
let mut layers = self.h.iter();
|
let mut layers = self.h.iter();
|
||||||
loop {
|
loop {
|
||||||
@ -229,9 +276,9 @@ impl OpenAiGptModel {
|
|||||||
attentions.push(temp.1.as_ref().unwrap().copy());
|
attentions.push(temp.1.as_ref().unwrap().copy());
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
None => break
|
None => break,
|
||||||
};
|
};
|
||||||
};
|
}
|
||||||
|
|
||||||
Ok((hidden_state, all_hidden_states, all_attentions))
|
Ok((hidden_state, all_hidden_states, all_attentions))
|
||||||
}
|
}
|
||||||
@ -258,11 +305,11 @@ impl OpenAIGPTLMHeadModel {
|
|||||||
/// # Example
|
/// # Example
|
||||||
///
|
///
|
||||||
/// ```no_run
|
/// ```no_run
|
||||||
/// use tch::{nn, Device};
|
|
||||||
/// use rust_bert::Config;
|
|
||||||
/// use std::path::Path;
|
|
||||||
/// use rust_bert::gpt2::Gpt2Config;
|
/// use rust_bert::gpt2::Gpt2Config;
|
||||||
/// use rust_bert::openai_gpt::OpenAIGPTLMHeadModel;
|
/// use rust_bert::openai_gpt::OpenAIGPTLMHeadModel;
|
||||||
|
/// use rust_bert::Config;
|
||||||
|
/// use std::path::Path;
|
||||||
|
/// use tch::{nn, Device};
|
||||||
///
|
///
|
||||||
/// let config_path = Path::new("path/to/config.json");
|
/// let config_path = Path::new("path/to/config.json");
|
||||||
/// let device = Device::Cpu;
|
/// let device = Device::Cpu;
|
||||||
@ -270,11 +317,18 @@ impl OpenAIGPTLMHeadModel {
|
|||||||
/// let config = Gpt2Config::from_file(config_path);
|
/// let config = Gpt2Config::from_file(config_path);
|
||||||
/// let gpt2: OpenAIGPTLMHeadModel = OpenAIGPTLMHeadModel::new(&(&p.root() / "gpt"), &config);
|
/// let gpt2: OpenAIGPTLMHeadModel = OpenAIGPTLMHeadModel::new(&(&p.root() / "gpt"), &config);
|
||||||
/// ```
|
/// ```
|
||||||
///
|
|
||||||
pub fn new(p: &nn::Path, config: &Gpt2Config) -> OpenAIGPTLMHeadModel {
|
pub fn new(p: &nn::Path, config: &Gpt2Config) -> OpenAIGPTLMHeadModel {
|
||||||
let transformer = OpenAiGptModel::new(&p, config);
|
let transformer = OpenAiGptModel::new(&p, config);
|
||||||
let lm_head = linear_no_bias(&(p / "lm_head"), config.n_embd, config.vocab_size, Default::default());
|
let lm_head = linear_no_bias(
|
||||||
OpenAIGPTLMHeadModel { transformer, lm_head }
|
&(p / "lm_head"),
|
||||||
|
config.n_embd,
|
||||||
|
config.vocab_size,
|
||||||
|
Default::default(),
|
||||||
|
);
|
||||||
|
OpenAIGPTLMHeadModel {
|
||||||
|
transformer,
|
||||||
|
lm_head,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -305,19 +359,19 @@ impl LMHeadModel for OpenAIGPTLMHeadModel {
|
|||||||
/// # Example
|
/// # Example
|
||||||
///
|
///
|
||||||
/// ```no_run
|
/// ```no_run
|
||||||
///# use tch::{nn, Device, Tensor, no_grad};
|
/// # use tch::{nn, Device, Tensor, no_grad};
|
||||||
///# use rust_bert::Config;
|
/// # use rust_bert::Config;
|
||||||
///# use std::path::Path;
|
/// # use std::path::Path;
|
||||||
///# use tch::kind::Kind::{Int64, Double};
|
/// # use tch::kind::Kind::{Int64, Double};
|
||||||
/// use rust_bert::gpt2::Gpt2Config;
|
/// use rust_bert::gpt2::Gpt2Config;
|
||||||
/// use rust_bert::openai_gpt::OpenAIGPTLMHeadModel;
|
/// use rust_bert::openai_gpt::OpenAIGPTLMHeadModel;
|
||||||
/// use rust_bert::pipelines::generation::{LMHeadModel, Cache};
|
/// use rust_bert::pipelines::generation::{LMHeadModel, Cache};
|
||||||
///# let config_path = Path::new("path/to/config.json");
|
/// # let config_path = Path::new("path/to/config.json");
|
||||||
///# let vocab_path = Path::new("path/to/vocab.txt");
|
/// # let vocab_path = Path::new("path/to/vocab.txt");
|
||||||
///# let device = Device::Cpu;
|
/// # let device = Device::Cpu;
|
||||||
///# let vs = nn::VarStore::new(device);
|
/// # let vs = nn::VarStore::new(device);
|
||||||
///# let config = Gpt2Config::from_file(config_path);
|
/// # let config = Gpt2Config::from_file(config_path);
|
||||||
///# let mut gpt_model: OpenAIGPTLMHeadModel = OpenAIGPTLMHeadModel::new(&vs.root(), &config);
|
/// # let mut gpt_model: OpenAIGPTLMHeadModel = OpenAIGPTLMHeadModel::new(&vs.root(), &config);
|
||||||
/// let (batch_size, sequence_length, past_sequence_length) = (64, 128, 56);
|
/// let (batch_size, sequence_length, past_sequence_length) = (64, 128, 56);
|
||||||
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
|
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
|
||||||
/// let attention_mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
|
/// let attention_mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
|
||||||
@ -336,29 +390,44 @@ impl LMHeadModel for OpenAIGPTLMHeadModel {
|
|||||||
/// &None,
|
/// &None,
|
||||||
/// false).unwrap()
|
/// false).unwrap()
|
||||||
/// });
|
/// });
|
||||||
///
|
|
||||||
/// ```
|
/// ```
|
||||||
///
|
fn forward_t(
|
||||||
fn forward_t(&self,
|
&self,
|
||||||
input_ids: &Option<Tensor>,
|
input_ids: &Option<Tensor>,
|
||||||
_layer_past: Cache,
|
_layer_past: Cache,
|
||||||
attention_mask: &Option<Tensor>,
|
attention_mask: &Option<Tensor>,
|
||||||
token_type_ids: &Option<Tensor>,
|
token_type_ids: &Option<Tensor>,
|
||||||
position_ids: &Option<Tensor>,
|
position_ids: &Option<Tensor>,
|
||||||
input_embeds: &Option<Tensor>,
|
input_embeds: &Option<Tensor>,
|
||||||
_encoder_outputs: Option<&Tensor>,
|
_encoder_outputs: Option<&Tensor>,
|
||||||
_decoder_input_ids: &Option<Tensor>,
|
_decoder_input_ids: &Option<Tensor>,
|
||||||
train: bool) -> Result<(Tensor, Option<Tensor>, Cache, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
|
train: bool,
|
||||||
let (output,
|
) -> Result<
|
||||||
all_hidden_states,
|
(
|
||||||
all_attentions) = self.transformer.forward_t(input_ids,
|
Tensor,
|
||||||
attention_mask,
|
Option<Tensor>,
|
||||||
token_type_ids,
|
Cache,
|
||||||
position_ids,
|
Option<Vec<Tensor>>,
|
||||||
input_embeds,
|
Option<Vec<Tensor>>,
|
||||||
train)?;
|
),
|
||||||
|
&'static str,
|
||||||
|
> {
|
||||||
|
let (output, all_hidden_states, all_attentions) = self.transformer.forward_t(
|
||||||
|
input_ids,
|
||||||
|
attention_mask,
|
||||||
|
token_type_ids,
|
||||||
|
position_ids,
|
||||||
|
input_embeds,
|
||||||
|
train,
|
||||||
|
)?;
|
||||||
|
|
||||||
let lm_logits = output.apply(&self.lm_head);
|
let lm_logits = output.apply(&self.lm_head);
|
||||||
Ok((lm_logits, None, Cache::None, all_hidden_states, all_attentions))
|
Ok((
|
||||||
|
lm_logits,
|
||||||
|
None,
|
||||||
|
Cache::None,
|
||||||
|
all_hidden_states,
|
||||||
|
all_attentions,
|
||||||
|
))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -13,9 +13,9 @@
|
|||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
use crate::gpt2::attention::Attention;
|
use crate::gpt2::attention::Attention;
|
||||||
use tch::{Tensor, nn};
|
|
||||||
use crate::gpt2::transformer::MLP;
|
use crate::gpt2::transformer::MLP;
|
||||||
use crate::gpt2::Gpt2Config;
|
use crate::gpt2::Gpt2Config;
|
||||||
|
use tch::{nn, Tensor};
|
||||||
|
|
||||||
pub struct Block {
|
pub struct Block {
|
||||||
ln_1: nn::LayerNorm,
|
ln_1: nn::LayerNorm,
|
||||||
@ -26,21 +26,33 @@ pub struct Block {
|
|||||||
|
|
||||||
impl Block {
|
impl Block {
|
||||||
pub fn new(p: &nn::Path, config: &Gpt2Config, scale: bool) -> Block {
|
pub fn new(p: &nn::Path, config: &Gpt2Config, scale: bool) -> Block {
|
||||||
let layer_norm_config = nn::LayerNormConfig { eps: config.layer_norm_epsilon, ..Default::default() };
|
let layer_norm_config = nn::LayerNormConfig {
|
||||||
|
eps: config.layer_norm_epsilon,
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
let ln_1 = nn::layer_norm(p / "ln_1", vec![config.n_embd], layer_norm_config);
|
let ln_1 = nn::layer_norm(p / "ln_1", vec![config.n_embd], layer_norm_config);
|
||||||
let ln_2 = nn::layer_norm(p / "ln_2", vec![config.n_embd], layer_norm_config);
|
let ln_2 = nn::layer_norm(p / "ln_2", vec![config.n_embd], layer_norm_config);
|
||||||
let attn = Attention::new(&(p / "attn"), config, scale);
|
let attn = Attention::new(&(p / "attn"), config, scale);
|
||||||
let mlp = MLP::new(&(p / "mlp"), config);
|
let mlp = MLP::new(&(p / "mlp"), config);
|
||||||
|
|
||||||
Block { ln_1, attn, ln_2, mlp }
|
Block {
|
||||||
|
ln_1,
|
||||||
|
attn,
|
||||||
|
ln_2,
|
||||||
|
mlp,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn forward_t(&self, x: &Tensor, attention_mask: &Option<Tensor>, train: bool)
|
pub fn forward_t(
|
||||||
-> (Tensor, Option<Tensor>) {
|
&self,
|
||||||
|
x: &Tensor,
|
||||||
|
attention_mask: &Option<Tensor>,
|
||||||
|
train: bool,
|
||||||
|
) -> (Tensor, Option<Tensor>) {
|
||||||
let (output, _, attentions) = self.attn.forward_t(x, &None, attention_mask, train);
|
let (output, _, attentions) = self.attn.forward_t(x, &None, attention_mask, train);
|
||||||
let x = (x + output).apply(&self.ln_1);
|
let x = (x + output).apply(&self.ln_1);
|
||||||
let m = self.mlp.forward_t(&x, train);
|
let m = self.mlp.forward_t(&x, train);
|
||||||
let x = (x + m).apply(&self.ln_2);
|
let x = (x + m).apply(&self.ln_2);
|
||||||
(x, attentions)
|
(x, attentions)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -19,13 +19,13 @@
|
|||||||
//!
|
//!
|
||||||
use crate::bert::BertConfig;
|
use crate::bert::BertConfig;
|
||||||
use crate::distilbert::DistilBertConfig;
|
use crate::distilbert::DistilBertConfig;
|
||||||
use rust_tokenizers::{BertTokenizer, RobertaTokenizer, TokenizedInput, TruncationStrategy};
|
|
||||||
use rust_tokenizers::preprocessing::tokenizer::base_tokenizer::Tokenizer;
|
|
||||||
use std::path::Path;
|
|
||||||
use crate::Config;
|
|
||||||
use std::collections::HashMap;
|
|
||||||
use serde::{Serialize, Deserialize};
|
|
||||||
use crate::electra::ElectraConfig;
|
use crate::electra::ElectraConfig;
|
||||||
|
use crate::Config;
|
||||||
|
use rust_tokenizers::preprocessing::tokenizer::base_tokenizer::Tokenizer;
|
||||||
|
use rust_tokenizers::{BertTokenizer, RobertaTokenizer, TokenizedInput, TruncationStrategy};
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::path::Path;
|
||||||
|
|
||||||
#[derive(Clone, Copy, Serialize, Deserialize)]
|
#[derive(Clone, Copy, Serialize, Deserialize)]
|
||||||
/// # Identifies the type of model
|
/// # Identifies the type of model
|
||||||
@ -60,25 +60,42 @@ impl ConfigOption {
|
|||||||
match model_type {
|
match model_type {
|
||||||
ModelType::Bert | ModelType::Roberta => ConfigOption::Bert(BertConfig::from_file(path)),
|
ModelType::Bert | ModelType::Roberta => ConfigOption::Bert(BertConfig::from_file(path)),
|
||||||
ModelType::DistilBert => ConfigOption::DistilBert(DistilBertConfig::from_file(path)),
|
ModelType::DistilBert => ConfigOption::DistilBert(DistilBertConfig::from_file(path)),
|
||||||
ModelType::Electra => ConfigOption::Electra(ElectraConfig::from_file(path))
|
ModelType::Electra => ConfigOption::Electra(ElectraConfig::from_file(path)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_label_mapping(self) -> HashMap<i64, String> {
|
pub fn get_label_mapping(self) -> HashMap<i64, String> {
|
||||||
match self {
|
match self {
|
||||||
Self::Bert(config) => config.id2label.expect("No label dictionary (id2label) provided in configuration file"),
|
Self::Bert(config) => config
|
||||||
Self::DistilBert(config) => config.id2label.expect("No label dictionary (id2label) provided in configuration file"),
|
.id2label
|
||||||
Self::Electra(config) => config.id2label.expect("No label dictionary (id2label) provided in configuration file"),
|
.expect("No label dictionary (id2label) provided in configuration file"),
|
||||||
|
Self::DistilBert(config) => config
|
||||||
|
.id2label
|
||||||
|
.expect("No label dictionary (id2label) provided in configuration file"),
|
||||||
|
Self::Electra(config) => config
|
||||||
|
.id2label
|
||||||
|
.expect("No label dictionary (id2label) provided in configuration file"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl TokenizerOption {
|
impl TokenizerOption {
|
||||||
/// Interface method to load a tokenizer from file
|
/// Interface method to load a tokenizer from file
|
||||||
pub fn from_file(model_type: ModelType, vocab_path: &str, merges_path: Option<&str>, lower_case: bool) -> Self {
|
pub fn from_file(
|
||||||
|
model_type: ModelType,
|
||||||
|
vocab_path: &str,
|
||||||
|
merges_path: Option<&str>,
|
||||||
|
lower_case: bool,
|
||||||
|
) -> Self {
|
||||||
match model_type {
|
match model_type {
|
||||||
ModelType::Bert | ModelType::DistilBert | ModelType::Electra => TokenizerOption::Bert(BertTokenizer::from_file(vocab_path, lower_case)),
|
ModelType::Bert | ModelType::DistilBert | ModelType::Electra => {
|
||||||
ModelType::Roberta => TokenizerOption::Roberta(RobertaTokenizer::from_file(vocab_path, merges_path.expect("No merges specified!"), lower_case)),
|
TokenizerOption::Bert(BertTokenizer::from_file(vocab_path, lower_case))
|
||||||
|
}
|
||||||
|
ModelType::Roberta => TokenizerOption::Roberta(RobertaTokenizer::from_file(
|
||||||
|
vocab_path,
|
||||||
|
merges_path.expect("No merges specified!"),
|
||||||
|
lower_case,
|
||||||
|
)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -86,15 +103,25 @@ impl TokenizerOption {
|
|||||||
pub fn model_type(&self) -> ModelType {
|
pub fn model_type(&self) -> ModelType {
|
||||||
match *self {
|
match *self {
|
||||||
Self::Bert(_) => ModelType::Bert,
|
Self::Bert(_) => ModelType::Bert,
|
||||||
Self::Roberta(_) => ModelType::Roberta
|
Self::Roberta(_) => ModelType::Roberta,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Interface method
|
/// Interface method
|
||||||
pub fn encode_list(&self, text_list: Vec<&str>, max_len: usize, truncation_strategy: &TruncationStrategy, stride: usize) -> Vec<TokenizedInput> {
|
pub fn encode_list(
|
||||||
|
&self,
|
||||||
|
text_list: Vec<&str>,
|
||||||
|
max_len: usize,
|
||||||
|
truncation_strategy: &TruncationStrategy,
|
||||||
|
stride: usize,
|
||||||
|
) -> Vec<TokenizedInput> {
|
||||||
match *self {
|
match *self {
|
||||||
Self::Bert(ref tokenizer) => tokenizer.encode_list(text_list, max_len, truncation_strategy, stride),
|
Self::Bert(ref tokenizer) => {
|
||||||
Self::Roberta(ref tokenizer) => tokenizer.encode_list(text_list, max_len, truncation_strategy, stride)
|
tokenizer.encode_list(text_list, max_len, truncation_strategy, stride)
|
||||||
|
}
|
||||||
|
Self::Roberta(ref tokenizer) => {
|
||||||
|
tokenizer.encode_list(text_list, max_len, truncation_strategy, stride)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
File diff suppressed because it is too large
Load Diff
@ -6,32 +6,29 @@
|
|||||||
//! Extractive question answering from a given question and context. DistilBERT model finetuned on SQuAD (Stanford Question Answering Dataset)
|
//! Extractive question answering from a given question and context. DistilBERT model finetuned on SQuAD (Stanford Question Answering Dataset)
|
||||||
//!
|
//!
|
||||||
//! ```no_run
|
//! ```no_run
|
||||||
//! use rust_bert::pipelines::question_answering::{QuestionAnsweringModel, QaInput};
|
//! use rust_bert::pipelines::question_answering::{QaInput, QuestionAnsweringModel};
|
||||||
//!# fn main() -> failure::Fallible<()> {
|
//! # fn main() -> failure::Fallible<()> {
|
||||||
//! let qa_model = QuestionAnsweringModel::new(Default::default())?;
|
//! let qa_model = QuestionAnsweringModel::new(Default::default())?;
|
||||||
//!
|
//!
|
||||||
//! let question = String::from("Where does Amy live ?");
|
//! let question = String::from("Where does Amy live ?");
|
||||||
//! let context = String::from("Amy lives in Amsterdam");
|
//! let context = String::from("Amy lives in Amsterdam");
|
||||||
//!
|
//!
|
||||||
//! let answers = qa_model.predict(&vec!(QaInput { question, context }), 1, 32);
|
//! let answers = qa_model.predict(&vec![QaInput { question, context }], 1, 32);
|
||||||
//!# Ok(())
|
//! # Ok(())
|
||||||
//!# }
|
//! # }
|
||||||
//! ```
|
//! ```
|
||||||
//!
|
//!
|
||||||
//! Output: \
|
//! Output: \
|
||||||
//! ```no_run
|
//! ```no_run
|
||||||
//!# use rust_bert::pipelines::question_answering::Answer;
|
//! # use rust_bert::pipelines::question_answering::Answer;
|
||||||
//!# let output =
|
//! # let output =
|
||||||
//! [
|
//! [Answer {
|
||||||
//! Answer {
|
//! score: 0.9976,
|
||||||
//! score: 0.9976,
|
//! start: 13,
|
||||||
//! start: 13,
|
//! end: 21,
|
||||||
//! end: 21,
|
//! answer: "Amsterdam", //#### # .to_owned()
|
||||||
//! answer: "Amsterdam"
|
//! }]
|
||||||
//!# .to_owned()
|
//! # ;
|
||||||
//! }
|
|
||||||
//! ]
|
|
||||||
//!# ;
|
|
||||||
//! ```
|
//! ```
|
||||||
//!
|
//!
|
||||||
//! #### 2. Translation
|
//! #### 2. Translation
|
||||||
@ -46,74 +43,75 @@
|
|||||||
//! - English <-> Russian
|
//! - English <-> Russian
|
||||||
//! - French <-> German
|
//! - French <-> German
|
||||||
//! ```no_run
|
//! ```no_run
|
||||||
//!# fn main() -> failure::Fallible<()> {
|
//! # fn main() -> failure::Fallible<()> {
|
||||||
//!# use rust_bert::pipelines::generation::LanguageGenerator;
|
//! # use rust_bert::pipelines::generation::LanguageGenerator;
|
||||||
//! use rust_bert::pipelines::translation::{TranslationModel, TranslationConfig, Language};
|
//! use rust_bert::pipelines::translation::{Language, TranslationConfig, TranslationModel};
|
||||||
//! use tch::Device;
|
//! use tch::Device;
|
||||||
//! let translation_config = TranslationConfig::new(Language::EnglishToFrench, Device::cuda_if_available());
|
//! let translation_config =
|
||||||
|
//! TranslationConfig::new(Language::EnglishToFrench, Device::cuda_if_available());
|
||||||
//! let mut model = TranslationModel::new(translation_config)?;
|
//! let mut model = TranslationModel::new(translation_config)?;
|
||||||
//!
|
//!
|
||||||
//! let input = ["This is a sentence to be translated"];
|
//! let input = ["This is a sentence to be translated"];
|
||||||
//!
|
//!
|
||||||
//! let output = model.translate(&input);
|
//! let output = model.translate(&input);
|
||||||
//!# Ok(())
|
//! # Ok(())
|
||||||
//!# }
|
//! # }
|
||||||
//! ```
|
//! ```
|
||||||
//!
|
//!
|
||||||
//! Output: \
|
//! Output: \
|
||||||
//! ```no_run
|
//! ```no_run
|
||||||
//!# let output =
|
//! # let output =
|
||||||
//! "Il s'agit d'une phrase à traduire"
|
//! "Il s'agit d'une phrase à traduire"
|
||||||
//!# ;
|
//! # ;
|
||||||
//!```
|
//! ```
|
||||||
//!
|
//!
|
||||||
//! #### 3. Summarization
|
//! #### 3. Summarization
|
||||||
//! Abstractive summarization of texts based on the BART encoder-decoder architecture
|
//! Abstractive summarization of texts based on the BART encoder-decoder architecture
|
||||||
//! Include techniques such as beam search, top-k and nucleus sampling, temperature setting and repetition penalty.
|
//! Include techniques such as beam search, top-k and nucleus sampling, temperature setting and repetition penalty.
|
||||||
//!
|
//!
|
||||||
//! ```no_run
|
//! ```no_run
|
||||||
//!# fn main() -> failure::Fallible<()> {
|
//! # fn main() -> failure::Fallible<()> {
|
||||||
//!# use rust_bert::pipelines::generation::LanguageGenerator;
|
//! # use rust_bert::pipelines::generation::LanguageGenerator;
|
||||||
//! use rust_bert::pipelines::summarization::SummarizationModel;
|
//! use rust_bert::pipelines::summarization::SummarizationModel;
|
||||||
//!
|
//!
|
||||||
//! let mut model = SummarizationModel::new(Default::default())?;
|
//! let mut model = SummarizationModel::new(Default::default())?;
|
||||||
//!
|
//!
|
||||||
//! let input = ["In findings published Tuesday in Cornell University's arXiv by a team of scientists
|
//! let input = ["In findings published Tuesday in Cornell University's arXiv by a team of scientists
|
||||||
//!from the University of Montreal and a separate report published Wednesday in Nature Astronomy by a team
|
//! from the University of Montreal and a separate report published Wednesday in Nature Astronomy by a team
|
||||||
//!from University College London (UCL), the presence of water vapour was confirmed in the atmosphere of K2-18b,
|
//! from University College London (UCL), the presence of water vapour was confirmed in the atmosphere of K2-18b,
|
||||||
//!a planet circling a star in the constellation Leo. This is the first such discovery in a planet in its star's
|
//! a planet circling a star in the constellation Leo. This is the first such discovery in a planet in its star's
|
||||||
//!habitable zone — not too hot and not too cold for liquid water to exist. The Montreal team, led by Björn Benneke,
|
//! habitable zone — not too hot and not too cold for liquid water to exist. The Montreal team, led by Björn Benneke,
|
||||||
//!used data from the NASA's Hubble telescope to assess changes in the light coming from K2-18b's star as the planet
|
//! used data from the NASA's Hubble telescope to assess changes in the light coming from K2-18b's star as the planet
|
||||||
//!passed between it and Earth. They found that certain wavelengths of light, which are usually absorbed by water,
|
//! passed between it and Earth. They found that certain wavelengths of light, which are usually absorbed by water,
|
||||||
//!weakened when the planet was in the way, indicating not only does K2-18b have an atmosphere, but the atmosphere
|
//! weakened when the planet was in the way, indicating not only does K2-18b have an atmosphere, but the atmosphere
|
||||||
//!contains water in vapour form. The team from UCL then analyzed the Montreal team's data using their own software
|
//! contains water in vapour form. The team from UCL then analyzed the Montreal team's data using their own software
|
||||||
//!and confirmed their conclusion. This was not the first time scientists have found signs of water on an exoplanet,
|
//! and confirmed their conclusion. This was not the first time scientists have found signs of water on an exoplanet,
|
||||||
//!but previous discoveries were made on planets with high temperatures or other pronounced differences from Earth.
|
//! but previous discoveries were made on planets with high temperatures or other pronounced differences from Earth.
|
||||||
//!\"This is the first potentially habitable planet where the temperature is right and where we now know there is water,\"
|
//! \"This is the first potentially habitable planet where the temperature is right and where we now know there is water,\"
|
||||||
//!said UCL astronomer Angelos Tsiaras. \"It's the best candidate for habitability right now.\" \"It's a good sign\",
|
//! said UCL astronomer Angelos Tsiaras. \"It's the best candidate for habitability right now.\" \"It's a good sign\",
|
||||||
//!said Ryan Cloutier of the Harvard–Smithsonian Center for Astrophysics, who was not one of either study's authors.
|
//! said Ryan Cloutier of the Harvard–Smithsonian Center for Astrophysics, who was not one of either study's authors.
|
||||||
//!\"Overall,\" he continued, \"the presence of water in its atmosphere certainly improves the prospect of K2-18b being
|
//! \"Overall,\" he continued, \"the presence of water in its atmosphere certainly improves the prospect of K2-18b being
|
||||||
//!a potentially habitable planet, but further observations will be required to say for sure. \"
|
//! a potentially habitable planet, but further observations will be required to say for sure. \"
|
||||||
//!K2-18b was first identified in 2015 by the Kepler space telescope. It is about 110 light-years from Earth and larger
|
//! K2-18b was first identified in 2015 by the Kepler space telescope. It is about 110 light-years from Earth and larger
|
||||||
//!but less dense. Its star, a red dwarf, is cooler than the Sun, but the planet's orbit is much closer, such that a year
|
//! but less dense. Its star, a red dwarf, is cooler than the Sun, but the planet's orbit is much closer, such that a year
|
||||||
//!on K2-18b lasts 33 Earth days. According to The Guardian, astronomers were optimistic that NASA's James Webb space
|
//! on K2-18b lasts 33 Earth days. According to The Guardian, astronomers were optimistic that NASA's James Webb space
|
||||||
//!telescope — scheduled for launch in 2021 — and the European Space Agency's 2028 ARIEL program, could reveal more
|
//! telescope — scheduled for launch in 2021 — and the European Space Agency's 2028 ARIEL program, could reveal more
|
||||||
//!about exoplanets like K2-18b."];
|
//! about exoplanets like K2-18b."];
|
||||||
//!
|
//!
|
||||||
//! let output = model.summarize(&input);
|
//! let output = model.summarize(&input);
|
||||||
//!# Ok(())
|
//! # Ok(())
|
||||||
//!# }
|
//! # }
|
||||||
//! ```
|
//! ```
|
||||||
//! (example from: [WikiNews](https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b))
|
//! (example from: [WikiNews](https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b))
|
||||||
//!
|
//!
|
||||||
//! Example output: \
|
//! Example output: \
|
||||||
//! ```no_run
|
//! ```no_run
|
||||||
//!# let output =
|
//! # let output =
|
||||||
//! "Scientists have found water vapour on K2-18b, a planet 110 light-years from Earth.
|
//! "Scientists have found water vapour on K2-18b, a planet 110 light-years from Earth.
|
||||||
//! This is the first such discovery in a planet in its star's habitable zone.
|
//! This is the first such discovery in a planet in its star's habitable zone.
|
||||||
//! The planet is not too hot and not too cold for liquid water to exist."
|
//! The planet is not too hot and not too cold for liquid water to exist."
|
||||||
//!# ;
|
//! # ;
|
||||||
//!```
|
//! ```
|
||||||
//!
|
//!
|
||||||
//!
|
//!
|
||||||
//! #### 4. Natural Language Generation
|
//! #### 4. Natural Language Generation
|
||||||
@ -124,18 +122,18 @@
|
|||||||
//!
|
//!
|
||||||
//! ```no_run
|
//! ```no_run
|
||||||
//! use rust_bert::pipelines::generation::GPT2Generator;
|
//! use rust_bert::pipelines::generation::GPT2Generator;
|
||||||
//!# fn main() -> failure::Fallible<()> {
|
//! # fn main() -> failure::Fallible<()> {
|
||||||
//!# use rust_bert::pipelines::generation::LanguageGenerator;
|
//! # use rust_bert::pipelines::generation::LanguageGenerator;
|
||||||
//! let mut model = GPT2Generator::new(Default::default())?;
|
//! let mut model = GPT2Generator::new(Default::default())?;
|
||||||
//! let input_context_1 = "The dog";
|
//! let input_context_1 = "The dog";
|
||||||
//! let input_context_2 = "The cat was";
|
//! let input_context_2 = "The cat was";
|
||||||
//! let output = model.generate(Some(vec!(input_context_1, input_context_2)), None);
|
//! let output = model.generate(Some(vec![input_context_1, input_context_2]), None);
|
||||||
//!# Ok(())
|
//! # Ok(())
|
||||||
//!# }
|
//! # }
|
||||||
//! ```
|
//! ```
|
||||||
//! Example output: \
|
//! Example output: \
|
||||||
//! ```no_run
|
//! ```no_run
|
||||||
//!# let output =
|
//! # let output =
|
||||||
//! [
|
//! [
|
||||||
//! "The dog's owners, however, did not want to be named. According to the lawsuit, the animal's owner, a 29-year",
|
//! "The dog's owners, however, did not want to be named. According to the lawsuit, the animal's owner, a 29-year",
|
||||||
//! "The dog has always been part of the family. \"He was always going to be my dog and he was always looking out for me",
|
//! "The dog has always been part of the family. \"He was always going to be my dog and he was always looking out for me",
|
||||||
@ -144,14 +142,14 @@
|
|||||||
//! "The cat was pulled from the street by two-year-old Jazmine.\"I didn't know what to do,\" she said",
|
//! "The cat was pulled from the street by two-year-old Jazmine.\"I didn't know what to do,\" she said",
|
||||||
//! "The cat was attacked by two stray dogs and was taken to a hospital. Two other cats were also injured in the attack and are being treated."
|
//! "The cat was attacked by two stray dogs and was taken to a hospital. Two other cats were also injured in the attack and are being treated."
|
||||||
//! ]
|
//! ]
|
||||||
//!# ;
|
//! # ;
|
||||||
//!```
|
//! ```
|
||||||
//!
|
//!
|
||||||
//! #### 5. Sentiment analysis
|
//! #### 5. Sentiment analysis
|
||||||
//! Predicts the binary sentiment for a sentence. DistilBERT model finetuned on SST-2.
|
//! Predicts the binary sentiment for a sentence. DistilBERT model finetuned on SST-2.
|
||||||
//! ```no_run
|
//! ```no_run
|
||||||
//! use rust_bert::pipelines::sentiment::SentimentModel;
|
//! use rust_bert::pipelines::sentiment::SentimentModel;
|
||||||
//!# fn main() -> failure::Fallible<()> {
|
//! # fn main() -> failure::Fallible<()> {
|
||||||
//! let sentiment_model = SentimentModel::new(Default::default())?;
|
//! let sentiment_model = SentimentModel::new(Default::default())?;
|
||||||
//! let input = [
|
//! let input = [
|
||||||
//! "Probably my all-time favorite movie, a story of selflessness, sacrifice and dedication to a noble cause, but it's not preachy or boring.",
|
//! "Probably my all-time favorite movie, a story of selflessness, sacrifice and dedication to a noble cause, but it's not preachy or boring.",
|
||||||
@ -159,59 +157,84 @@
|
|||||||
//! "If you like original gut wrenching laughter you will like this movie. If you are young or old then you will love this movie, hell even my mom liked it.",
|
//! "If you like original gut wrenching laughter you will like this movie. If you are young or old then you will love this movie, hell even my mom liked it.",
|
||||||
//! ];
|
//! ];
|
||||||
//! let output = sentiment_model.predict(&input);
|
//! let output = sentiment_model.predict(&input);
|
||||||
//!# Ok(())
|
//! # Ok(())
|
||||||
//!# }
|
//! # }
|
||||||
//! ```
|
//! ```
|
||||||
//! (Example courtesy of [IMDb](http://www.imdb.com))
|
//! (Example courtesy of [IMDb](http://www.imdb.com))
|
||||||
//!
|
//!
|
||||||
//! Output: \
|
//! Output: \
|
||||||
//! ```no_run
|
//! ```no_run
|
||||||
//!# use rust_bert::pipelines::sentiment::Sentiment;
|
//! # use rust_bert::pipelines::sentiment::Sentiment;
|
||||||
//!# use rust_bert::pipelines::sentiment::SentimentPolarity::{Positive, Negative};
|
//! # use rust_bert::pipelines::sentiment::SentimentPolarity::{Positive, Negative};
|
||||||
//!# let output =
|
//! # let output =
|
||||||
//! [
|
//! [
|
||||||
//! Sentiment { polarity: Positive, score: 0.998 },
|
//! Sentiment {
|
||||||
//! Sentiment { polarity: Negative, score: 0.992 },
|
//! polarity: Positive,
|
||||||
//! Sentiment { polarity: Positive, score: 0.999 }
|
//! score: 0.998,
|
||||||
|
//! },
|
||||||
|
//! Sentiment {
|
||||||
|
//! polarity: Negative,
|
||||||
|
//! score: 0.992,
|
||||||
|
//! },
|
||||||
|
//! Sentiment {
|
||||||
|
//! polarity: Positive,
|
||||||
|
//! score: 0.999,
|
||||||
|
//! },
|
||||||
//! ]
|
//! ]
|
||||||
//!# ;
|
//! # ;
|
||||||
//! ```
|
//! ```
|
||||||
//!
|
//!
|
||||||
//! #### 6. Named Entity Recognition
|
//! #### 6. Named Entity Recognition
|
||||||
//! Extracts entities (Person, Location, Organization, Miscellaneous) from text. BERT cased large model finetuned on CoNNL03, contributed by the [MDZ Digital Library team at the Bavarian State Library](https://github.com/dbmdz)
|
//! Extracts entities (Person, Location, Organization, Miscellaneous) from text. BERT cased large model finetuned on CoNNL03, contributed by the [MDZ Digital Library team at the Bavarian State Library](https://github.com/dbmdz)
|
||||||
//! ```no_run
|
//! ```no_run
|
||||||
//! use rust_bert::pipelines::ner::NERModel;
|
//! use rust_bert::pipelines::ner::NERModel;
|
||||||
//!# fn main() -> failure::Fallible<()> {
|
//! # fn main() -> failure::Fallible<()> {
|
||||||
//! let ner_model = NERModel::new(Default::default())?;
|
//! let ner_model = NERModel::new(Default::default())?;
|
||||||
//! let input = [
|
//! let input = [
|
||||||
//! "My name is Amy. I live in Paris.",
|
//! "My name is Amy. I live in Paris.",
|
||||||
//! "Paris is a city in France."
|
//! "Paris is a city in France.",
|
||||||
//! ];
|
//! ];
|
||||||
//! let output = ner_model.predict(&input);
|
//! let output = ner_model.predict(&input);
|
||||||
//!# Ok(())
|
//! # Ok(())
|
||||||
//!# }
|
//! # }
|
||||||
//! ```
|
//! ```
|
||||||
//! Output: \
|
//! Output: \
|
||||||
//! ```no_run
|
//! ```no_run
|
||||||
//!# use rust_bert::pipelines::question_answering::Answer;
|
//! # use rust_bert::pipelines::question_answering::Answer;
|
||||||
//!# use rust_bert::pipelines::ner::Entity;
|
//! # use rust_bert::pipelines::ner::Entity;
|
||||||
//!# let output =
|
//! # let output =
|
||||||
//! [
|
//! [
|
||||||
//! Entity { word: String::from("Amy"), score: 0.9986, label: String::from("I-PER") },
|
//! Entity {
|
||||||
//! Entity { word: String::from("Paris"), score: 0.9985, label: String::from("I-LOC") },
|
//! word: String::from("Amy"),
|
||||||
//! Entity { word: String::from("Paris"), score: 0.9988, label: String::from("I-LOC") },
|
//! score: 0.9986,
|
||||||
//! Entity { word: String::from("France"), score: 0.9993, label: String::from("I-LOC") },
|
//! label: String::from("I-PER"),
|
||||||
|
//! },
|
||||||
|
//! Entity {
|
||||||
|
//! word: String::from("Paris"),
|
||||||
|
//! score: 0.9985,
|
||||||
|
//! label: String::from("I-LOC"),
|
||||||
|
//! },
|
||||||
|
//! Entity {
|
||||||
|
//! word: String::from("Paris"),
|
||||||
|
//! score: 0.9988,
|
||||||
|
//! label: String::from("I-LOC"),
|
||||||
|
//! },
|
||||||
|
//! Entity {
|
||||||
|
//! word: String::from("France"),
|
||||||
|
//! score: 0.9993,
|
||||||
|
//! label: String::from("I-LOC"),
|
||||||
|
//! },
|
||||||
//! ]
|
//! ]
|
||||||
//!# ;
|
//! # ;
|
||||||
//! ```
|
//! ```
|
||||||
//!
|
//!
|
||||||
|
|
||||||
pub mod sentiment;
|
|
||||||
pub mod common;
|
pub mod common;
|
||||||
pub mod token_classification;
|
pub mod generation;
|
||||||
pub mod sequence_classification;
|
|
||||||
pub mod ner;
|
pub mod ner;
|
||||||
pub mod question_answering;
|
pub mod question_answering;
|
||||||
pub mod generation;
|
pub mod sentiment;
|
||||||
|
pub mod sequence_classification;
|
||||||
pub mod summarization;
|
pub mod summarization;
|
||||||
|
pub mod token_classification;
|
||||||
pub mod translation;
|
pub mod translation;
|
||||||
|
@ -20,33 +20,48 @@
|
|||||||
//!
|
//!
|
||||||
//! ```no_run
|
//! ```no_run
|
||||||
//! use rust_bert::pipelines::ner::NERModel;
|
//! use rust_bert::pipelines::ner::NERModel;
|
||||||
//!# fn main() -> failure::Fallible<()> {
|
//! # fn main() -> failure::Fallible<()> {
|
||||||
//! let ner_model = NERModel::new(Default::default())?;
|
//! let ner_model = NERModel::new(Default::default())?;
|
||||||
//!
|
//!
|
||||||
//! let input = [
|
//! let input = [
|
||||||
//! "My name is Amy. I live in Paris.",
|
//! "My name is Amy. I live in Paris.",
|
||||||
//! "Paris is a city in France."
|
//! "Paris is a city in France.",
|
||||||
//! ];
|
//! ];
|
||||||
//! let output = ner_model.predict(&input);
|
//! let output = ner_model.predict(&input);
|
||||||
//!# Ok(())
|
//! # Ok(())
|
||||||
//!# }
|
//! # }
|
||||||
//! ```
|
//! ```
|
||||||
//! Output: \
|
//! Output: \
|
||||||
//! ```no_run
|
//! ```no_run
|
||||||
//!# use rust_bert::pipelines::question_answering::Answer;
|
//! # use rust_bert::pipelines::question_answering::Answer;
|
||||||
//!# use rust_bert::pipelines::ner::Entity;
|
//! # use rust_bert::pipelines::ner::Entity;
|
||||||
//!# let output =
|
//! # let output =
|
||||||
//! [
|
//! [
|
||||||
//! Entity { word: String::from("Amy"), score: 0.9986, label: String::from("I-PER") },
|
//! Entity {
|
||||||
//! Entity { word: String::from("Paris"), score: 0.9985, label: String::from("I-LOC") },
|
//! word: String::from("Amy"),
|
||||||
//! Entity { word: String::from("Paris"), score: 0.9988, label: String::from("I-LOC") },
|
//! score: 0.9986,
|
||||||
//! Entity { word: String::from("France"), score: 0.9993, label: String::from("I-LOC") },
|
//! label: String::from("I-PER"),
|
||||||
|
//! },
|
||||||
|
//! Entity {
|
||||||
|
//! word: String::from("Paris"),
|
||||||
|
//! score: 0.9985,
|
||||||
|
//! label: String::from("I-LOC"),
|
||||||
|
//! },
|
||||||
|
//! Entity {
|
||||||
|
//! word: String::from("Paris"),
|
||||||
|
//! score: 0.9988,
|
||||||
|
//! label: String::from("I-LOC"),
|
||||||
|
//! },
|
||||||
|
//! Entity {
|
||||||
|
//! word: String::from("France"),
|
||||||
|
//! score: 0.9993,
|
||||||
|
//! label: String::from("I-LOC"),
|
||||||
|
//! },
|
||||||
//! ]
|
//! ]
|
||||||
//!# ;
|
//! # ;
|
||||||
//! ```
|
//! ```
|
||||||
|
|
||||||
use crate::pipelines::token_classification::{TokenClassificationModel, TokenClassificationConfig};
|
use crate::pipelines::token_classification::{TokenClassificationConfig, TokenClassificationModel};
|
||||||
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
/// # Entity generated by a `NERModel`
|
/// # Entity generated by a `NERModel`
|
||||||
@ -64,7 +79,7 @@ type NERConfig = TokenClassificationConfig;
|
|||||||
|
|
||||||
/// # NERModel to extract named entities
|
/// # NERModel to extract named entities
|
||||||
pub struct NERModel {
|
pub struct NERModel {
|
||||||
token_classification_model: TokenClassificationModel
|
token_classification_model: TokenClassificationModel,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl NERModel {
|
impl NERModel {
|
||||||
@ -77,17 +92,18 @@ impl NERModel {
|
|||||||
/// # Example
|
/// # Example
|
||||||
///
|
///
|
||||||
/// ```no_run
|
/// ```no_run
|
||||||
///# fn main() -> failure::Fallible<()> {
|
/// # fn main() -> failure::Fallible<()> {
|
||||||
/// use rust_bert::pipelines::ner::NERModel;
|
/// use rust_bert::pipelines::ner::NERModel;
|
||||||
///
|
///
|
||||||
/// let ner_model = NERModel::new(Default::default())?;
|
/// let ner_model = NERModel::new(Default::default())?;
|
||||||
///# Ok(())
|
/// # Ok(())
|
||||||
///# }
|
/// # }
|
||||||
/// ```
|
/// ```
|
||||||
///
|
|
||||||
pub fn new(ner_config: NERConfig) -> failure::Fallible<NERModel> {
|
pub fn new(ner_config: NERConfig) -> failure::Fallible<NERModel> {
|
||||||
let model = TokenClassificationModel::new(ner_config)?;
|
let model = TokenClassificationModel::new(ner_config)?;
|
||||||
Ok(NERModel { token_classification_model: model })
|
Ok(NERModel {
|
||||||
|
token_classification_model: model,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Extract entities from a text
|
/// Extract entities from a text
|
||||||
@ -103,30 +119,28 @@ impl NERModel {
|
|||||||
/// # Example
|
/// # Example
|
||||||
///
|
///
|
||||||
/// ```no_run
|
/// ```no_run
|
||||||
///# fn main() -> failure::Fallible<()> {
|
/// # fn main() -> failure::Fallible<()> {
|
||||||
///# use rust_bert::pipelines::ner::NERModel;
|
/// # use rust_bert::pipelines::ner::NERModel;
|
||||||
///
|
///
|
||||||
/// let ner_model = NERModel::new(Default::default())?;
|
/// let ner_model = NERModel::new(Default::default())?;
|
||||||
/// let input = [
|
/// let input = [
|
||||||
/// "My name is Amy. I live in Paris.",
|
/// "My name is Amy. I live in Paris.",
|
||||||
/// "Paris is a city in France."
|
/// "Paris is a city in France.",
|
||||||
/// ];
|
/// ];
|
||||||
/// let output = ner_model.predict(&input);
|
/// let output = ner_model.predict(&input);
|
||||||
///# Ok(())
|
/// # Ok(())
|
||||||
///# }
|
/// # }
|
||||||
/// ```
|
/// ```
|
||||||
///
|
|
||||||
pub fn predict(&self, input: &[&str]) -> Vec<Entity> {
|
pub fn predict(&self, input: &[&str]) -> Vec<Entity> {
|
||||||
self.token_classification_model
|
self.token_classification_model
|
||||||
.predict(input, true, false)
|
.predict(input, true, false)
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.filter(|token| token.label != "O")
|
.filter(|token| token.label != "O")
|
||||||
.map(|token| {
|
.map(|token| Entity {
|
||||||
Entity {
|
word: token.text,
|
||||||
word: token.text,
|
score: token.score,
|
||||||
score: token.score,
|
label: token.label,
|
||||||
label: token.label,
|
})
|
||||||
}
|
.collect()
|
||||||
}).collect()
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -17,48 +17,48 @@
|
|||||||
//! The dependencies will be downloaded to the user's home directory, under ~/.cache/.rustbert/distilbert-qa
|
//! The dependencies will be downloaded to the user's home directory, under ~/.cache/.rustbert/distilbert-qa
|
||||||
//!
|
//!
|
||||||
//! ```no_run
|
//! ```no_run
|
||||||
//! use rust_bert::pipelines::question_answering::{QuestionAnsweringModel, QaInput};
|
//! use rust_bert::pipelines::question_answering::{QaInput, QuestionAnsweringModel};
|
||||||
//!
|
//!
|
||||||
//!# fn main() -> failure::Fallible<()> {
|
//! # fn main() -> failure::Fallible<()> {
|
||||||
//! let qa_model = QuestionAnsweringModel::new(Default::default())?;
|
//! let qa_model = QuestionAnsweringModel::new(Default::default())?;
|
||||||
//!
|
//!
|
||||||
//! let question = String::from("Where does Amy live ?");
|
//! let question = String::from("Where does Amy live ?");
|
||||||
//! let context = String::from("Amy lives in Amsterdam");
|
//! let context = String::from("Amy lives in Amsterdam");
|
||||||
//!
|
//!
|
||||||
//! let answers = qa_model.predict(&vec!(QaInput { question, context }), 1, 32);
|
//! let answers = qa_model.predict(&vec![QaInput { question, context }], 1, 32);
|
||||||
//!# Ok(())
|
//! # Ok(())
|
||||||
//!# }
|
//! # }
|
||||||
//! ```
|
//! ```
|
||||||
//!
|
//!
|
||||||
//! Output: \
|
//! Output: \
|
||||||
//! ```no_run
|
//! ```no_run
|
||||||
//!# use rust_bert::pipelines::question_answering::Answer;
|
//! # use rust_bert::pipelines::question_answering::Answer;
|
||||||
//!# let output =
|
//! # let output =
|
||||||
//! [
|
//! [Answer {
|
||||||
//! Answer {
|
//! score: 0.9976,
|
||||||
//! score: 0.9976,
|
//! start: 13,
|
||||||
//! start: 13,
|
//! end: 21,
|
||||||
//! end: 21,
|
//! answer: "Amsterdam", //#### # .to_owned()
|
||||||
//! answer: "Amsterdam"
|
//! }]
|
||||||
//!# .to_owned()
|
//! # ;
|
||||||
//! }
|
|
||||||
//! ]
|
|
||||||
//!# ;
|
|
||||||
//! ```
|
//! ```
|
||||||
|
|
||||||
use rust_tokenizers::{BertTokenizer, Tokenizer, TruncationStrategy, TokenizedInput};
|
use crate::common::resources::{download_resource, RemoteResource, Resource};
|
||||||
use rust_tokenizers::preprocessing::tokenizer::base_tokenizer::Mask;
|
use crate::distilbert::{
|
||||||
use tch::{Device, Tensor, no_grad};
|
DistilBertConfig, DistilBertConfigResources, DistilBertForQuestionAnswering,
|
||||||
use std::path::PathBuf;
|
DistilBertModelResources, DistilBertVocabResources,
|
||||||
use rust_tokenizers::tokenization_utils::truncate_sequences;
|
};
|
||||||
use std::collections::HashMap;
|
|
||||||
use std::cmp::min;
|
|
||||||
use tch::nn::VarStore;
|
|
||||||
use tch::kind::Kind::Float;
|
|
||||||
use std::fs;
|
|
||||||
use crate::Config;
|
use crate::Config;
|
||||||
use crate::distilbert::{DistilBertForQuestionAnswering, DistilBertConfig, DistilBertConfigResources, DistilBertModelResources, DistilBertVocabResources};
|
use rust_tokenizers::preprocessing::tokenizer::base_tokenizer::Mask;
|
||||||
use crate::common::resources::{Resource, RemoteResource, download_resource};
|
use rust_tokenizers::tokenization_utils::truncate_sequences;
|
||||||
|
use rust_tokenizers::{BertTokenizer, TokenizedInput, Tokenizer, TruncationStrategy};
|
||||||
|
use std::cmp::min;
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::fs;
|
||||||
|
use std::path::PathBuf;
|
||||||
|
use tch::kind::Kind::Float;
|
||||||
|
use tch::nn::VarStore;
|
||||||
|
use tch::{no_grad, Device, Tensor};
|
||||||
|
|
||||||
/// # Input for Question Answering
|
/// # Input for Question Answering
|
||||||
/// Includes a context (containing the answer) and question strings
|
/// Includes a context (containing the answer) and question strings
|
||||||
@ -84,7 +84,6 @@ struct QaFeature {
|
|||||||
pub token_to_orig_map: HashMap<i64, i64>,
|
pub token_to_orig_map: HashMap<i64, i64>,
|
||||||
pub p_mask: Vec<i8>,
|
pub p_mask: Vec<i8>,
|
||||||
pub example_index: i64,
|
pub example_index: i64,
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
@ -102,34 +101,38 @@ pub struct Answer {
|
|||||||
|
|
||||||
impl PartialEq for Answer {
|
impl PartialEq for Answer {
|
||||||
fn eq(&self, other: &Self) -> bool {
|
fn eq(&self, other: &Self) -> bool {
|
||||||
(self.start == other.start) &&
|
(self.start == other.start) && (self.end == other.end) && (self.answer == other.answer)
|
||||||
(self.end == other.end) &&
|
|
||||||
(self.answer == other.answer)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn remove_duplicates<T: PartialEq + Clone>(vector: &mut Vec<T>) -> &mut Vec<T> {
|
fn remove_duplicates<T: PartialEq + Clone>(vector: &mut Vec<T>) -> &mut Vec<T> {
|
||||||
let mut potential_duplicates = vec!();
|
let mut potential_duplicates = vec![];
|
||||||
vector.retain(|item| if potential_duplicates.contains(item) {
|
vector.retain(|item| {
|
||||||
false
|
if potential_duplicates.contains(item) {
|
||||||
} else {
|
false
|
||||||
potential_duplicates.push(item.clone());
|
} else {
|
||||||
true
|
potential_duplicates.push(item.clone());
|
||||||
|
true
|
||||||
|
}
|
||||||
});
|
});
|
||||||
vector
|
vector
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
impl QaExample {
|
impl QaExample {
|
||||||
pub fn new(question: &str, context: &str) -> QaExample {
|
pub fn new(question: &str, context: &str) -> QaExample {
|
||||||
let question = question.to_owned();
|
let question = question.to_owned();
|
||||||
let (doc_tokens, char_to_word_offset) = QaExample::split_context(context);
|
let (doc_tokens, char_to_word_offset) = QaExample::split_context(context);
|
||||||
QaExample { question, context: context.to_owned(), doc_tokens, char_to_word_offset }
|
QaExample {
|
||||||
|
question,
|
||||||
|
context: context.to_owned(),
|
||||||
|
doc_tokens,
|
||||||
|
char_to_word_offset,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn split_context(context: &str) -> (Vec<String>, Vec<i64>) {
|
fn split_context(context: &str) -> (Vec<String>, Vec<i64>) {
|
||||||
let mut doc_tokens: Vec<String> = vec!();
|
let mut doc_tokens: Vec<String> = vec![];
|
||||||
let mut char_to_word_offset: Vec<i64> = vec!();
|
let mut char_to_word_offset: Vec<i64> = vec![];
|
||||||
let max_length = context.len();
|
let max_length = context.len();
|
||||||
let mut current_word = String::with_capacity(max_length);
|
let mut current_word = String::with_capacity(max_length);
|
||||||
let mut previous_whitespace = false;
|
let mut previous_whitespace = false;
|
||||||
@ -158,11 +161,11 @@ impl QaExample {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn is_whitespace(character: &char) -> bool {
|
fn is_whitespace(character: &char) -> bool {
|
||||||
(character == &' ') |
|
(character == &' ')
|
||||||
(character == &'\t') |
|
| (character == &'\t')
|
||||||
(character == &'\r') |
|
| (character == &'\r')
|
||||||
(character == &'\n') |
|
| (character == &'\n')
|
||||||
(*character as u32 == 0x202F)
|
| (*character as u32 == 0x202F)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -182,9 +185,15 @@ pub struct QuestionAnsweringConfig {
|
|||||||
impl Default for QuestionAnsweringConfig {
|
impl Default for QuestionAnsweringConfig {
|
||||||
fn default() -> QuestionAnsweringConfig {
|
fn default() -> QuestionAnsweringConfig {
|
||||||
QuestionAnsweringConfig {
|
QuestionAnsweringConfig {
|
||||||
model_resource: Resource::Remote(RemoteResource::from_pretrained(DistilBertModelResources::DISTIL_BERT_SQUAD)),
|
model_resource: Resource::Remote(RemoteResource::from_pretrained(
|
||||||
config_resource: Resource::Remote(RemoteResource::from_pretrained(DistilBertConfigResources::DISTIL_BERT_SQUAD)),
|
DistilBertModelResources::DISTIL_BERT_SQUAD,
|
||||||
vocab_resource: Resource::Remote(RemoteResource::from_pretrained(DistilBertVocabResources::DISTIL_BERT_SQUAD)),
|
)),
|
||||||
|
config_resource: Resource::Remote(RemoteResource::from_pretrained(
|
||||||
|
DistilBertConfigResources::DISTIL_BERT_SQUAD,
|
||||||
|
)),
|
||||||
|
vocab_resource: Resource::Remote(RemoteResource::from_pretrained(
|
||||||
|
DistilBertVocabResources::DISTIL_BERT_SQUAD,
|
||||||
|
)),
|
||||||
device: Device::cuda_if_available(),
|
device: Device::cuda_if_available(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -213,26 +222,33 @@ impl QuestionAnsweringModel {
|
|||||||
/// # Example
|
/// # Example
|
||||||
///
|
///
|
||||||
/// ```no_run
|
/// ```no_run
|
||||||
///# fn main() -> failure::Fallible<()> {
|
/// # fn main() -> failure::Fallible<()> {
|
||||||
/// use rust_bert::pipelines::question_answering::QuestionAnsweringModel;
|
/// use rust_bert::pipelines::question_answering::QuestionAnsweringModel;
|
||||||
///
|
///
|
||||||
/// let qa_model = QuestionAnsweringModel::new(Default::default())?;
|
/// let qa_model = QuestionAnsweringModel::new(Default::default())?;
|
||||||
///# Ok(())
|
/// # Ok(())
|
||||||
///# }
|
/// # }
|
||||||
/// ```
|
/// ```
|
||||||
///
|
pub fn new(
|
||||||
pub fn new(question_answering_config: QuestionAnsweringConfig) -> failure::Fallible<QuestionAnsweringModel> {
|
question_answering_config: QuestionAnsweringConfig,
|
||||||
|
) -> failure::Fallible<QuestionAnsweringModel> {
|
||||||
let config_path = download_resource(&question_answering_config.config_resource)?;
|
let config_path = download_resource(&question_answering_config.config_resource)?;
|
||||||
let vocab_path = download_resource(&question_answering_config.vocab_resource)?;
|
let vocab_path = download_resource(&question_answering_config.vocab_resource)?;
|
||||||
let weights_path = download_resource(&question_answering_config.model_resource)?;
|
let weights_path = download_resource(&question_answering_config.model_resource)?;
|
||||||
let device = question_answering_config.device;
|
let device = question_answering_config.device;
|
||||||
|
|
||||||
let tokenizer = BertTokenizer::from_file(vocab_path.to_str().unwrap(), false);
|
let tokenizer = BertTokenizer::from_file(vocab_path.to_str().unwrap(), false);
|
||||||
let pad_idx = *Tokenizer::vocab(&tokenizer).special_values.get("[PAD]").expect("[PAD] token not found in vocabulary");
|
let pad_idx = *Tokenizer::vocab(&tokenizer)
|
||||||
let sep_idx = *Tokenizer::vocab(&tokenizer).special_values.get("[SEP]").expect("[SEP] token not found in vocabulary");
|
.special_values
|
||||||
|
.get("[PAD]")
|
||||||
|
.expect("[PAD] token not found in vocabulary");
|
||||||
|
let sep_idx = *Tokenizer::vocab(&tokenizer)
|
||||||
|
.special_values
|
||||||
|
.get("[SEP]")
|
||||||
|
.expect("[SEP] token not found in vocabulary");
|
||||||
let mut var_store = VarStore::new(device);
|
let mut var_store = VarStore::new(device);
|
||||||
let mut config = DistilBertConfig::from_file(config_path);
|
let mut config = DistilBertConfig::from_file(config_path);
|
||||||
// The config for the current pre-trained question answering model indicates position embeddings which does not seem accurate
|
// The config for the current pre-trained question answering model indicates position embeddings which does not seem accurate
|
||||||
config.sinusoidal_pos_embds = false;
|
config.sinusoidal_pos_embds = false;
|
||||||
let distilbert_qa = DistilBertForQuestionAnswering::new(&var_store.root(), &config);
|
let distilbert_qa = DistilBertForQuestionAnswering::new(&var_store.root(), &config);
|
||||||
var_store.load(weights_path)?;
|
var_store.load(weights_path)?;
|
||||||
@ -249,7 +265,6 @@ impl QuestionAnsweringModel {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/// Perform extractive question answering given a list of `QaInputs`
|
/// Perform extractive question answering given a list of `QaInputs`
|
||||||
///
|
///
|
||||||
/// # Arguments
|
/// # Arguments
|
||||||
@ -264,25 +279,35 @@ impl QuestionAnsweringModel {
|
|||||||
/// # Example
|
/// # Example
|
||||||
///
|
///
|
||||||
/// ```no_run
|
/// ```no_run
|
||||||
///# fn main() -> failure::Fallible<()> {
|
/// # fn main() -> failure::Fallible<()> {
|
||||||
/// use rust_bert::pipelines::question_answering::{QuestionAnsweringModel, QaInput};
|
/// use rust_bert::pipelines::question_answering::{QaInput, QuestionAnsweringModel};
|
||||||
///
|
///
|
||||||
/// let qa_model = QuestionAnsweringModel::new(Default::default())?;
|
/// let qa_model = QuestionAnsweringModel::new(Default::default())?;
|
||||||
///
|
///
|
||||||
/// let question_1 = String::from("Where does Amy live ?");
|
/// let question_1 = String::from("Where does Amy live ?");
|
||||||
/// let context_1 = String::from("Amy lives in Amsterdam");
|
/// let context_1 = String::from("Amy lives in Amsterdam");
|
||||||
/// let question_2 = String::from("Where does Eric live");
|
/// let question_2 = String::from("Where does Eric live");
|
||||||
/// let context_2 = String::from("While Amy lives in Amsterdam, Eric is in The Hague.");
|
/// let context_2 = String::from("While Amy lives in Amsterdam, Eric is in The Hague.");
|
||||||
///
|
///
|
||||||
/// let qa_input_1 = QaInput { question: question_1, context: context_1 };
|
/// let qa_input_1 = QaInput {
|
||||||
/// let qa_input_2 = QaInput { question: question_2, context: context_2 };
|
/// question: question_1,
|
||||||
|
/// context: context_1,
|
||||||
|
/// };
|
||||||
|
/// let qa_input_2 = QaInput {
|
||||||
|
/// question: question_2,
|
||||||
|
/// context: context_2,
|
||||||
|
/// };
|
||||||
/// let answers = qa_model.predict(&[qa_input_1, qa_input_2], 1, 32);
|
/// let answers = qa_model.predict(&[qa_input_1, qa_input_2], 1, 32);
|
||||||
///
|
///
|
||||||
///# Ok(())
|
/// # Ok(())
|
||||||
///# }
|
/// # }
|
||||||
/// ```
|
/// ```
|
||||||
///
|
pub fn predict(
|
||||||
pub fn predict(&self, qa_inputs: &[QaInput], top_k: i64, batch_size: usize) -> Vec<Vec<Answer>> {
|
&self,
|
||||||
|
qa_inputs: &[QaInput],
|
||||||
|
top_k: i64,
|
||||||
|
batch_size: usize,
|
||||||
|
) -> Vec<Vec<Answer>> {
|
||||||
let examples: Vec<QaExample> = qa_inputs
|
let examples: Vec<QaExample> = qa_inputs
|
||||||
.iter()
|
.iter()
|
||||||
.map(|qa_input| QaExample::new(&qa_input.question, &qa_input.context))
|
.map(|qa_input| QaExample::new(&qa_input.question, &qa_input.context))
|
||||||
@ -290,7 +315,15 @@ impl QuestionAnsweringModel {
|
|||||||
let features: Vec<QaFeature> = examples
|
let features: Vec<QaFeature> = examples
|
||||||
.iter()
|
.iter()
|
||||||
.enumerate()
|
.enumerate()
|
||||||
.map(|(example_index, qa_example)| self.generate_features(&qa_example, self.max_seq_len, self.doc_stride, self.max_query_length, example_index as i64))
|
.map(|(example_index, qa_example)| {
|
||||||
|
self.generate_features(
|
||||||
|
&qa_example,
|
||||||
|
self.max_seq_len,
|
||||||
|
self.doc_stride,
|
||||||
|
self.max_query_length,
|
||||||
|
example_index as i64,
|
||||||
|
)
|
||||||
|
})
|
||||||
.flatten()
|
.flatten()
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
@ -310,28 +343,36 @@ impl QuestionAnsweringModel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
let input_ids = Tensor::stack(&input_ids, 0).to(self.var_store.device());
|
let input_ids = Tensor::stack(&input_ids, 0).to(self.var_store.device());
|
||||||
let attention_masks = Tensor::stack(&attention_masks, 0).to(self.var_store.device());
|
let attention_masks =
|
||||||
|
Tensor::stack(&attention_masks, 0).to(self.var_store.device());
|
||||||
|
|
||||||
let (start_logits, end_logits, _, _) = self.distilbert_qa.forward_t(Some(input_ids), Some(attention_masks), None, false).unwrap();
|
let (start_logits, end_logits, _, _) = self
|
||||||
|
.distilbert_qa
|
||||||
|
.forward_t(Some(input_ids), Some(attention_masks), None, false)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
let start_logits = start_logits.detach();
|
let start_logits = start_logits.detach();
|
||||||
let end_logits = end_logits.detach();
|
let end_logits = end_logits.detach();
|
||||||
let example_index_to_feature_end_position: Vec<(usize, i64)> = batch_features
|
let example_index_to_feature_end_position: Vec<(usize, i64)> = batch_features
|
||||||
.iter()
|
.iter()
|
||||||
.enumerate()
|
.enumerate()
|
||||||
.map(|(feature_index, feature)| (feature.example_index as usize, feature_index as i64 + 1))
|
.map(|(feature_index, feature)| {
|
||||||
|
(feature.example_index as usize, feature_index as i64 + 1)
|
||||||
|
})
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
let mut feature_id_start = 0;
|
let mut feature_id_start = 0;
|
||||||
|
|
||||||
for (example_id, max_feature_id) in example_index_to_feature_end_position {
|
for (example_id, max_feature_id) in example_index_to_feature_end_position {
|
||||||
let mut answers: Vec<Answer> = vec!();
|
let mut answers: Vec<Answer> = vec![];
|
||||||
let example = &examples[example_id];
|
let example = &examples[example_id];
|
||||||
for feature_idx in feature_id_start..max_feature_id {
|
for feature_idx in feature_id_start..max_feature_id {
|
||||||
let feature = &batch_features[feature_idx as usize];
|
let feature = &batch_features[feature_idx as usize];
|
||||||
let start = start_logits.get(feature_idx);
|
let start = start_logits.get(feature_idx);
|
||||||
let end = end_logits.get(feature_idx);
|
let end = end_logits.get(feature_idx);
|
||||||
let p_mask = (Tensor::of_slice(&feature.p_mask) - 1).abs().to_device(start.device());
|
let p_mask = (Tensor::of_slice(&feature.p_mask) - 1)
|
||||||
|
.abs()
|
||||||
|
.to_device(start.device());
|
||||||
|
|
||||||
let start: Tensor = start.exp() / start.exp().sum(Float) * &p_mask;
|
let start: Tensor = start.exp() / start.exp().sum(Float) * &p_mask;
|
||||||
let end: Tensor = end.exp() / end.exp().sum(Float) * &p_mask;
|
let end: Tensor = end.exp() / end.exp().sum(Float) * &p_mask;
|
||||||
@ -343,33 +384,42 @@ impl QuestionAnsweringModel {
|
|||||||
let end_pos = feature.token_to_orig_map[&ends[idx]] as usize;
|
let end_pos = feature.token_to_orig_map[&ends[idx]] as usize;
|
||||||
let answer = example.doc_tokens[start_pos..end_pos + 1].join(" ");
|
let answer = example.doc_tokens[start_pos..end_pos + 1].join(" ");
|
||||||
|
|
||||||
let start = example.char_to_word_offset
|
let start = example
|
||||||
|
.char_to_word_offset
|
||||||
.iter()
|
.iter()
|
||||||
.position(|&v| v as usize == start_pos)
|
.position(|&v| v as usize == start_pos)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let end = example.char_to_word_offset
|
let end = example
|
||||||
|
.char_to_word_offset
|
||||||
.iter()
|
.iter()
|
||||||
.rposition(|&v| v as usize == end_pos)
|
.rposition(|&v| v as usize == end_pos)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
answers.push(Answer { score: scores[idx], start, end, answer });
|
answers.push(Answer {
|
||||||
|
score: scores[idx],
|
||||||
|
start,
|
||||||
|
end,
|
||||||
|
answer,
|
||||||
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
feature_id_start = max_feature_id;
|
feature_id_start = max_feature_id;
|
||||||
let example_answers = example_top_k_answers_map.entry(example_id).or_insert(vec!());
|
let example_answers = example_top_k_answers_map
|
||||||
|
.entry(example_id)
|
||||||
|
.or_insert(vec![]);
|
||||||
example_answers.extend(answers);
|
example_answers.extend(answers);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
start = end;
|
start = end;
|
||||||
}
|
}
|
||||||
let mut all_answers = vec!();
|
let mut all_answers = vec![];
|
||||||
for example_id in 0..examples.len() {
|
for example_id in 0..examples.len() {
|
||||||
if let Some(answers) = example_top_k_answers_map.get_mut(&example_id) {
|
if let Some(answers) = example_top_k_answers_map.get_mut(&example_id) {
|
||||||
remove_duplicates(answers).sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
|
remove_duplicates(answers).sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
|
||||||
all_answers.push(answers[..min(answers.len(), top_k as usize)].to_vec());
|
all_answers.push(answers[..min(answers.len(), top_k as usize)].to_vec());
|
||||||
} else {
|
} else {
|
||||||
all_answers.push(vec!());
|
all_answers.push(vec![]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
all_answers
|
all_answers
|
||||||
@ -379,7 +429,10 @@ impl QuestionAnsweringModel {
|
|||||||
let outer = start.unsqueeze(-1).matmul(&end.unsqueeze(0));
|
let outer = start.unsqueeze(-1).matmul(&end.unsqueeze(0));
|
||||||
let start_dim = start.size()[0];
|
let start_dim = start.size()[0];
|
||||||
let end_dim = end.size()[0];
|
let end_dim = end.size()[0];
|
||||||
let candidates = outer.triu(0).tril(self.max_answer_len as i64 - 1).flatten(0, -1);
|
let candidates = outer
|
||||||
|
.triu(0)
|
||||||
|
.tril(self.max_answer_len as i64 - 1)
|
||||||
|
.flatten(0, -1);
|
||||||
let idx_sort = if top_k == 1 {
|
let idx_sort = if top_k == 1 {
|
||||||
candidates.argmax(0, true)
|
candidates.argmax(0, true)
|
||||||
} else if candidates.size()[0] < top_k {
|
} else if candidates.size()[0] < top_k {
|
||||||
@ -387,9 +440,9 @@ impl QuestionAnsweringModel {
|
|||||||
} else {
|
} else {
|
||||||
candidates.argsort(0, true).slice(0, 0, top_k, 1)
|
candidates.argsort(0, true).slice(0, 0, top_k, 1)
|
||||||
};
|
};
|
||||||
let mut start: Vec<i64> = vec!();
|
let mut start: Vec<i64> = vec![];
|
||||||
let mut end: Vec<i64> = vec!();
|
let mut end: Vec<i64> = vec![];
|
||||||
let mut scores: Vec<f64> = vec!();
|
let mut scores: Vec<f64> = vec![];
|
||||||
for flat_index_position in 0..idx_sort.size()[0] {
|
for flat_index_position in 0..idx_sort.size()[0] {
|
||||||
let flat_index = idx_sort.int64_value(&[flat_index_position]);
|
let flat_index = idx_sort.int64_value(&[flat_index_position]);
|
||||||
scores.push(candidates.double_value(&[flat_index]));
|
scores.push(candidates.double_value(&[flat_index]));
|
||||||
@ -399,10 +452,16 @@ impl QuestionAnsweringModel {
|
|||||||
(start, end, scores)
|
(start, end, scores)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn generate_features(
|
||||||
fn generate_features(&self, qa_example: &QaExample, max_seq_length: usize, doc_stride: usize, max_query_length: usize, example_index: i64) -> Vec<QaFeature> {
|
&self,
|
||||||
let mut tok_to_orig_index: Vec<i64> = vec!();
|
qa_example: &QaExample,
|
||||||
let mut all_doc_tokens: Vec<String> = vec!();
|
max_seq_length: usize,
|
||||||
|
doc_stride: usize,
|
||||||
|
max_query_length: usize,
|
||||||
|
example_index: i64,
|
||||||
|
) -> Vec<QaFeature> {
|
||||||
|
let mut tok_to_orig_index: Vec<i64> = vec![];
|
||||||
|
let mut all_doc_tokens: Vec<String> = vec![];
|
||||||
|
|
||||||
for (idx, token) in qa_example.doc_tokens.iter().enumerate() {
|
for (idx, token) in qa_example.doc_tokens.iter().enumerate() {
|
||||||
let sub_tokens = self.tokenizer.tokenize(token);
|
let sub_tokens = self.tokenizer.tokenize(token);
|
||||||
@ -414,28 +473,61 @@ impl QuestionAnsweringModel {
|
|||||||
|
|
||||||
let truncated_query = self.prepare_query(&qa_example.question, max_query_length);
|
let truncated_query = self.prepare_query(&qa_example.question, max_query_length);
|
||||||
|
|
||||||
let sequence_added_tokens = self.tokenizer.build_input_with_special_tokens(vec!(), None, vec!(), None, vec!(), None, vec!(), None).0.len();
|
let sequence_added_tokens = self
|
||||||
let sequence_pair_added_tokens = self.tokenizer.build_input_with_special_tokens(vec!(), Some(vec!()), vec!(), Some(vec!()), vec!(), Some(vec!()), vec!(), Some(vec!())).0.len();
|
.tokenizer
|
||||||
|
.build_input_with_special_tokens(vec![], None, vec![], None, vec![], None, vec![], None)
|
||||||
|
.0
|
||||||
|
.len();
|
||||||
|
let sequence_pair_added_tokens = self
|
||||||
|
.tokenizer
|
||||||
|
.build_input_with_special_tokens(
|
||||||
|
vec![],
|
||||||
|
Some(vec![]),
|
||||||
|
vec![],
|
||||||
|
Some(vec![]),
|
||||||
|
vec![],
|
||||||
|
Some(vec![]),
|
||||||
|
vec![],
|
||||||
|
Some(vec![]),
|
||||||
|
)
|
||||||
|
.0
|
||||||
|
.len();
|
||||||
|
|
||||||
let mut spans: Vec<QaFeature> = vec!();
|
let mut spans: Vec<QaFeature> = vec![];
|
||||||
|
|
||||||
let mut remaining_tokens = self.tokenizer.convert_tokens_to_ids(&all_doc_tokens);
|
let mut remaining_tokens = self.tokenizer.convert_tokens_to_ids(&all_doc_tokens);
|
||||||
while (spans.len() * doc_stride as usize) < all_doc_tokens.len() {
|
while (spans.len() * doc_stride as usize) < all_doc_tokens.len() {
|
||||||
let (encoded_span, attention_mask) = self.encode_qa_pair(&truncated_query, &remaining_tokens, max_seq_length, doc_stride, sequence_pair_added_tokens);
|
let (encoded_span, attention_mask) = self.encode_qa_pair(
|
||||||
|
&truncated_query,
|
||||||
|
&remaining_tokens,
|
||||||
|
max_seq_length,
|
||||||
|
doc_stride,
|
||||||
|
sequence_pair_added_tokens,
|
||||||
|
);
|
||||||
|
|
||||||
let paragraph_len = min(
|
let paragraph_len = min(
|
||||||
all_doc_tokens.len() - spans.len() * doc_stride,
|
all_doc_tokens.len() - spans.len() * doc_stride,
|
||||||
max_seq_length - truncated_query.len() - sequence_pair_added_tokens);
|
max_seq_length - truncated_query.len() - sequence_pair_added_tokens,
|
||||||
|
);
|
||||||
|
|
||||||
let mut token_to_orig_map = HashMap::new();
|
let mut token_to_orig_map = HashMap::new();
|
||||||
for i in 0..paragraph_len {
|
for i in 0..paragraph_len {
|
||||||
let index = truncated_query.len() + sequence_added_tokens + i;
|
let index = truncated_query.len() + sequence_added_tokens + i;
|
||||||
token_to_orig_map.insert(index as i64, tok_to_orig_index[spans.len() * doc_stride + i] as i64);
|
token_to_orig_map.insert(
|
||||||
|
index as i64,
|
||||||
|
tok_to_orig_index[spans.len() * doc_stride + i] as i64,
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
let p_mask = self.get_mask(&encoded_span);
|
let p_mask = self.get_mask(&encoded_span);
|
||||||
|
|
||||||
let qa_feature = QaFeature { input_ids: encoded_span.token_ids, attention_mask, token_to_orig_map, p_mask, example_index };
|
let qa_feature = QaFeature {
|
||||||
|
input_ids: encoded_span.token_ids,
|
||||||
|
attention_mask,
|
||||||
|
token_to_orig_map,
|
||||||
|
p_mask,
|
||||||
|
example_index,
|
||||||
|
};
|
||||||
|
|
||||||
spans.push(qa_feature);
|
spans.push(qa_feature);
|
||||||
if encoded_span.num_truncated_tokens == 0 {
|
if encoded_span.num_truncated_tokens == 0 {
|
||||||
@ -447,59 +539,81 @@ impl QuestionAnsweringModel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn prepare_query(&self, query: &str, max_query_length: usize) -> Vec<i64> {
|
fn prepare_query(&self, query: &str, max_query_length: usize) -> Vec<i64> {
|
||||||
let truncated_query = self.tokenizer.convert_tokens_to_ids(&self.tokenizer.tokenize(&query));
|
let truncated_query = self
|
||||||
let num_query_tokens_to_remove = if truncated_query.len() > max_query_length as usize { truncated_query.len() - max_query_length } else { 0 };
|
.tokenizer
|
||||||
let (truncated_query, _, _, _, _, _, _, _, _, _) = truncate_sequences(truncated_query,
|
.convert_tokens_to_ids(&self.tokenizer.tokenize(&query));
|
||||||
None,
|
let num_query_tokens_to_remove = if truncated_query.len() > max_query_length as usize {
|
||||||
vec!(),
|
truncated_query.len() - max_query_length
|
||||||
None,
|
} else {
|
||||||
vec!(),
|
0
|
||||||
None,
|
};
|
||||||
vec!(),
|
let (truncated_query, _, _, _, _, _, _, _, _, _) = truncate_sequences(
|
||||||
None,
|
truncated_query,
|
||||||
num_query_tokens_to_remove,
|
None,
|
||||||
&TruncationStrategy::OnlyFirst,
|
vec![],
|
||||||
0).unwrap();
|
None,
|
||||||
|
vec![],
|
||||||
|
None,
|
||||||
|
vec![],
|
||||||
|
None,
|
||||||
|
num_query_tokens_to_remove,
|
||||||
|
&TruncationStrategy::OnlyFirst,
|
||||||
|
0,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
truncated_query
|
truncated_query
|
||||||
}
|
}
|
||||||
|
|
||||||
fn encode_qa_pair(&self,
|
fn encode_qa_pair(
|
||||||
truncated_query: &Vec<i64>,
|
&self,
|
||||||
spans_token_ids: &Vec<i64>,
|
truncated_query: &Vec<i64>,
|
||||||
max_seq_length: usize,
|
spans_token_ids: &Vec<i64>,
|
||||||
doc_stride: usize,
|
max_seq_length: usize,
|
||||||
sequence_pair_added_tokens: usize) -> (TokenizedInput, Vec<i64>) {
|
doc_stride: usize,
|
||||||
|
sequence_pair_added_tokens: usize,
|
||||||
|
) -> (TokenizedInput, Vec<i64>) {
|
||||||
let len_1 = truncated_query.len();
|
let len_1 = truncated_query.len();
|
||||||
let len_2 = spans_token_ids.len();
|
let len_2 = spans_token_ids.len();
|
||||||
let total_len = len_1 + len_2 + sequence_pair_added_tokens;
|
let total_len = len_1 + len_2 + sequence_pair_added_tokens;
|
||||||
let num_truncated_tokens = if total_len > max_seq_length { total_len - max_seq_length } else { 0 };
|
let num_truncated_tokens = if total_len > max_seq_length {
|
||||||
|
total_len - max_seq_length
|
||||||
|
} else {
|
||||||
|
0
|
||||||
|
};
|
||||||
|
|
||||||
let (truncated_query, truncated_context, _, _, _, _, _, _, overflowing_tokens, _)
|
let (truncated_query, truncated_context, _, _, _, _, _, _, overflowing_tokens, _) =
|
||||||
= truncate_sequences(truncated_query.clone(),
|
truncate_sequences(
|
||||||
Some(spans_token_ids.clone()),
|
truncated_query.clone(),
|
||||||
vec!(),
|
Some(spans_token_ids.clone()),
|
||||||
None,
|
vec![],
|
||||||
vec!(),
|
None,
|
||||||
None,
|
vec![],
|
||||||
vec!(),
|
None,
|
||||||
None,
|
vec![],
|
||||||
num_truncated_tokens,
|
None,
|
||||||
&TruncationStrategy::OnlySecond,
|
num_truncated_tokens,
|
||||||
max_seq_length - doc_stride - len_1 - sequence_pair_added_tokens).unwrap();
|
&TruncationStrategy::OnlySecond,
|
||||||
|
max_seq_length - doc_stride - len_1 - sequence_pair_added_tokens,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
let (mut token_ids,
|
let (
|
||||||
|
mut token_ids,
|
||||||
mut segment_ids,
|
mut segment_ids,
|
||||||
special_tokens_mask,
|
special_tokens_mask,
|
||||||
mut token_offsets,
|
mut token_offsets,
|
||||||
mut reference_offsets,
|
mut reference_offsets,
|
||||||
mut mask) = self.tokenizer.build_input_with_special_tokens(truncated_query,
|
mut mask,
|
||||||
truncated_context,
|
) = self.tokenizer.build_input_with_special_tokens(
|
||||||
vec!(),
|
truncated_query,
|
||||||
None,
|
truncated_context,
|
||||||
vec!(),
|
vec![],
|
||||||
None,
|
None,
|
||||||
vec!(),
|
vec![],
|
||||||
None);
|
None,
|
||||||
|
vec![],
|
||||||
|
None,
|
||||||
|
);
|
||||||
let mut attention_mask = vec![1; token_ids.len()];
|
let mut attention_mask = vec![1; token_ids.len()];
|
||||||
if token_ids.len() < max_seq_length {
|
if token_ids.len() < max_seq_length {
|
||||||
token_ids.append(&mut vec![self.pad_idx; max_seq_length - token_ids.len()]);
|
token_ids.append(&mut vec![self.pad_idx; max_seq_length - token_ids.len()]);
|
||||||
@ -509,18 +623,32 @@ impl QuestionAnsweringModel {
|
|||||||
reference_offsets.append(&mut vec![vec!(); max_seq_length - token_offsets.len()]);
|
reference_offsets.append(&mut vec![vec!(); max_seq_length - token_offsets.len()]);
|
||||||
mask.append(&mut vec![Mask::Special; max_seq_length - mask.len()]);
|
mask.append(&mut vec![Mask::Special; max_seq_length - mask.len()]);
|
||||||
}
|
}
|
||||||
(TokenizedInput { token_ids, segment_ids, special_tokens_mask, overflowing_tokens, num_truncated_tokens, token_offsets, reference_offsets, mask }, attention_mask)
|
(
|
||||||
|
TokenizedInput {
|
||||||
|
token_ids,
|
||||||
|
segment_ids,
|
||||||
|
special_tokens_mask,
|
||||||
|
overflowing_tokens,
|
||||||
|
num_truncated_tokens,
|
||||||
|
token_offsets,
|
||||||
|
reference_offsets,
|
||||||
|
mask,
|
||||||
|
},
|
||||||
|
attention_mask,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_mask(&self, encoded_span: &TokenizedInput) -> Vec<i8> {
|
fn get_mask(&self, encoded_span: &TokenizedInput) -> Vec<i8> {
|
||||||
let sep_indices: Vec<usize> = encoded_span.token_ids
|
let sep_indices: Vec<usize> = encoded_span
|
||||||
|
.token_ids
|
||||||
.iter()
|
.iter()
|
||||||
.enumerate()
|
.enumerate()
|
||||||
.filter(|(_, &value)| value == self.sep_idx)
|
.filter(|(_, &value)| value == self.sep_idx)
|
||||||
.map(|(position, _)| position)
|
.map(|(position, _)| position)
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
let mut p_mask: Vec<i8> = encoded_span.segment_ids
|
let mut p_mask: Vec<i8> = encoded_span
|
||||||
|
.segment_ids
|
||||||
.iter()
|
.iter()
|
||||||
.map(|v| min(v, &1i8))
|
.map(|v| min(v, &1i8))
|
||||||
.map(|&v| 1i8 - v)
|
.map(|&v| 1i8 - v)
|
||||||
@ -534,10 +662,13 @@ impl QuestionAnsweringModel {
|
|||||||
|
|
||||||
pub fn squad_processor(file_path: PathBuf) -> Vec<QaInput> {
|
pub fn squad_processor(file_path: PathBuf) -> Vec<QaInput> {
|
||||||
let file = fs::File::open(file_path).expect("unable to open file");
|
let file = fs::File::open(file_path).expect("unable to open file");
|
||||||
let json: serde_json::Value = serde_json::from_reader(file).expect("JSON not properly formatted");
|
let json: serde_json::Value =
|
||||||
|
serde_json::from_reader(file).expect("JSON not properly formatted");
|
||||||
let data = json
|
let data = json
|
||||||
.get("data").expect("SQuAD file does not contain data field")
|
.get("data")
|
||||||
.as_array().expect("Data array not properly formatted");
|
.expect("SQuAD file does not contain data field")
|
||||||
|
.as_array()
|
||||||
|
.expect("Data array not properly formatted");
|
||||||
|
|
||||||
let mut qa_inputs: Vec<QaInput> = Vec::with_capacity(data.len());
|
let mut qa_inputs: Vec<QaInput> = Vec::with_capacity(data.len());
|
||||||
for qa_input in data.iter() {
|
for qa_input in data.iter() {
|
||||||
@ -548,8 +679,17 @@ pub fn squad_processor(file_path: PathBuf) -> Vec<QaInput> {
|
|||||||
let context = paragraph.get("context").unwrap().as_str().unwrap();
|
let context = paragraph.get("context").unwrap().as_str().unwrap();
|
||||||
let qas = paragraph.get("qas").unwrap().as_array().unwrap();
|
let qas = paragraph.get("qas").unwrap().as_array().unwrap();
|
||||||
for qa in qas.iter() {
|
for qa in qas.iter() {
|
||||||
let question = qa.as_object().unwrap().get("question").unwrap().as_str().unwrap();
|
let question = qa
|
||||||
qa_inputs.push(QaInput { question: question.to_owned(), context: context.to_owned() });
|
.as_object()
|
||||||
|
.unwrap()
|
||||||
|
.get("question")
|
||||||
|
.unwrap()
|
||||||
|
.as_str()
|
||||||
|
.unwrap();
|
||||||
|
qa_inputs.push(QaInput {
|
||||||
|
question: question.to_owned(),
|
||||||
|
context: context.to_owned(),
|
||||||
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -19,7 +19,7 @@
|
|||||||
//! ```no_run
|
//! ```no_run
|
||||||
//! use rust_bert::pipelines::sentiment::SentimentModel;
|
//! use rust_bert::pipelines::sentiment::SentimentModel;
|
||||||
//!
|
//!
|
||||||
//!# fn main() -> failure::Fallible<()> {
|
//! # fn main() -> failure::Fallible<()> {
|
||||||
//! let sentiment_classifier = SentimentModel::new(Default::default())?;
|
//! let sentiment_classifier = SentimentModel::new(Default::default())?;
|
||||||
//! let input = [
|
//! let input = [
|
||||||
//! "Probably my all-time favorite movie, a story of selflessness, sacrifice and dedication to a noble cause, but it's not preachy or boring.",
|
//! "Probably my all-time favorite movie, a story of selflessness, sacrifice and dedication to a noble cause, but it's not preachy or boring.",
|
||||||
@ -27,29 +27,40 @@
|
|||||||
//! "If you like original gut wrenching laughter you will like this movie. If you are young or old then you will love this movie, hell even my mom liked it.",
|
//! "If you like original gut wrenching laughter you will like this movie. If you are young or old then you will love this movie, hell even my mom liked it.",
|
||||||
//! ];
|
//! ];
|
||||||
//! let output = sentiment_classifier.predict(&input);
|
//! let output = sentiment_classifier.predict(&input);
|
||||||
//!# Ok(())
|
//! # Ok(())
|
||||||
//!# }
|
//! # }
|
||||||
//! ```
|
//! ```
|
||||||
//! (Example courtesy of [IMDb](http://www.imdb.com))
|
//! (Example courtesy of [IMDb](http://www.imdb.com))
|
||||||
//!
|
//!
|
||||||
//! Output: \
|
//! Output: \
|
||||||
//! ```no_run
|
//! ```no_run
|
||||||
//!# use rust_bert::pipelines::sentiment::Sentiment;
|
//! # use rust_bert::pipelines::sentiment::Sentiment;
|
||||||
//!# use rust_bert::pipelines::sentiment::SentimentPolarity::{Positive, Negative};
|
//! # use rust_bert::pipelines::sentiment::SentimentPolarity::{Positive, Negative};
|
||||||
//!# let output =
|
//! # let output =
|
||||||
//! [
|
//! [
|
||||||
//! Sentiment { polarity: Positive, score: 0.998 },
|
//! Sentiment {
|
||||||
//! Sentiment { polarity: Negative, score: 0.992 },
|
//! polarity: Positive,
|
||||||
//! Sentiment { polarity: Positive, score: 0.999 }
|
//! score: 0.998,
|
||||||
|
//! },
|
||||||
|
//! Sentiment {
|
||||||
|
//! polarity: Negative,
|
||||||
|
//! score: 0.992,
|
||||||
|
//! },
|
||||||
|
//! Sentiment {
|
||||||
|
//! polarity: Positive,
|
||||||
|
//! score: 0.999,
|
||||||
|
//! },
|
||||||
//! ]
|
//! ]
|
||||||
//!# ;
|
//! # ;
|
||||||
//! ```
|
//! ```
|
||||||
|
|
||||||
use std::path::PathBuf;
|
use crate::pipelines::sequence_classification::{
|
||||||
use std::fs;
|
SequenceClassificationConfig, SequenceClassificationModel,
|
||||||
|
};
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
use std::error::Error;
|
use std::error::Error;
|
||||||
use crate::pipelines::sequence_classification::{SequenceClassificationConfig, SequenceClassificationModel};
|
use std::fs;
|
||||||
|
use std::path::PathBuf;
|
||||||
|
|
||||||
#[derive(Debug, PartialEq)]
|
#[derive(Debug, PartialEq)]
|
||||||
/// Enum with the possible sentiment polarities. Note that the pre-trained SST2 model does not include neutral sentiment.
|
/// Enum with the possible sentiment polarities. Note that the pre-trained SST2 model does not include neutral sentiment.
|
||||||
@ -71,7 +82,7 @@ type SentimentConfig = SequenceClassificationConfig;
|
|||||||
|
|
||||||
/// # SentimentClassifier to perform sentiment analysis
|
/// # SentimentClassifier to perform sentiment analysis
|
||||||
pub struct SentimentModel {
|
pub struct SentimentModel {
|
||||||
sequence_classification_model: SequenceClassificationModel
|
sequence_classification_model: SequenceClassificationModel,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl SentimentModel {
|
impl SentimentModel {
|
||||||
@ -84,17 +95,18 @@ impl SentimentModel {
|
|||||||
/// # Example
|
/// # Example
|
||||||
///
|
///
|
||||||
/// ```no_run
|
/// ```no_run
|
||||||
///# fn main() -> failure::Fallible<()> {
|
/// # fn main() -> failure::Fallible<()> {
|
||||||
/// use rust_bert::pipelines::sentiment::SentimentModel;
|
/// use rust_bert::pipelines::sentiment::SentimentModel;
|
||||||
///
|
///
|
||||||
/// let sentiment_model = SentimentModel::new(Default::default())?;
|
/// let sentiment_model = SentimentModel::new(Default::default())?;
|
||||||
///# Ok(())
|
/// # Ok(())
|
||||||
///# }
|
/// # }
|
||||||
/// ```
|
/// ```
|
||||||
///
|
|
||||||
pub fn new(sentiment_config: SentimentConfig) -> failure::Fallible<SentimentModel> {
|
pub fn new(sentiment_config: SentimentConfig) -> failure::Fallible<SentimentModel> {
|
||||||
let sequence_classification_model = SequenceClassificationModel::new(sentiment_config)?;
|
let sequence_classification_model = SequenceClassificationModel::new(sentiment_config)?;
|
||||||
Ok(SentimentModel { sequence_classification_model })
|
Ok(SentimentModel {
|
||||||
|
sequence_classification_model,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Extract sentiment form an array of text inputs
|
/// Extract sentiment form an array of text inputs
|
||||||
@ -109,7 +121,7 @@ impl SentimentModel {
|
|||||||
/// # Example
|
/// # Example
|
||||||
///
|
///
|
||||||
/// ```no_run
|
/// ```no_run
|
||||||
///# fn main() -> failure::Fallible<()> {
|
/// # fn main() -> failure::Fallible<()> {
|
||||||
/// use rust_bert::pipelines::sentiment::SentimentModel;
|
/// use rust_bert::pipelines::sentiment::SentimentModel;
|
||||||
///
|
///
|
||||||
/// let sentiment_classifier = SentimentModel::new(Default::default())?;
|
/// let sentiment_classifier = SentimentModel::new(Default::default())?;
|
||||||
@ -121,17 +133,23 @@ impl SentimentModel {
|
|||||||
/// ];
|
/// ];
|
||||||
///
|
///
|
||||||
/// let output = sentiment_classifier.predict(&input);
|
/// let output = sentiment_classifier.predict(&input);
|
||||||
///# Ok(())
|
/// # Ok(())
|
||||||
///# }
|
/// # }
|
||||||
/// ```
|
/// ```
|
||||||
///
|
|
||||||
pub fn predict(&self, input: &[&str]) -> Vec<Sentiment> {
|
pub fn predict(&self, input: &[&str]) -> Vec<Sentiment> {
|
||||||
let labels = self.sequence_classification_model.predict(input);
|
let labels = self.sequence_classification_model.predict(input);
|
||||||
let mut sentiments = Vec::with_capacity(labels.len());
|
let mut sentiments = Vec::with_capacity(labels.len());
|
||||||
for label in labels {
|
for label in labels {
|
||||||
let polarity = if label.id == 1 { SentimentPolarity::Positive } else { SentimentPolarity::Negative };
|
let polarity = if label.id == 1 {
|
||||||
sentiments.push(Sentiment { polarity, score: label.score })
|
SentimentPolarity::Positive
|
||||||
};
|
} else {
|
||||||
|
SentimentPolarity::Negative
|
||||||
|
};
|
||||||
|
sentiments.push(Sentiment {
|
||||||
|
polarity,
|
||||||
|
score: label.score,
|
||||||
|
})
|
||||||
|
}
|
||||||
sentiments
|
sentiments
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -154,4 +172,4 @@ pub fn ss2_processor(file_path: PathBuf) -> Result<Vec<String>, Box<dyn Error>>
|
|||||||
records.push(record.sentence);
|
records.push(record.sentence);
|
||||||
}
|
}
|
||||||
Ok(records)
|
Ok(records)
|
||||||
}
|
}
|
||||||
|
@ -19,7 +19,7 @@
|
|||||||
//! use rust_bert::distilbert::{DistilBertModelResources, DistilBertVocabResources, DistilBertConfigResources};
|
//! use rust_bert::distilbert::{DistilBertModelResources, DistilBertVocabResources, DistilBertConfigResources};
|
||||||
//! use rust_bert::pipelines::sequence_classification::SequenceClassificationModel;
|
//! use rust_bert::pipelines::sequence_classification::SequenceClassificationModel;
|
||||||
//! use rust_bert::pipelines::common::ModelType;
|
//! use rust_bert::pipelines::common::ModelType;
|
||||||
//!# fn main() -> failure::Fallible<()> {
|
//! # fn main() -> failure::Fallible<()> {
|
||||||
//!
|
//!
|
||||||
//! //Load a configuration
|
//! //Load a configuration
|
||||||
//! let config = SequenceClassificationConfig::new(ModelType::DistilBert,
|
//! let config = SequenceClassificationConfig::new(ModelType::DistilBert,
|
||||||
@ -39,34 +39,38 @@
|
|||||||
//! "If you like original gut wrenching laughter you will like this movie. If you are young or old then you will love this movie, hell even my mom liked it.",
|
//! "If you like original gut wrenching laughter you will like this movie. If you are young or old then you will love this movie, hell even my mom liked it.",
|
||||||
//! ];
|
//! ];
|
||||||
//! let output = sequence_classification_model.predict(&input);
|
//! let output = sequence_classification_model.predict(&input);
|
||||||
//!# Ok(())
|
//! # Ok(())
|
||||||
//!# }
|
//! # }
|
||||||
//! ```
|
//! ```
|
||||||
//! (Example courtesy of [IMDb](http://www.imdb.com))
|
//! (Example courtesy of [IMDb](http://www.imdb.com))
|
||||||
//!
|
//!
|
||||||
//! Output: \
|
//! Output: \
|
||||||
//! ```no_run
|
//! ```no_run
|
||||||
//!# use rust_bert::pipelines::sequence_classification::Label;
|
//! # use rust_bert::pipelines::sequence_classification::Label;
|
||||||
//! let output =
|
//! let output =
|
||||||
//! [
|
//! [
|
||||||
//! Label { text: String::from("POSITIVE"), score: 0.9986, id: 1, sentence: 0},
|
//! Label { text: String::from("POSITIVE"), score: 0.9986, id: 1, sentence: 0},
|
||||||
//! Label { text: String::from("NEGATIVE"), score: 0.9985, id: 0, sentence: 1},
|
//! Label { text: String::from("NEGATIVE"), score: 0.9985, id: 0, sentence: 1},
|
||||||
//! Label { text: String::from("POSITIVE"), score: 0.9988, id: 1, sentence: 12},
|
//! Label { text: String::from("POSITIVE"), score: 0.9988, id: 1, sentence: 12},
|
||||||
//! ]
|
//! ]
|
||||||
//!# ;
|
//! # ;
|
||||||
//! ```
|
//! ```
|
||||||
//!
|
//!
|
||||||
use tch::nn::VarStore;
|
|
||||||
use rust_tokenizers::preprocessing::tokenizer::base_tokenizer::{TokenizedInput, TruncationStrategy};
|
|
||||||
use std::collections::HashMap;
|
|
||||||
use tch::{Tensor, no_grad, Device, Kind};
|
|
||||||
use crate::bert::BertForSequenceClassification;
|
use crate::bert::BertForSequenceClassification;
|
||||||
|
use crate::common::resources::{download_resource, RemoteResource, Resource};
|
||||||
|
use crate::distilbert::{
|
||||||
|
DistilBertConfigResources, DistilBertModelClassifier, DistilBertModelResources,
|
||||||
|
DistilBertVocabResources,
|
||||||
|
};
|
||||||
|
use crate::pipelines::common::{ConfigOption, ModelType, TokenizerOption};
|
||||||
use crate::roberta::RobertaForSequenceClassification;
|
use crate::roberta::RobertaForSequenceClassification;
|
||||||
use crate::distilbert::{DistilBertModelResources, DistilBertConfigResources, DistilBertVocabResources, DistilBertModelClassifier};
|
use rust_tokenizers::preprocessing::tokenizer::base_tokenizer::{
|
||||||
use crate::common::resources::{Resource, RemoteResource, download_resource};
|
TokenizedInput, TruncationStrategy,
|
||||||
use serde::{Serialize, Deserialize};
|
};
|
||||||
use crate::pipelines::common::{ModelType, ConfigOption, TokenizerOption};
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use tch::nn::VarStore;
|
||||||
|
use tch::{no_grad, Device, Kind, Tensor};
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
/// # Label generated by a `SequenceClassificationModel`
|
/// # Label generated by a `SequenceClassificationModel`
|
||||||
@ -112,8 +116,14 @@ impl SequenceClassificationConfig {
|
|||||||
/// * vocab - The `Resource' pointing to the tokenizer's vocabulary to load (e.g. vocab.txt/vocab.json)
|
/// * vocab - The `Resource' pointing to the tokenizer's vocabulary to load (e.g. vocab.txt/vocab.json)
|
||||||
/// * vocab - An optional `Resource` tuple (`Option<Resource>`) pointing to the tokenizer's merge file to load (e.g. merges.txt), needed only for Roberta.
|
/// * vocab - An optional `Resource` tuple (`Option<Resource>`) pointing to the tokenizer's merge file to load (e.g. merges.txt), needed only for Roberta.
|
||||||
/// * lower_case - A `bool' indicating whether the tokeniser should lower case all input (in case of a lower-cased model)
|
/// * lower_case - A `bool' indicating whether the tokeniser should lower case all input (in case of a lower-cased model)
|
||||||
///
|
pub fn new(
|
||||||
pub fn new(model_type: ModelType, model_resource: Resource, config_resource: Resource, vocab_resource: Resource, merges_resource: Option<Resource>, lower_case: bool) -> SequenceClassificationConfig {
|
model_type: ModelType,
|
||||||
|
model_resource: Resource,
|
||||||
|
config_resource: Resource,
|
||||||
|
vocab_resource: Resource,
|
||||||
|
merges_resource: Option<Resource>,
|
||||||
|
lower_case: bool,
|
||||||
|
) -> SequenceClassificationConfig {
|
||||||
SequenceClassificationConfig {
|
SequenceClassificationConfig {
|
||||||
model_type,
|
model_type,
|
||||||
model_resource,
|
model_resource,
|
||||||
@ -131,9 +141,15 @@ impl Default for SequenceClassificationConfig {
|
|||||||
fn default() -> SequenceClassificationConfig {
|
fn default() -> SequenceClassificationConfig {
|
||||||
SequenceClassificationConfig {
|
SequenceClassificationConfig {
|
||||||
model_type: ModelType::DistilBert,
|
model_type: ModelType::DistilBert,
|
||||||
model_resource: Resource::Remote(RemoteResource::from_pretrained(DistilBertModelResources::DISTIL_BERT_SST2)),
|
model_resource: Resource::Remote(RemoteResource::from_pretrained(
|
||||||
config_resource: Resource::Remote(RemoteResource::from_pretrained(DistilBertConfigResources::DISTIL_BERT_SST2)),
|
DistilBertModelResources::DISTIL_BERT_SST2,
|
||||||
vocab_resource: Resource::Remote(RemoteResource::from_pretrained(DistilBertVocabResources::DISTIL_BERT_SST2)),
|
)),
|
||||||
|
config_resource: Resource::Remote(RemoteResource::from_pretrained(
|
||||||
|
DistilBertConfigResources::DISTIL_BERT_SST2,
|
||||||
|
)),
|
||||||
|
vocab_resource: Resource::Remote(RemoteResource::from_pretrained(
|
||||||
|
DistilBertVocabResources::DISTIL_BERT_SST2,
|
||||||
|
)),
|
||||||
merges_resource: None,
|
merges_resource: None,
|
||||||
lower_case: true,
|
lower_case: true,
|
||||||
device: Device::cuda_if_available(),
|
device: Device::cuda_if_available(),
|
||||||
@ -160,31 +176,38 @@ impl SequenceClassificationOption {
|
|||||||
/// * `p` - `tch::nn::Path` path to the model file to load (e.g. model.ot)
|
/// * `p` - `tch::nn::Path` path to the model file to load (e.g. model.ot)
|
||||||
/// * `config` - A configuration (the model type of the configuration must be compatible with the value for
|
/// * `config` - A configuration (the model type of the configuration must be compatible with the value for
|
||||||
/// `model_type`)
|
/// `model_type`)
|
||||||
///
|
|
||||||
pub fn new(model_type: ModelType, p: &tch::nn::Path, config: &ConfigOption) -> Self {
|
pub fn new(model_type: ModelType, p: &tch::nn::Path, config: &ConfigOption) -> Self {
|
||||||
match model_type {
|
match model_type {
|
||||||
ModelType::Bert => {
|
ModelType::Bert => {
|
||||||
if let ConfigOption::Bert(config) = config {
|
if let ConfigOption::Bert(config) = config {
|
||||||
SequenceClassificationOption::Bert(BertForSequenceClassification::new(p, config))
|
SequenceClassificationOption::Bert(BertForSequenceClassification::new(
|
||||||
|
p, config,
|
||||||
|
))
|
||||||
} else {
|
} else {
|
||||||
panic!("You can only supply a BertConfig for Bert!");
|
panic!("You can only supply a BertConfig for Bert!");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ModelType::DistilBert => {
|
ModelType::DistilBert => {
|
||||||
if let ConfigOption::DistilBert(config) = config {
|
if let ConfigOption::DistilBert(config) = config {
|
||||||
SequenceClassificationOption::DistilBert(DistilBertModelClassifier::new(p, config))
|
SequenceClassificationOption::DistilBert(DistilBertModelClassifier::new(
|
||||||
|
p, config,
|
||||||
|
))
|
||||||
} else {
|
} else {
|
||||||
panic!("You can only supply a DistilBertConfig for DistilBert!");
|
panic!("You can only supply a DistilBertConfig for DistilBert!");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ModelType::Roberta => {
|
ModelType::Roberta => {
|
||||||
if let ConfigOption::Bert(config) = config {
|
if let ConfigOption::Bert(config) = config {
|
||||||
SequenceClassificationOption::Roberta(RobertaForSequenceClassification::new(p, config))
|
SequenceClassificationOption::Roberta(RobertaForSequenceClassification::new(
|
||||||
|
p, config,
|
||||||
|
))
|
||||||
} else {
|
} else {
|
||||||
panic!("You can only supply a BertConfig for Roberta!");
|
panic!("You can only supply a BertConfig for Roberta!");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ModelType::Electra => { panic!("SequenceClassification not implemented for Electra!"); }
|
ModelType::Electra => {
|
||||||
|
panic!("SequenceClassification not implemented for Electra!");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -193,27 +216,44 @@ impl SequenceClassificationOption {
|
|||||||
match *self {
|
match *self {
|
||||||
Self::Bert(_) => ModelType::Bert,
|
Self::Bert(_) => ModelType::Bert,
|
||||||
Self::Roberta(_) => ModelType::Roberta,
|
Self::Roberta(_) => ModelType::Roberta,
|
||||||
Self::DistilBert(_) => ModelType::DistilBert
|
Self::DistilBert(_) => ModelType::DistilBert,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Interface method to forward_t() of the particular models.
|
/// Interface method to forward_t() of the particular models.
|
||||||
pub fn forward_t(&self,
|
pub fn forward_t(
|
||||||
input_ids: Option<Tensor>,
|
&self,
|
||||||
mask: Option<Tensor>,
|
input_ids: Option<Tensor>,
|
||||||
token_type_ids: Option<Tensor>,
|
mask: Option<Tensor>,
|
||||||
position_ids: Option<Tensor>,
|
token_type_ids: Option<Tensor>,
|
||||||
input_embeds: Option<Tensor>,
|
position_ids: Option<Tensor>,
|
||||||
train: bool) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
|
input_embeds: Option<Tensor>,
|
||||||
|
train: bool,
|
||||||
|
) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
|
||||||
match *self {
|
match *self {
|
||||||
Self::Bert(ref model) => model.forward_t(input_ids, mask, token_type_ids, position_ids, input_embeds, train),
|
Self::Bert(ref model) => model.forward_t(
|
||||||
Self::DistilBert(ref model) => model.forward_t(input_ids, mask, input_embeds, train).expect("Error in distilbert forward_t"),
|
input_ids,
|
||||||
Self::Roberta(ref model) => model.forward_t(input_ids, mask, token_type_ids, position_ids, input_embeds, train),
|
mask,
|
||||||
|
token_type_ids,
|
||||||
|
position_ids,
|
||||||
|
input_embeds,
|
||||||
|
train,
|
||||||
|
),
|
||||||
|
Self::DistilBert(ref model) => model
|
||||||
|
.forward_t(input_ids, mask, input_embeds, train)
|
||||||
|
.expect("Error in distilbert forward_t"),
|
||||||
|
Self::Roberta(ref model) => model.forward_t(
|
||||||
|
input_ids,
|
||||||
|
mask,
|
||||||
|
token_type_ids,
|
||||||
|
position_ids,
|
||||||
|
input_embeds,
|
||||||
|
train,
|
||||||
|
),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/// # SequenceClassificationModel for Classification (e.g. Sentiment Analysis)
|
/// # SequenceClassificationModel for Classification (e.g. Sentiment Analysis)
|
||||||
pub struct SequenceClassificationModel {
|
pub struct SequenceClassificationModel {
|
||||||
tokenizer: TokenizerOption,
|
tokenizer: TokenizerOption,
|
||||||
@ -232,15 +272,16 @@ impl SequenceClassificationModel {
|
|||||||
/// # Example
|
/// # Example
|
||||||
///
|
///
|
||||||
/// ```no_run
|
/// ```no_run
|
||||||
///# fn main() -> failure::Fallible<()> {
|
/// # fn main() -> failure::Fallible<()> {
|
||||||
/// use rust_bert::pipelines::sequence_classification::SequenceClassificationModel;
|
/// use rust_bert::pipelines::sequence_classification::SequenceClassificationModel;
|
||||||
///
|
///
|
||||||
/// let model = SequenceClassificationModel::new(Default::default())?;
|
/// let model = SequenceClassificationModel::new(Default::default())?;
|
||||||
///# Ok(())
|
/// # Ok(())
|
||||||
///# }
|
/// # }
|
||||||
/// ```
|
/// ```
|
||||||
///
|
pub fn new(
|
||||||
pub fn new(config: SequenceClassificationConfig) -> failure::Fallible<SequenceClassificationModel> {
|
config: SequenceClassificationConfig,
|
||||||
|
) -> failure::Fallible<SequenceClassificationModel> {
|
||||||
let config_path = download_resource(&config.config_resource)?;
|
let config_path = download_resource(&config.config_resource)?;
|
||||||
let vocab_path = download_resource(&config.vocab_resource)?;
|
let vocab_path = download_resource(&config.vocab_resource)?;
|
||||||
let weights_path = download_resource(&config.model_resource)?;
|
let weights_path = download_resource(&config.model_resource)?;
|
||||||
@ -251,31 +292,44 @@ impl SequenceClassificationModel {
|
|||||||
};
|
};
|
||||||
let device = config.device;
|
let device = config.device;
|
||||||
|
|
||||||
let tokenizer = TokenizerOption::from_file(config.model_type, vocab_path.to_str().unwrap(), merges_path.map(|path| path.to_str().unwrap()), config.lower_case);
|
let tokenizer = TokenizerOption::from_file(
|
||||||
|
config.model_type,
|
||||||
|
vocab_path.to_str().unwrap(),
|
||||||
|
merges_path.map(|path| path.to_str().unwrap()),
|
||||||
|
config.lower_case,
|
||||||
|
);
|
||||||
let mut var_store = VarStore::new(device);
|
let mut var_store = VarStore::new(device);
|
||||||
let model_config = ConfigOption::from_file(config.model_type, config_path);
|
let model_config = ConfigOption::from_file(config.model_type, config_path);
|
||||||
let sequence_classifier = SequenceClassificationOption::new(config.model_type, &var_store.root(), &model_config);
|
let sequence_classifier =
|
||||||
|
SequenceClassificationOption::new(config.model_type, &var_store.root(), &model_config);
|
||||||
let label_mapping = model_config.get_label_mapping();
|
let label_mapping = model_config.get_label_mapping();
|
||||||
var_store.load(weights_path)?;
|
var_store.load(weights_path)?;
|
||||||
Ok(SequenceClassificationModel { tokenizer, sequence_classifier, label_mapping, var_store })
|
Ok(SequenceClassificationModel {
|
||||||
|
tokenizer,
|
||||||
|
sequence_classifier,
|
||||||
|
label_mapping,
|
||||||
|
var_store,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn prepare_for_model(&self, input: Vec<&str>) -> Tensor {
|
fn prepare_for_model(&self, input: Vec<&str>) -> Tensor {
|
||||||
let tokenized_input: Vec<TokenizedInput> = self.tokenizer.encode_list(input.to_vec(),
|
let tokenized_input: Vec<TokenizedInput> =
|
||||||
128,
|
self.tokenizer
|
||||||
&TruncationStrategy::LongestFirst,
|
.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
||||||
0);
|
let max_len = tokenized_input
|
||||||
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
|
.iter()
|
||||||
let tokenized_input_tensors: Vec<tch::Tensor> = tokenized_input.
|
.map(|input| input.token_ids.len())
|
||||||
iter().
|
.max()
|
||||||
map(|input| input.token_ids.clone()).
|
.unwrap();
|
||||||
map(|mut input| {
|
let tokenized_input_tensors: Vec<tch::Tensor> = tokenized_input
|
||||||
|
.iter()
|
||||||
|
.map(|input| input.token_ids.clone())
|
||||||
|
.map(|mut input| {
|
||||||
input.extend(vec![0; max_len - input.len()]);
|
input.extend(vec![0; max_len - input.len()]);
|
||||||
input
|
input
|
||||||
}).
|
})
|
||||||
map(|input|
|
.map(|input| Tensor::of_slice(&(input)))
|
||||||
Tensor::of_slice(&(input))).
|
.collect::<Vec<_>>();
|
||||||
collect::<Vec<_>>();
|
|
||||||
Tensor::stack(tokenized_input_tensors.as_slice(), 0).to(self.var_store.device())
|
Tensor::stack(tokenized_input_tensors.as_slice(), 0).to(self.var_store.device())
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -292,8 +346,8 @@ impl SequenceClassificationModel {
|
|||||||
/// # Example
|
/// # Example
|
||||||
///
|
///
|
||||||
/// ```no_run
|
/// ```no_run
|
||||||
///# fn main() -> failure::Fallible<()> {
|
/// # fn main() -> failure::Fallible<()> {
|
||||||
///# use rust_bert::pipelines::sequence_classification::SequenceClassificationModel;
|
/// # use rust_bert::pipelines::sequence_classification::SequenceClassificationModel;
|
||||||
///
|
///
|
||||||
/// let sequence_classification_model = SequenceClassificationModel::new(Default::default())?;
|
/// let sequence_classification_model = SequenceClassificationModel::new(Default::default())?;
|
||||||
/// let input = [
|
/// let input = [
|
||||||
@ -302,29 +356,36 @@ impl SequenceClassificationModel {
|
|||||||
/// "If you like original gut wrenching laughter you will like this movie. If you are young or old then you will love this movie, hell even my mom liked it.",
|
/// "If you like original gut wrenching laughter you will like this movie. If you are young or old then you will love this movie, hell even my mom liked it.",
|
||||||
/// ];
|
/// ];
|
||||||
/// let output = sequence_classification_model.predict(&input);
|
/// let output = sequence_classification_model.predict(&input);
|
||||||
///# Ok(())
|
/// # Ok(())
|
||||||
///# }
|
/// # }
|
||||||
/// ```
|
/// ```
|
||||||
pub fn predict(&self, input: &[&str]) -> Vec<Label> {
|
pub fn predict(&self, input: &[&str]) -> Vec<Label> {
|
||||||
let input_tensor = self.prepare_for_model(input.to_vec());
|
let input_tensor = self.prepare_for_model(input.to_vec());
|
||||||
let output = no_grad(|| {
|
let output = no_grad(|| {
|
||||||
let (output, _, _) = self.sequence_classifier
|
let (output, _, _) = self.sequence_classifier.forward_t(
|
||||||
.forward_t(Some(input_tensor.copy()),
|
Some(input_tensor.copy()),
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
false);
|
false,
|
||||||
|
);
|
||||||
output.softmax(-1, Kind::Float).detach().to(Device::Cpu)
|
output.softmax(-1, Kind::Float).detach().to(Device::Cpu)
|
||||||
});
|
});
|
||||||
let label_indices = output.as_ref().argmax(-1, true).squeeze1(1);
|
let label_indices = output.as_ref().argmax(-1, true).squeeze1(1);
|
||||||
let scores = output.gather(1, &label_indices.unsqueeze(-1), false).squeeze1(1);
|
let scores = output
|
||||||
|
.gather(1, &label_indices.unsqueeze(-1), false)
|
||||||
|
.squeeze1(1);
|
||||||
let label_indices = label_indices.iter::<i64>().unwrap().collect::<Vec<i64>>();
|
let label_indices = label_indices.iter::<i64>().unwrap().collect::<Vec<i64>>();
|
||||||
let scores = scores.iter::<f64>().unwrap().collect::<Vec<f64>>();
|
let scores = scores.iter::<f64>().unwrap().collect::<Vec<f64>>();
|
||||||
|
|
||||||
let mut labels: Vec<Label> = vec!();
|
let mut labels: Vec<Label> = vec![];
|
||||||
for sentence_idx in 0..label_indices.len() {
|
for sentence_idx in 0..label_indices.len() {
|
||||||
let label_string = self.label_mapping.get(&label_indices[sentence_idx]).unwrap().clone();
|
let label_string = self
|
||||||
|
.label_mapping
|
||||||
|
.get(&label_indices[sentence_idx])
|
||||||
|
.unwrap()
|
||||||
|
.clone();
|
||||||
let label = Label {
|
let label = Label {
|
||||||
text: label_string,
|
text: label_string,
|
||||||
score: scores[sentence_idx],
|
score: scores[sentence_idx],
|
||||||
@ -350,8 +411,8 @@ impl SequenceClassificationModel {
|
|||||||
/// # Example
|
/// # Example
|
||||||
///
|
///
|
||||||
/// ```no_run
|
/// ```no_run
|
||||||
///# fn main() -> failure::Fallible<()> {
|
/// # fn main() -> failure::Fallible<()> {
|
||||||
///# use rust_bert::pipelines::sequence_classification::SequenceClassificationModel;
|
/// # use rust_bert::pipelines::sequence_classification::SequenceClassificationModel;
|
||||||
///
|
///
|
||||||
/// let sequence_classification_model = SequenceClassificationModel::new(Default::default())?;
|
/// let sequence_classification_model = SequenceClassificationModel::new(Default::default())?;
|
||||||
/// let input = [
|
/// let input = [
|
||||||
@ -360,34 +421,37 @@ impl SequenceClassificationModel {
|
|||||||
/// "If you like original gut wrenching laughter you will like this movie. If you are young or old then you will love this movie, hell even my mom liked it.",
|
/// "If you like original gut wrenching laughter you will like this movie. If you are young or old then you will love this movie, hell even my mom liked it.",
|
||||||
/// ];
|
/// ];
|
||||||
/// let output = sequence_classification_model.predict_multilabel(&input, 0.5);
|
/// let output = sequence_classification_model.predict_multilabel(&input, 0.5);
|
||||||
///# Ok(())
|
/// # Ok(())
|
||||||
///# }
|
/// # }
|
||||||
/// ```
|
/// ```
|
||||||
pub fn predict_multilabel(&self, input: &[&str], threshold: f64) -> Vec<Vec<Label>> {
|
pub fn predict_multilabel(&self, input: &[&str], threshold: f64) -> Vec<Vec<Label>> {
|
||||||
let input_tensor = self.prepare_for_model(input.to_vec());
|
let input_tensor = self.prepare_for_model(input.to_vec());
|
||||||
let output = no_grad(|| {
|
let output = no_grad(|| {
|
||||||
let (output, _, _) = self.sequence_classifier
|
let (output, _, _) = self.sequence_classifier.forward_t(
|
||||||
.forward_t(Some(input_tensor.copy()),
|
Some(input_tensor.copy()),
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
false);
|
false,
|
||||||
|
);
|
||||||
output.sigmoid().detach().to(Device::Cpu)
|
output.sigmoid().detach().to(Device::Cpu)
|
||||||
});
|
});
|
||||||
let label_indices = output.as_ref().ge(threshold).nonzero();
|
let label_indices = output.as_ref().ge(threshold).nonzero();
|
||||||
|
|
||||||
let mut labels: Vec<Vec<Label>> = vec!();
|
let mut labels: Vec<Vec<Label>> = vec![];
|
||||||
let mut sequence_labels: Vec<Label> = vec!();
|
let mut sequence_labels: Vec<Label> = vec![];
|
||||||
|
|
||||||
for sentence_idx in 0..label_indices.size()[0] {
|
for sentence_idx in 0..label_indices.size()[0] {
|
||||||
|
|
||||||
let label_index_tensor = label_indices.get(sentence_idx);
|
let label_index_tensor = label_indices.get(sentence_idx);
|
||||||
let sentence_label = label_index_tensor.iter::<i64>().unwrap().collect::<Vec<i64>>();
|
let sentence_label = label_index_tensor
|
||||||
|
.iter::<i64>()
|
||||||
|
.unwrap()
|
||||||
|
.collect::<Vec<i64>>();
|
||||||
let (sentence, id) = (sentence_label[0], sentence_label[1]);
|
let (sentence, id) = (sentence_label[0], sentence_label[1]);
|
||||||
if sentence as usize > labels.len() {
|
if sentence as usize > labels.len() {
|
||||||
labels.push(sequence_labels);
|
labels.push(sequence_labels);
|
||||||
sequence_labels = vec!();
|
sequence_labels = vec![];
|
||||||
}
|
}
|
||||||
let score = output.double_value(sentence_label.as_slice());
|
let score = output.double_value(sentence_label.as_slice());
|
||||||
let label_string = self.label_mapping.get(&id).unwrap().to_owned();
|
let label_string = self.label_mapping.get(&id).unwrap().to_owned();
|
||||||
@ -398,7 +462,6 @@ impl SequenceClassificationModel {
|
|||||||
sentence: sentence as usize,
|
sentence: sentence as usize,
|
||||||
};
|
};
|
||||||
sequence_labels.push(label);
|
sequence_labels.push(label);
|
||||||
|
|
||||||
}
|
}
|
||||||
if sequence_labels.len() > 0 {
|
if sequence_labels.len() > 0 {
|
||||||
labels.push(sequence_labels);
|
labels.push(sequence_labels);
|
||||||
|
@ -11,7 +11,6 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
//! # Summarization pipeline
|
//! # Summarization pipeline
|
||||||
//! Abstractive summarization of texts based on the BART encoder-decoder architecture
|
//! Abstractive summarization of texts based on the BART encoder-decoder architecture
|
||||||
//! Include techniques such as beam search, top-k and nucleus sampling, temperature setting and repetition penalty.
|
//! Include techniques such as beam search, top-k and nucleus sampling, temperature setting and repetition penalty.
|
||||||
@ -21,52 +20,54 @@
|
|||||||
//!
|
//!
|
||||||
//!
|
//!
|
||||||
//! ```no_run
|
//! ```no_run
|
||||||
//!# fn main() -> failure::Fallible<()> {
|
//! # fn main() -> failure::Fallible<()> {
|
||||||
//!# use rust_bert::pipelines::generation::LanguageGenerator;
|
//! # use rust_bert::pipelines::generation::LanguageGenerator;
|
||||||
//! use rust_bert::pipelines::summarization::SummarizationModel;
|
//! use rust_bert::pipelines::summarization::SummarizationModel;
|
||||||
//! let mut model = SummarizationModel::new(Default::default())?;
|
//! let mut model = SummarizationModel::new(Default::default())?;
|
||||||
//!
|
//!
|
||||||
//! let input = ["In findings published Tuesday in Cornell University's arXiv by a team of scientists
|
//! let input = ["In findings published Tuesday in Cornell University's arXiv by a team of scientists
|
||||||
//!from the University of Montreal and a separate report published Wednesday in Nature Astronomy by a team
|
//! from the University of Montreal and a separate report published Wednesday in Nature Astronomy by a team
|
||||||
//!from University College London (UCL), the presence of water vapour was confirmed in the atmosphere of K2-18b,
|
//! from University College London (UCL), the presence of water vapour was confirmed in the atmosphere of K2-18b,
|
||||||
//!a planet circling a star in the constellation Leo. This is the first such discovery in a planet in its star's
|
//! a planet circling a star in the constellation Leo. This is the first such discovery in a planet in its star's
|
||||||
//!habitable zone — not too hot and not too cold for liquid water to exist. The Montreal team, led by Björn Benneke,
|
//! habitable zone — not too hot and not too cold for liquid water to exist. The Montreal team, led by Björn Benneke,
|
||||||
//!used data from the NASA's Hubble telescope to assess changes in the light coming from K2-18b's star as the planet
|
//! used data from the NASA's Hubble telescope to assess changes in the light coming from K2-18b's star as the planet
|
||||||
//!passed between it and Earth. They found that certain wavelengths of light, which are usually absorbed by water,
|
//! passed between it and Earth. They found that certain wavelengths of light, which are usually absorbed by water,
|
||||||
//!weakened when the planet was in the way, indicating not only does K2-18b have an atmosphere, but the atmosphere
|
//! weakened when the planet was in the way, indicating not only does K2-18b have an atmosphere, but the atmosphere
|
||||||
//!contains water in vapour form. The team from UCL then analyzed the Montreal team's data using their own software
|
//! contains water in vapour form. The team from UCL then analyzed the Montreal team's data using their own software
|
||||||
//!and confirmed their conclusion. This was not the first time scientists have found signs of water on an exoplanet,
|
//! and confirmed their conclusion. This was not the first time scientists have found signs of water on an exoplanet,
|
||||||
//!but previous discoveries were made on planets with high temperatures or other pronounced differences from Earth.
|
//! but previous discoveries were made on planets with high temperatures or other pronounced differences from Earth.
|
||||||
//!\"This is the first potentially habitable planet where the temperature is right and where we now know there is water,\"
|
//! \"This is the first potentially habitable planet where the temperature is right and where we now know there is water,\"
|
||||||
//!said UCL astronomer Angelos Tsiaras. \"It's the best candidate for habitability right now.\" \"It's a good sign\",
|
//! said UCL astronomer Angelos Tsiaras. \"It's the best candidate for habitability right now.\" \"It's a good sign\",
|
||||||
//!said Ryan Cloutier of the Harvard–Smithsonian Center for Astrophysics, who was not one of either study's authors.
|
//! said Ryan Cloutier of the Harvard–Smithsonian Center for Astrophysics, who was not one of either study's authors.
|
||||||
//!\"Overall,\" he continued, \"the presence of water in its atmosphere certainly improves the prospect of K2-18b being
|
//! \"Overall,\" he continued, \"the presence of water in its atmosphere certainly improves the prospect of K2-18b being
|
||||||
//!a potentially habitable planet, but further observations will be required to say for sure. \"
|
//! a potentially habitable planet, but further observations will be required to say for sure. \"
|
||||||
//!K2-18b was first identified in 2015 by the Kepler space telescope. It is about 110 light-years from Earth and larger
|
//! K2-18b was first identified in 2015 by the Kepler space telescope. It is about 110 light-years from Earth and larger
|
||||||
//!but less dense. Its star, a red dwarf, is cooler than the Sun, but the planet's orbit is much closer, such that a year
|
//! but less dense. Its star, a red dwarf, is cooler than the Sun, but the planet's orbit is much closer, such that a year
|
||||||
//!on K2-18b lasts 33 Earth days. According to The Guardian, astronomers were optimistic that NASA's James Webb space
|
//! on K2-18b lasts 33 Earth days. According to The Guardian, astronomers were optimistic that NASA's James Webb space
|
||||||
//!telescope — scheduled for launch in 2021 — and the European Space Agency's 2028 ARIEL program, could reveal more
|
//! telescope — scheduled for launch in 2021 — and the European Space Agency's 2028 ARIEL program, could reveal more
|
||||||
//!about exoplanets like K2-18b."];
|
//! about exoplanets like K2-18b."];
|
||||||
//!
|
//!
|
||||||
//! let output = model.summarize(&input);
|
//! let output = model.summarize(&input);
|
||||||
//!# Ok(())
|
//! # Ok(())
|
||||||
//!# }
|
//! # }
|
||||||
//! ```
|
//! ```
|
||||||
//! (New sample credits: [WikiNews](https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b))
|
//! (New sample credits: [WikiNews](https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b))
|
||||||
//!
|
//!
|
||||||
//! Example output: \
|
//! Example output: \
|
||||||
//! ```no_run
|
//! ```no_run
|
||||||
//!# let output =
|
//! # let output =
|
||||||
//! "Scientists have found water vapour on K2-18b, a planet 110 light-years from Earth.
|
//! "Scientists have found water vapour on K2-18b, a planet 110 light-years from Earth.
|
||||||
//! This is the first such discovery in a planet in its star's habitable zone.
|
//! This is the first such discovery in a planet in its star's habitable zone.
|
||||||
//! The planet is not too hot and not too cold for liquid water to exist."
|
//! The planet is not too hot and not too cold for liquid water to exist."
|
||||||
//!# ;
|
//! # ;
|
||||||
//!```
|
//! ```
|
||||||
|
|
||||||
|
use crate::bart::{
|
||||||
|
BartConfigResources, BartMergesResources, BartModelResources, BartVocabResources,
|
||||||
|
};
|
||||||
|
use crate::common::resources::{RemoteResource, Resource};
|
||||||
use crate::pipelines::generation::{BartGenerator, GenerateConfig, LanguageGenerator};
|
use crate::pipelines::generation::{BartGenerator, GenerateConfig, LanguageGenerator};
|
||||||
use tch::Device;
|
use tch::Device;
|
||||||
use crate::common::resources::{Resource, RemoteResource};
|
|
||||||
use crate::bart::{BartModelResources, BartConfigResources, BartVocabResources, BartMergesResources};
|
|
||||||
|
|
||||||
/// # Configuration for text summarization
|
/// # Configuration for text summarization
|
||||||
/// Contains information regarding the model to load, mirrors the GenerationConfig, with a
|
/// Contains information regarding the model to load, mirrors the GenerationConfig, with a
|
||||||
@ -111,10 +112,18 @@ pub struct SummarizationConfig {
|
|||||||
impl Default for SummarizationConfig {
|
impl Default for SummarizationConfig {
|
||||||
fn default() -> SummarizationConfig {
|
fn default() -> SummarizationConfig {
|
||||||
SummarizationConfig {
|
SummarizationConfig {
|
||||||
model_resource: Resource::Remote(RemoteResource::from_pretrained(BartModelResources::BART_CNN)),
|
model_resource: Resource::Remote(RemoteResource::from_pretrained(
|
||||||
config_resource: Resource::Remote(RemoteResource::from_pretrained(BartConfigResources::BART_CNN)),
|
BartModelResources::BART_CNN,
|
||||||
vocab_resource: Resource::Remote(RemoteResource::from_pretrained(BartVocabResources::BART_CNN)),
|
)),
|
||||||
merges_resource: Resource::Remote(RemoteResource::from_pretrained(BartMergesResources::BART_CNN)),
|
config_resource: Resource::Remote(RemoteResource::from_pretrained(
|
||||||
|
BartConfigResources::BART_CNN,
|
||||||
|
)),
|
||||||
|
vocab_resource: Resource::Remote(RemoteResource::from_pretrained(
|
||||||
|
BartVocabResources::BART_CNN,
|
||||||
|
)),
|
||||||
|
merges_resource: Resource::Remote(RemoteResource::from_pretrained(
|
||||||
|
BartMergesResources::BART_CNN,
|
||||||
|
)),
|
||||||
min_length: 56,
|
min_length: 56,
|
||||||
max_length: 142,
|
max_length: 142,
|
||||||
do_sample: false,
|
do_sample: false,
|
||||||
@ -134,7 +143,7 @@ impl Default for SummarizationConfig {
|
|||||||
|
|
||||||
/// # SummarizationModel to perform summarization
|
/// # SummarizationModel to perform summarization
|
||||||
pub struct SummarizationModel {
|
pub struct SummarizationModel {
|
||||||
model: BartGenerator
|
model: BartGenerator,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl SummarizationModel {
|
impl SummarizationModel {
|
||||||
@ -147,16 +156,14 @@ impl SummarizationModel {
|
|||||||
/// # Example
|
/// # Example
|
||||||
///
|
///
|
||||||
/// ```no_run
|
/// ```no_run
|
||||||
///# fn main() -> failure::Fallible<()> {
|
/// # fn main() -> failure::Fallible<()> {
|
||||||
/// use rust_bert::pipelines::summarization::SummarizationModel;
|
/// use rust_bert::pipelines::summarization::SummarizationModel;
|
||||||
///
|
///
|
||||||
/// let mut summarization_model = SummarizationModel::new(Default::default())?;
|
/// let mut summarization_model = SummarizationModel::new(Default::default())?;
|
||||||
///# Ok(())
|
/// # Ok(())
|
||||||
///# }
|
/// # }
|
||||||
/// ```
|
/// ```
|
||||||
///
|
pub fn new(summarization_config: SummarizationConfig) -> failure::Fallible<SummarizationModel> {
|
||||||
pub fn new(summarization_config: SummarizationConfig)
|
|
||||||
-> failure::Fallible<SummarizationModel> {
|
|
||||||
let generate_config = GenerateConfig {
|
let generate_config = GenerateConfig {
|
||||||
model_resource: summarization_config.model_resource,
|
model_resource: summarization_config.model_resource,
|
||||||
config_resource: summarization_config.config_resource,
|
config_resource: summarization_config.config_resource,
|
||||||
@ -194,40 +201,39 @@ impl SummarizationModel {
|
|||||||
/// # Example
|
/// # Example
|
||||||
///
|
///
|
||||||
/// ```no_run
|
/// ```no_run
|
||||||
///# fn main() -> failure::Fallible<()> {
|
/// # fn main() -> failure::Fallible<()> {
|
||||||
/// use rust_bert::pipelines::generation::LanguageGenerator;
|
/// use rust_bert::pipelines::generation::LanguageGenerator;
|
||||||
/// use rust_bert::pipelines::summarization::SummarizationModel;
|
/// use rust_bert::pipelines::summarization::SummarizationModel;
|
||||||
/// let model = SummarizationModel::new(Default::default())?;
|
/// let model = SummarizationModel::new(Default::default())?;
|
||||||
///
|
///
|
||||||
/// let input = ["In findings published Tuesday in Cornell University's arXiv by a team of scientists
|
/// let input = ["In findings published Tuesday in Cornell University's arXiv by a team of scientists
|
||||||
///from the University of Montreal and a separate report published Wednesday in Nature Astronomy by a team
|
/// from the University of Montreal and a separate report published Wednesday in Nature Astronomy by a team
|
||||||
///from University College London (UCL), the presence of water vapour was confirmed in the atmosphere of K2-18b,
|
/// from University College London (UCL), the presence of water vapour was confirmed in the atmosphere of K2-18b,
|
||||||
///a planet circling a star in the constellation Leo. This is the first such discovery in a planet in its star's
|
/// a planet circling a star in the constellation Leo. This is the first such discovery in a planet in its star's
|
||||||
///habitable zone — not too hot and not too cold for liquid water to exist. The Montreal team, led by Björn Benneke,
|
/// habitable zone — not too hot and not too cold for liquid water to exist. The Montreal team, led by Björn Benneke,
|
||||||
///used data from the NASA's Hubble telescope to assess changes in the light coming from K2-18b's star as the planet
|
/// used data from the NASA's Hubble telescope to assess changes in the light coming from K2-18b's star as the planet
|
||||||
///passed between it and Earth. They found that certain wavelengths of light, which are usually absorbed by water,
|
/// passed between it and Earth. They found that certain wavelengths of light, which are usually absorbed by water,
|
||||||
///weakened when the planet was in the way, indicating not only does K2-18b have an atmosphere, but the atmosphere
|
/// weakened when the planet was in the way, indicating not only does K2-18b have an atmosphere, but the atmosphere
|
||||||
///contains water in vapour form. The team from UCL then analyzed the Montreal team's data using their own software
|
/// contains water in vapour form. The team from UCL then analyzed the Montreal team's data using their own software
|
||||||
///and confirmed their conclusion. This was not the first time scientists have found signs of water on an exoplanet,
|
/// and confirmed their conclusion. This was not the first time scientists have found signs of water on an exoplanet,
|
||||||
///but previous discoveries were made on planets with high temperatures or other pronounced differences from Earth.
|
/// but previous discoveries were made on planets with high temperatures or other pronounced differences from Earth.
|
||||||
///\"This is the first potentially habitable planet where the temperature is right and where we now know there is water,\"
|
/// \"This is the first potentially habitable planet where the temperature is right and where we now know there is water,\"
|
||||||
///said UCL astronomer Angelos Tsiaras. \"It's the best candidate for habitability right now.\" \"It's a good sign\",
|
/// said UCL astronomer Angelos Tsiaras. \"It's the best candidate for habitability right now.\" \"It's a good sign\",
|
||||||
///said Ryan Cloutier of the Harvard–Smithsonian Center for Astrophysics, who was not one of either study's authors.
|
/// said Ryan Cloutier of the Harvard–Smithsonian Center for Astrophysics, who was not one of either study's authors.
|
||||||
///\"Overall,\" he continued, \"the presence of water in its atmosphere certainly improves the prospect of K2-18b being
|
/// \"Overall,\" he continued, \"the presence of water in its atmosphere certainly improves the prospect of K2-18b being
|
||||||
///a potentially habitable planet, but further observations will be required to say for sure. \"
|
/// a potentially habitable planet, but further observations will be required to say for sure. \"
|
||||||
///K2-18b was first identified in 2015 by the Kepler space telescope. It is about 110 light-years from Earth and larger
|
/// K2-18b was first identified in 2015 by the Kepler space telescope. It is about 110 light-years from Earth and larger
|
||||||
///but less dense. Its star, a red dwarf, is cooler than the Sun, but the planet's orbit is much closer, such that a year
|
/// but less dense. Its star, a red dwarf, is cooler than the Sun, but the planet's orbit is much closer, such that a year
|
||||||
///on K2-18b lasts 33 Earth days. According to The Guardian, astronomers were optimistic that NASA's James Webb space
|
/// on K2-18b lasts 33 Earth days. According to The Guardian, astronomers were optimistic that NASA's James Webb space
|
||||||
///telescope — scheduled for launch in 2021 — and the European Space Agency's 2028 ARIEL program, could reveal more
|
/// telescope — scheduled for launch in 2021 — and the European Space Agency's 2028 ARIEL program, could reveal more
|
||||||
///about exoplanets like K2-18b."];
|
/// about exoplanets like K2-18b."];
|
||||||
///
|
///
|
||||||
/// let output = model.summarize(&input);
|
/// let output = model.summarize(&input);
|
||||||
///# Ok(())
|
/// # Ok(())
|
||||||
///# }
|
/// # }
|
||||||
/// ```
|
/// ```
|
||||||
/// (New sample credits: [WikiNews](https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b))
|
/// (New sample credits: [WikiNews](https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b))
|
||||||
///
|
|
||||||
pub fn summarize(&self, texts: &[&str]) -> Vec<String> {
|
pub fn summarize(&self, texts: &[&str]) -> Vec<String> {
|
||||||
self.model.generate(Some(texts.to_vec()), None)
|
self.model.generate(Some(texts.to_vec()), None)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -19,7 +19,7 @@
|
|||||||
//! use rust_bert::resources::{Resource,RemoteResource};
|
//! use rust_bert::resources::{Resource,RemoteResource};
|
||||||
//! use rust_bert::bert::{BertModelResources, BertVocabResources, BertConfigResources};
|
//! use rust_bert::bert::{BertModelResources, BertVocabResources, BertConfigResources};
|
||||||
//! use rust_bert::pipelines::common::ModelType;
|
//! use rust_bert::pipelines::common::ModelType;
|
||||||
//!# fn main() -> failure::Fallible<()> {
|
//! # fn main() -> failure::Fallible<()> {
|
||||||
//!
|
//!
|
||||||
//! //Load a configuration
|
//! //Load a configuration
|
||||||
//! use rust_bert::pipelines::token_classification::LabelAggregationOption;
|
//! use rust_bert::pipelines::token_classification::LabelAggregationOption;
|
||||||
@ -40,41 +40,94 @@
|
|||||||
//! "Paris is a city in France."
|
//! "Paris is a city in France."
|
||||||
//! ];
|
//! ];
|
||||||
//! let output = token_classification_model.predict(&input, true, true); //ignore_first_label = true (only returns the NER parts, ignoring first label O)
|
//! let output = token_classification_model.predict(&input, true, true); //ignore_first_label = true (only returns the NER parts, ignoring first label O)
|
||||||
//!# Ok(())
|
//! # Ok(())
|
||||||
//!# }
|
//! # }
|
||||||
//! ```
|
//! ```
|
||||||
//! Output: \
|
//! Output: \
|
||||||
//! ```no_run
|
//! ```no_run
|
||||||
//!# use rust_bert::pipelines::token_classification::Token;
|
//! # use rust_bert::pipelines::token_classification::Token;
|
||||||
//! use rust_tokenizers::preprocessing::tokenizer::base_tokenizer::Mask::Special;
|
//! use rust_tokenizers::preprocessing::tokenizer::base_tokenizer::Mask::Special;
|
||||||
//! use rust_tokenizers::preprocessing::tokenizer::base_tokenizer::{Offset, Mask};
|
//! use rust_tokenizers::preprocessing::tokenizer::base_tokenizer::{Mask, Offset};
|
||||||
//!# let output =
|
//! # let output =
|
||||||
//! [
|
//! [
|
||||||
//! Token { text: String::from("[CLS]"), score: 0.9995001554489136, label: String::from("O"), label_index: 0, sentence: 0, index: 0, word_index: 0, offset: None, mask: Special },
|
//! Token {
|
||||||
//! Token { text: String::from("My"), score: 0.9980450868606567, label: String::from("O"), label_index: 0, sentence: 0, index: 1, word_index: 1, offset: Some(Offset { begin: 0, end: 2 }), mask: Mask::None },
|
//! text: String::from("[CLS]"),
|
||||||
//! Token { text: String::from("name"), score: 0.9995062351226807, label: String::from("O"), label_index: 0, sentence: 0, index: 2, word_index: 2, offset: Some(Offset { begin: 3, end: 7 }), mask: Mask::None },
|
//! score: 0.9995001554489136,
|
||||||
//! Token { text: String::from("is"), score: 0.9997343420982361, label: String::from("O"), label_index: 0, sentence: 0, index: 3, word_index: 3, offset: Some(Offset { begin: 8, end: 10 }), mask: Mask::None },
|
//! label: String::from("O"),
|
||||||
//! Token { text: String::from("Amélie"), score: 0.9913727683112525, label: String::from("I-PER"), label_index: 4, sentence: 0, index: 4, word_index: 4, offset: Some(Offset { begin: 11, end: 17 }), mask: Mask::None }
|
//! label_index: 0,
|
||||||
//! // ...
|
//! sentence: 0,
|
||||||
|
//! index: 0,
|
||||||
|
//! word_index: 0,
|
||||||
|
//! offset: None,
|
||||||
|
//! mask: Special,
|
||||||
|
//! },
|
||||||
|
//! Token {
|
||||||
|
//! text: String::from("My"),
|
||||||
|
//! score: 0.9980450868606567,
|
||||||
|
//! label: String::from("O"),
|
||||||
|
//! label_index: 0,
|
||||||
|
//! sentence: 0,
|
||||||
|
//! index: 1,
|
||||||
|
//! word_index: 1,
|
||||||
|
//! offset: Some(Offset { begin: 0, end: 2 }),
|
||||||
|
//! mask: Mask::None,
|
||||||
|
//! },
|
||||||
|
//! Token {
|
||||||
|
//! text: String::from("name"),
|
||||||
|
//! score: 0.9995062351226807,
|
||||||
|
//! label: String::from("O"),
|
||||||
|
//! label_index: 0,
|
||||||
|
//! sentence: 0,
|
||||||
|
//! index: 2,
|
||||||
|
//! word_index: 2,
|
||||||
|
//! offset: Some(Offset { begin: 3, end: 7 }),
|
||||||
|
//! mask: Mask::None,
|
||||||
|
//! },
|
||||||
|
//! Token {
|
||||||
|
//! text: String::from("is"),
|
||||||
|
//! score: 0.9997343420982361,
|
||||||
|
//! label: String::from("O"),
|
||||||
|
//! label_index: 0,
|
||||||
|
//! sentence: 0,
|
||||||
|
//! index: 3,
|
||||||
|
//! word_index: 3,
|
||||||
|
//! offset: Some(Offset { begin: 8, end: 10 }),
|
||||||
|
//! mask: Mask::None,
|
||||||
|
//! },
|
||||||
|
//! Token {
|
||||||
|
//! text: String::from("Amélie"),
|
||||||
|
//! score: 0.9913727683112525,
|
||||||
|
//! label: String::from("I-PER"),
|
||||||
|
//! label_index: 4,
|
||||||
|
//! sentence: 0,
|
||||||
|
//! index: 4,
|
||||||
|
//! word_index: 4,
|
||||||
|
//! offset: Some(Offset { begin: 11, end: 17 }),
|
||||||
|
//! mask: Mask::None,
|
||||||
|
//! }, // ...
|
||||||
//! ]
|
//! ]
|
||||||
//!# ;
|
//! # ;
|
||||||
//! ```
|
//! ```
|
||||||
|
|
||||||
use tch::nn::VarStore;
|
use crate::bert::{
|
||||||
use rust_tokenizers::preprocessing::tokenizer::base_tokenizer::{Tokenizer, TokenizedInput, TruncationStrategy, Mask, Offset, ConsolidatableTokens, ConsolidatedTokenIterator, TokenTrait};
|
BertConfigResources, BertForTokenClassification, BertModelResources, BertVocabResources,
|
||||||
use std::collections::HashMap;
|
};
|
||||||
use tch::{Tensor, no_grad, Device};
|
use crate::common::resources::{download_resource, RemoteResource, Resource};
|
||||||
use tch::kind::Kind::Float;
|
|
||||||
use crate::bert::{BertForTokenClassification, BertModelResources, BertConfigResources, BertVocabResources};
|
|
||||||
use crate::roberta::RobertaForTokenClassification;
|
|
||||||
use crate::distilbert::DistilBertForTokenClassification;
|
use crate::distilbert::DistilBertForTokenClassification;
|
||||||
use crate::common::resources::{Resource, RemoteResource, download_resource};
|
|
||||||
use crate::pipelines::common::{ModelType, ConfigOption, TokenizerOption};
|
|
||||||
use crate::electra::ElectraForTokenClassification;
|
use crate::electra::ElectraForTokenClassification;
|
||||||
|
use crate::pipelines::common::{ConfigOption, ModelType, TokenizerOption};
|
||||||
|
use crate::roberta::RobertaForTokenClassification;
|
||||||
use itertools::Itertools;
|
use itertools::Itertools;
|
||||||
|
use rust_tokenizers::preprocessing::tokenizer::base_tokenizer::{
|
||||||
|
ConsolidatableTokens, ConsolidatedTokenIterator, Mask, Offset, TokenTrait, TokenizedInput,
|
||||||
|
Tokenizer, TruncationStrategy,
|
||||||
|
};
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
use std::cmp::min;
|
use std::cmp::min;
|
||||||
use serde::{Serialize, Deserialize};
|
use std::collections::HashMap;
|
||||||
|
use tch::kind::Kind::Float;
|
||||||
|
use tch::nn::VarStore;
|
||||||
|
use tch::{no_grad, Device, Tensor};
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
/// # Token generated by a `TokenClassificationModel`
|
/// # Token generated by a `TokenClassificationModel`
|
||||||
@ -140,7 +193,6 @@ pub enum LabelAggregationOption {
|
|||||||
Custom(Box<dyn Fn(&[Token]) -> (i64, String)>),
|
Custom(Box<dyn Fn(&[Token]) -> (i64, String)>),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/// # Configuration for TokenClassificationModel
|
/// # Configuration for TokenClassificationModel
|
||||||
/// Contains information regarding the model to load and device to place the model on.
|
/// Contains information regarding the model to load and device to place the model on.
|
||||||
pub struct TokenClassificationConfig {
|
pub struct TokenClassificationConfig {
|
||||||
@ -173,14 +225,15 @@ impl TokenClassificationConfig {
|
|||||||
/// * vocab - The `Resource' pointing to the tokenizer's vocabulary to load (e.g. vocab.txt/vocab.json)
|
/// * vocab - The `Resource' pointing to the tokenizer's vocabulary to load (e.g. vocab.txt/vocab.json)
|
||||||
/// * vocab - An optional `Resource` tuple (`Option<Resource>`) pointing to the tokenizer's merge file to load (e.g. merges.txt), needed only for Roberta.
|
/// * vocab - An optional `Resource` tuple (`Option<Resource>`) pointing to the tokenizer's merge file to load (e.g. merges.txt), needed only for Roberta.
|
||||||
/// * lower_case - A `bool' indicating whether the tokenizer should lower case all input (in case of a lower-cased model)
|
/// * lower_case - A `bool' indicating whether the tokenizer should lower case all input (in case of a lower-cased model)
|
||||||
///
|
pub fn new(
|
||||||
pub fn new(model_type: ModelType,
|
model_type: ModelType,
|
||||||
model_resource: Resource,
|
model_resource: Resource,
|
||||||
config_resource: Resource,
|
config_resource: Resource,
|
||||||
vocab_resource: Resource,
|
vocab_resource: Resource,
|
||||||
merges_resource: Option<Resource>,
|
merges_resource: Option<Resource>,
|
||||||
lower_case: bool,
|
lower_case: bool,
|
||||||
label_aggregation_function: LabelAggregationOption) -> TokenClassificationConfig {
|
label_aggregation_function: LabelAggregationOption,
|
||||||
|
) -> TokenClassificationConfig {
|
||||||
TokenClassificationConfig {
|
TokenClassificationConfig {
|
||||||
model_type,
|
model_type,
|
||||||
model_resource,
|
model_resource,
|
||||||
@ -199,9 +252,15 @@ impl Default for TokenClassificationConfig {
|
|||||||
fn default() -> TokenClassificationConfig {
|
fn default() -> TokenClassificationConfig {
|
||||||
TokenClassificationConfig {
|
TokenClassificationConfig {
|
||||||
model_type: ModelType::Bert,
|
model_type: ModelType::Bert,
|
||||||
model_resource: Resource::Remote(RemoteResource::from_pretrained(BertModelResources::BERT_NER)),
|
model_resource: Resource::Remote(RemoteResource::from_pretrained(
|
||||||
config_resource: Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT_NER)),
|
BertModelResources::BERT_NER,
|
||||||
vocab_resource: Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT_NER)),
|
)),
|
||||||
|
config_resource: Resource::Remote(RemoteResource::from_pretrained(
|
||||||
|
BertConfigResources::BERT_NER,
|
||||||
|
)),
|
||||||
|
vocab_resource: Resource::Remote(RemoteResource::from_pretrained(
|
||||||
|
BertVocabResources::BERT_NER,
|
||||||
|
)),
|
||||||
merges_resource: None,
|
merges_resource: None,
|
||||||
lower_case: false,
|
lower_case: false,
|
||||||
device: Device::cuda_if_available(),
|
device: Device::cuda_if_available(),
|
||||||
@ -231,7 +290,6 @@ impl TokenClassificationOption {
|
|||||||
/// * `p` - `tch::nn::Path` path to the model file to load (e.g. model.ot)
|
/// * `p` - `tch::nn::Path` path to the model file to load (e.g. model.ot)
|
||||||
/// * `config` - A configuration (the model type of the configuration must be compatible with the value for
|
/// * `config` - A configuration (the model type of the configuration must be compatible with the value for
|
||||||
/// `model_type`)
|
/// `model_type`)
|
||||||
///
|
|
||||||
pub fn new(model_type: ModelType, p: &tch::nn::Path, config: &ConfigOption) -> Self {
|
pub fn new(model_type: ModelType, p: &tch::nn::Path, config: &ConfigOption) -> Self {
|
||||||
match model_type {
|
match model_type {
|
||||||
ModelType::Bert => {
|
ModelType::Bert => {
|
||||||
@ -243,21 +301,27 @@ impl TokenClassificationOption {
|
|||||||
}
|
}
|
||||||
ModelType::DistilBert => {
|
ModelType::DistilBert => {
|
||||||
if let ConfigOption::DistilBert(config) = config {
|
if let ConfigOption::DistilBert(config) = config {
|
||||||
TokenClassificationOption::DistilBert(DistilBertForTokenClassification::new(p, config))
|
TokenClassificationOption::DistilBert(DistilBertForTokenClassification::new(
|
||||||
|
p, config,
|
||||||
|
))
|
||||||
} else {
|
} else {
|
||||||
panic!("You can only supply a DistilBertConfig for DistilBert!");
|
panic!("You can only supply a DistilBertConfig for DistilBert!");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ModelType::Roberta => {
|
ModelType::Roberta => {
|
||||||
if let ConfigOption::Bert(config) = config {
|
if let ConfigOption::Bert(config) = config {
|
||||||
TokenClassificationOption::Roberta(RobertaForTokenClassification::new(p, config))
|
TokenClassificationOption::Roberta(RobertaForTokenClassification::new(
|
||||||
|
p, config,
|
||||||
|
))
|
||||||
} else {
|
} else {
|
||||||
panic!("You can only supply a BertConfig for Roberta!");
|
panic!("You can only supply a BertConfig for Roberta!");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ModelType::Electra => {
|
ModelType::Electra => {
|
||||||
if let ConfigOption::Electra(config) = config {
|
if let ConfigOption::Electra(config) = config {
|
||||||
TokenClassificationOption::Electra(ElectraForTokenClassification::new(p, config))
|
TokenClassificationOption::Electra(ElectraForTokenClassification::new(
|
||||||
|
p, config,
|
||||||
|
))
|
||||||
} else {
|
} else {
|
||||||
panic!("You can only supply a BertConfig for Roberta!");
|
panic!("You can only supply a BertConfig for Roberta!");
|
||||||
}
|
}
|
||||||
@ -271,27 +335,51 @@ impl TokenClassificationOption {
|
|||||||
Self::Bert(_) => ModelType::Bert,
|
Self::Bert(_) => ModelType::Bert,
|
||||||
Self::Roberta(_) => ModelType::Roberta,
|
Self::Roberta(_) => ModelType::Roberta,
|
||||||
Self::DistilBert(_) => ModelType::DistilBert,
|
Self::DistilBert(_) => ModelType::DistilBert,
|
||||||
Self::Electra(_) => ModelType::Electra
|
Self::Electra(_) => ModelType::Electra,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn forward_t(&self,
|
fn forward_t(
|
||||||
input_ids: Option<Tensor>,
|
&self,
|
||||||
mask: Option<Tensor>,
|
input_ids: Option<Tensor>,
|
||||||
token_type_ids: Option<Tensor>,
|
mask: Option<Tensor>,
|
||||||
position_ids: Option<Tensor>,
|
token_type_ids: Option<Tensor>,
|
||||||
input_embeds: Option<Tensor>,
|
position_ids: Option<Tensor>,
|
||||||
train: bool) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
|
input_embeds: Option<Tensor>,
|
||||||
|
train: bool,
|
||||||
|
) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
|
||||||
match *self {
|
match *self {
|
||||||
Self::Bert(ref model) => model.forward_t(input_ids, mask, token_type_ids, position_ids, input_embeds, train),
|
Self::Bert(ref model) => model.forward_t(
|
||||||
Self::DistilBert(ref model) => model.forward_t(input_ids, mask, input_embeds, train).expect("Error in distilbert forward_t"),
|
input_ids,
|
||||||
Self::Roberta(ref model) => model.forward_t(input_ids, mask, token_type_ids, position_ids, input_embeds, train),
|
mask,
|
||||||
Self::Electra(ref model) => model.forward_t(input_ids, mask, token_type_ids, position_ids, input_embeds, train),
|
token_type_ids,
|
||||||
|
position_ids,
|
||||||
|
input_embeds,
|
||||||
|
train,
|
||||||
|
),
|
||||||
|
Self::DistilBert(ref model) => model
|
||||||
|
.forward_t(input_ids, mask, input_embeds, train)
|
||||||
|
.expect("Error in distilbert forward_t"),
|
||||||
|
Self::Roberta(ref model) => model.forward_t(
|
||||||
|
input_ids,
|
||||||
|
mask,
|
||||||
|
token_type_ids,
|
||||||
|
position_ids,
|
||||||
|
input_embeds,
|
||||||
|
train,
|
||||||
|
),
|
||||||
|
Self::Electra(ref model) => model.forward_t(
|
||||||
|
input_ids,
|
||||||
|
mask,
|
||||||
|
token_type_ids,
|
||||||
|
position_ids,
|
||||||
|
input_embeds,
|
||||||
|
train,
|
||||||
|
),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/// # TokenClassificationModel for Named Entity Recognition or Part-of-Speech tagging
|
/// # TokenClassificationModel for Named Entity Recognition or Part-of-Speech tagging
|
||||||
pub struct TokenClassificationModel {
|
pub struct TokenClassificationModel {
|
||||||
tokenizer: TokenizerOption,
|
tokenizer: TokenizerOption,
|
||||||
@ -311,14 +399,13 @@ impl TokenClassificationModel {
|
|||||||
/// # Example
|
/// # Example
|
||||||
///
|
///
|
||||||
/// ```no_run
|
/// ```no_run
|
||||||
///# fn main() -> failure::Fallible<()> {
|
/// # fn main() -> failure::Fallible<()> {
|
||||||
/// use rust_bert::pipelines::token_classification::TokenClassificationModel;
|
/// use rust_bert::pipelines::token_classification::TokenClassificationModel;
|
||||||
///
|
///
|
||||||
/// let model = TokenClassificationModel::new(Default::default())?;
|
/// let model = TokenClassificationModel::new(Default::default())?;
|
||||||
///# Ok(())
|
/// # Ok(())
|
||||||
///# }
|
/// # }
|
||||||
/// ```
|
/// ```
|
||||||
///
|
|
||||||
pub fn new(config: TokenClassificationConfig) -> failure::Fallible<TokenClassificationModel> {
|
pub fn new(config: TokenClassificationConfig) -> failure::Fallible<TokenClassificationModel> {
|
||||||
let config_path = download_resource(&config.config_resource)?;
|
let config_path = download_resource(&config.config_resource)?;
|
||||||
let vocab_path = download_resource(&config.vocab_resource)?;
|
let vocab_path = download_resource(&config.vocab_resource)?;
|
||||||
@ -331,32 +418,49 @@ impl TokenClassificationModel {
|
|||||||
let device = config.device;
|
let device = config.device;
|
||||||
let label_aggregation_function = config.label_aggregation_function;
|
let label_aggregation_function = config.label_aggregation_function;
|
||||||
|
|
||||||
let tokenizer = TokenizerOption::from_file(config.model_type, vocab_path.to_str().unwrap(), merges_path.map(|path| path.to_str().unwrap()), config.lower_case);
|
let tokenizer = TokenizerOption::from_file(
|
||||||
|
config.model_type,
|
||||||
|
vocab_path.to_str().unwrap(),
|
||||||
|
merges_path.map(|path| path.to_str().unwrap()),
|
||||||
|
config.lower_case,
|
||||||
|
);
|
||||||
let mut var_store = VarStore::new(device);
|
let mut var_store = VarStore::new(device);
|
||||||
let model_config = ConfigOption::from_file(config.model_type, config_path);
|
let model_config = ConfigOption::from_file(config.model_type, config_path);
|
||||||
let token_sequence_classifier = TokenClassificationOption::new(config.model_type, &var_store.root(), &model_config);
|
let token_sequence_classifier =
|
||||||
|
TokenClassificationOption::new(config.model_type, &var_store.root(), &model_config);
|
||||||
let label_mapping = model_config.get_label_mapping();
|
let label_mapping = model_config.get_label_mapping();
|
||||||
var_store.load(weights_path)?;
|
var_store.load(weights_path)?;
|
||||||
Ok(TokenClassificationModel { tokenizer, token_sequence_classifier, label_mapping, var_store, label_aggregation_function })
|
Ok(TokenClassificationModel {
|
||||||
|
tokenizer,
|
||||||
|
token_sequence_classifier,
|
||||||
|
label_mapping,
|
||||||
|
var_store,
|
||||||
|
label_aggregation_function,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn prepare_for_model(&self, input: Vec<&str>) -> (Vec<TokenizedInput>, Tensor) {
|
fn prepare_for_model(&self, input: Vec<&str>) -> (Vec<TokenizedInput>, Tensor) {
|
||||||
let tokenized_input: Vec<TokenizedInput> = self.tokenizer.encode_list(input.to_vec(),
|
let tokenized_input: Vec<TokenizedInput> =
|
||||||
128,
|
self.tokenizer
|
||||||
&TruncationStrategy::LongestFirst,
|
.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
||||||
0);
|
let max_len = tokenized_input
|
||||||
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
|
.iter()
|
||||||
let tokenized_input_tensors: Vec<tch::Tensor> = tokenized_input.
|
.map(|input| input.token_ids.len())
|
||||||
iter().
|
.max()
|
||||||
map(|input| input.token_ids.clone()).
|
.unwrap();
|
||||||
map(|mut input| {
|
let tokenized_input_tensors: Vec<tch::Tensor> = tokenized_input
|
||||||
|
.iter()
|
||||||
|
.map(|input| input.token_ids.clone())
|
||||||
|
.map(|mut input| {
|
||||||
input.extend(vec![0; max_len - input.len()]);
|
input.extend(vec![0; max_len - input.len()]);
|
||||||
input
|
input
|
||||||
}).
|
})
|
||||||
map(|input|
|
.map(|input| Tensor::of_slice(&(input)))
|
||||||
Tensor::of_slice(&(input))).
|
.collect::<Vec<_>>();
|
||||||
collect::<Vec<_>>();
|
(
|
||||||
(tokenized_input, Tensor::stack(tokenized_input_tensors.as_slice(), 0).to(self.var_store.device()))
|
tokenized_input,
|
||||||
|
Tensor::stack(tokenized_input_tensors.as_slice(), 0).to(self.var_store.device()),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Classify tokens in a text sequence
|
/// Classify tokens in a text sequence
|
||||||
@ -374,33 +478,39 @@ impl TokenClassificationModel {
|
|||||||
/// # Example
|
/// # Example
|
||||||
///
|
///
|
||||||
/// ```no_run
|
/// ```no_run
|
||||||
///# fn main() -> failure::Fallible<()> {
|
/// # fn main() -> failure::Fallible<()> {
|
||||||
///# use rust_bert::pipelines::token_classification::TokenClassificationModel;
|
/// # use rust_bert::pipelines::token_classification::TokenClassificationModel;
|
||||||
///
|
///
|
||||||
/// let ner_model = TokenClassificationModel::new(Default::default())?;
|
/// let ner_model = TokenClassificationModel::new(Default::default())?;
|
||||||
/// let input = [
|
/// let input = [
|
||||||
/// "My name is Amy. I live in Paris.",
|
/// "My name is Amy. I live in Paris.",
|
||||||
/// "Paris is a city in France."
|
/// "Paris is a city in France.",
|
||||||
/// ];
|
/// ];
|
||||||
/// let output = ner_model.predict(&input, true, true);
|
/// let output = ner_model.predict(&input, true, true);
|
||||||
///# Ok(())
|
/// # Ok(())
|
||||||
///# }
|
/// # }
|
||||||
/// ```
|
/// ```
|
||||||
pub fn predict(&self, input: &[&str], consolidate_sub_tokens: bool, return_special: bool) -> Vec<Token> {
|
pub fn predict(
|
||||||
|
&self,
|
||||||
|
input: &[&str],
|
||||||
|
consolidate_sub_tokens: bool,
|
||||||
|
return_special: bool,
|
||||||
|
) -> Vec<Token> {
|
||||||
let (tokenized_input, input_tensor) = self.prepare_for_model(input.to_vec());
|
let (tokenized_input, input_tensor) = self.prepare_for_model(input.to_vec());
|
||||||
let (output, _, _) = no_grad(|| {
|
let (output, _, _) = no_grad(|| {
|
||||||
self.token_sequence_classifier
|
self.token_sequence_classifier.forward_t(
|
||||||
.forward_t(Some(input_tensor.copy()),
|
Some(input_tensor.copy()),
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
false)
|
false,
|
||||||
|
)
|
||||||
});
|
});
|
||||||
let output = output.detach().to(Device::Cpu);
|
let output = output.detach().to(Device::Cpu);
|
||||||
let score: Tensor = output.exp() / output.exp().sum1(&[-1], true, Float);
|
let score: Tensor = output.exp() / output.exp().sum1(&[-1], true, Float);
|
||||||
let labels_idx = &score.argmax(-1, true);
|
let labels_idx = &score.argmax(-1, true);
|
||||||
let mut tokens: Vec<Token> = vec!();
|
let mut tokens: Vec<Token> = vec![];
|
||||||
for sentence_idx in 0..labels_idx.size()[0] {
|
for sentence_idx in 0..labels_idx.size()[0] {
|
||||||
let labels = labels_idx.get(sentence_idx);
|
let labels = labels_idx.get(sentence_idx);
|
||||||
let sentence_tokens = &tokenized_input[sentence_idx as usize];
|
let sentence_tokens = &tokenized_input[sentence_idx as usize];
|
||||||
@ -415,7 +525,16 @@ impl TokenClassificationModel {
|
|||||||
word_idx += 1;
|
word_idx += 1;
|
||||||
}
|
}
|
||||||
let token = {
|
let token = {
|
||||||
self.decode_token(&original_chars, sentence_tokens, &input_tensor, &labels, &score, sentence_idx, position_idx as i64, word_idx - 1)
|
self.decode_token(
|
||||||
|
&original_chars,
|
||||||
|
sentence_tokens,
|
||||||
|
&input_tensor,
|
||||||
|
&labels,
|
||||||
|
&score,
|
||||||
|
sentence_idx,
|
||||||
|
position_idx as i64,
|
||||||
|
word_idx - 1,
|
||||||
|
)
|
||||||
};
|
};
|
||||||
tokens.push(token);
|
tokens.push(token);
|
||||||
}
|
}
|
||||||
@ -426,8 +545,17 @@ impl TokenClassificationModel {
|
|||||||
tokens
|
tokens
|
||||||
}
|
}
|
||||||
|
|
||||||
fn decode_token(&self, original_sentence_chars: &Vec<char>, sentence_tokens: &TokenizedInput, input_tensor: &Tensor,
|
fn decode_token(
|
||||||
labels: &Tensor, score: &Tensor, sentence_idx: i64, position_idx: i64, word_index: u16) -> Token {
|
&self,
|
||||||
|
original_sentence_chars: &Vec<char>,
|
||||||
|
sentence_tokens: &TokenizedInput,
|
||||||
|
input_tensor: &Tensor,
|
||||||
|
labels: &Tensor,
|
||||||
|
score: &Tensor,
|
||||||
|
sentence_idx: i64,
|
||||||
|
position_idx: i64,
|
||||||
|
word_index: u16,
|
||||||
|
) -> Token {
|
||||||
let label_id = labels.int64_value(&[position_idx as i64]);
|
let label_id = labels.int64_value(&[position_idx as i64]);
|
||||||
let token_id = input_tensor.int64_value(&[sentence_idx, position_idx as i64]);
|
let token_id = input_tensor.int64_value(&[sentence_idx, position_idx as i64]);
|
||||||
|
|
||||||
@ -435,13 +563,19 @@ impl TokenClassificationModel {
|
|||||||
|
|
||||||
let text = match offsets {
|
let text = match offsets {
|
||||||
None => match self.tokenizer {
|
None => match self.tokenizer {
|
||||||
TokenizerOption::Bert(ref tokenizer) => Tokenizer::decode(tokenizer, vec!(token_id), false, false),
|
TokenizerOption::Bert(ref tokenizer) => {
|
||||||
TokenizerOption::Roberta(ref tokenizer) => Tokenizer::decode(tokenizer, vec!(token_id), false, false),
|
Tokenizer::decode(tokenizer, vec![token_id], false, false)
|
||||||
|
}
|
||||||
|
TokenizerOption::Roberta(ref tokenizer) => {
|
||||||
|
Tokenizer::decode(tokenizer, vec![token_id], false, false)
|
||||||
|
}
|
||||||
},
|
},
|
||||||
Some(offsets) => {
|
Some(offsets) => {
|
||||||
let (start_char, end_char) = (offsets.begin as usize, offsets.end as usize);
|
let (start_char, end_char) = (offsets.begin as usize, offsets.end as usize);
|
||||||
let end_char = min(end_char, original_sentence_chars.len());
|
let end_char = min(end_char, original_sentence_chars.len());
|
||||||
let text = original_sentence_chars[start_char..end_char].iter().collect();
|
let text = original_sentence_chars[start_char..end_char]
|
||||||
|
.iter()
|
||||||
|
.collect();
|
||||||
text
|
text
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -449,7 +583,11 @@ impl TokenClassificationModel {
|
|||||||
Token {
|
Token {
|
||||||
text,
|
text,
|
||||||
score: score.double_value(&[sentence_idx, position_idx, label_id]),
|
score: score.double_value(&[sentence_idx, position_idx, label_id]),
|
||||||
label: self.label_mapping.get(&label_id).expect("Index out of vocabulary bounds.").to_owned(),
|
label: self
|
||||||
|
.label_mapping
|
||||||
|
.get(&label_id)
|
||||||
|
.expect("Index out of vocabulary bounds.")
|
||||||
|
.to_owned(),
|
||||||
label_index: label_id,
|
label_index: label_id,
|
||||||
sentence: sentence_idx as usize,
|
sentence: sentence_idx as usize,
|
||||||
index: position_idx as u16,
|
index: position_idx as u16,
|
||||||
@ -459,24 +597,29 @@ impl TokenClassificationModel {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn consolidate_tokens(&self, tokens: &mut Vec<Token>, label_aggregation_function: &LabelAggregationOption) {
|
fn consolidate_tokens(
|
||||||
let mut tokens_to_replace = vec!();
|
&self,
|
||||||
|
tokens: &mut Vec<Token>,
|
||||||
|
label_aggregation_function: &LabelAggregationOption,
|
||||||
|
) {
|
||||||
|
let mut tokens_to_replace = vec![];
|
||||||
let mut token_iter = tokens.iter_consolidate_tokens();
|
let mut token_iter = tokens.iter_consolidate_tokens();
|
||||||
let mut cursor = 0;
|
let mut cursor = 0;
|
||||||
|
|
||||||
while let Some(sub_tokens) = token_iter.next() {
|
while let Some(sub_tokens) = token_iter.next() {
|
||||||
if sub_tokens.len() > 1 {
|
if sub_tokens.len() > 1 {
|
||||||
let (label_index, label) = self.consolidate_labels(sub_tokens, label_aggregation_function);
|
let (label_index, label) =
|
||||||
|
self.consolidate_labels(sub_tokens, label_aggregation_function);
|
||||||
let sentence = (&sub_tokens[0]).sentence;
|
let sentence = (&sub_tokens[0]).sentence;
|
||||||
let index = (&sub_tokens[0]).index;
|
let index = (&sub_tokens[0]).index;
|
||||||
let word_index = (&sub_tokens[0]).word_index;
|
let word_index = (&sub_tokens[0]).word_index;
|
||||||
let offset_start = match &sub_tokens.first().unwrap().offset {
|
let offset_start = match &sub_tokens.first().unwrap().offset {
|
||||||
Some(offset) => Some(offset.begin),
|
Some(offset) => Some(offset.begin),
|
||||||
None => None
|
None => None,
|
||||||
};
|
};
|
||||||
let offset_end = match &sub_tokens.last().unwrap().offset {
|
let offset_end = match &sub_tokens.last().unwrap().offset {
|
||||||
Some(offset) => Some(offset.end),
|
Some(offset) => Some(offset.end),
|
||||||
None => None
|
None => None,
|
||||||
};
|
};
|
||||||
let offset = if offset_start.is_some() & offset_end.is_some() {
|
let offset = if offset_start.is_some() & offset_end.is_some() {
|
||||||
Some(Offset::new(offset_start.unwrap(), offset_end.unwrap()))
|
Some(Offset::new(offset_start.unwrap(), offset_end.unwrap()))
|
||||||
@ -513,7 +656,11 @@ impl TokenClassificationModel {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn consolidate_labels(&self, tokens: &[Token], aggregation: &LabelAggregationOption) -> (i64, String) {
|
fn consolidate_labels(
|
||||||
|
&self,
|
||||||
|
tokens: &[Token],
|
||||||
|
aggregation: &LabelAggregationOption,
|
||||||
|
) -> (i64, String) {
|
||||||
match aggregation {
|
match aggregation {
|
||||||
LabelAggregationOption::First => {
|
LabelAggregationOption::First => {
|
||||||
let token = tokens.first().unwrap();
|
let token = tokens.first().unwrap();
|
||||||
@ -524,22 +671,17 @@ impl TokenClassificationModel {
|
|||||||
(token.label_index, token.label.clone())
|
(token.label_index, token.label.clone())
|
||||||
}
|
}
|
||||||
LabelAggregationOption::Mode => {
|
LabelAggregationOption::Mode => {
|
||||||
let counts = tokens
|
let counts = tokens.iter().fold(HashMap::new(), |mut m, c| {
|
||||||
.iter()
|
*m.entry((c.label_index, c.label.as_str())).or_insert(0) += 1;
|
||||||
.fold(
|
m
|
||||||
HashMap::new(),
|
});
|
||||||
|mut m, c| {
|
|
||||||
*m.entry((c.label_index, c.label.as_str())).or_insert(0) += 1;
|
|
||||||
m
|
|
||||||
},
|
|
||||||
);
|
|
||||||
counts
|
counts
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.max_by(|a, b| a.1.cmp(&b.1))
|
.max_by(|a, b| a.1.cmp(&b.1))
|
||||||
.map(|((label_index, label), _)| (label_index, label.to_owned()))
|
.map(|((label_index, label), _)| (label_index, label.to_owned()))
|
||||||
.unwrap()
|
.unwrap()
|
||||||
}
|
}
|
||||||
LabelAggregationOption::Custom(function) => function(tokens)
|
LabelAggregationOption::Custom(function) => function(tokens),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -11,7 +11,6 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
//! # Translation pipeline
|
//! # Translation pipeline
|
||||||
//! Translation based on the Marian encoder-decoder architecture
|
//! Translation based on the Marian encoder-decoder architecture
|
||||||
//! Include techniques such as beam search, top-k and nucleus sampling, temperature setting and repetition penalty.
|
//! Include techniques such as beam search, top-k and nucleus sampling, temperature setting and repetition penalty.
|
||||||
@ -33,31 +32,35 @@
|
|||||||
//!
|
//!
|
||||||
//!
|
//!
|
||||||
//! ```no_run
|
//! ```no_run
|
||||||
//!# fn main() -> failure::Fallible<()> {
|
//! # fn main() -> failure::Fallible<()> {
|
||||||
//!# use rust_bert::pipelines::generation::LanguageGenerator;
|
//! # use rust_bert::pipelines::generation::LanguageGenerator;
|
||||||
//! use rust_bert::pipelines::translation::{TranslationModel, TranslationConfig, Language};
|
//! use rust_bert::pipelines::translation::{Language, TranslationConfig, TranslationModel};
|
||||||
//! use tch::Device;
|
//! use tch::Device;
|
||||||
//! let translation_config = TranslationConfig::new(Language::EnglishToFrench, Device::cuda_if_available());
|
//! let translation_config =
|
||||||
|
//! TranslationConfig::new(Language::EnglishToFrench, Device::cuda_if_available());
|
||||||
//! let mut model = TranslationModel::new(translation_config)?;
|
//! let mut model = TranslationModel::new(translation_config)?;
|
||||||
//!
|
//!
|
||||||
//! let input = ["This is a sentence to be translated"];
|
//! let input = ["This is a sentence to be translated"];
|
||||||
//!
|
//!
|
||||||
//! let output = model.translate(&input);
|
//! let output = model.translate(&input);
|
||||||
//!# Ok(())
|
//! # Ok(())
|
||||||
//!# }
|
//! # }
|
||||||
//! ```
|
//! ```
|
||||||
//!
|
//!
|
||||||
//! Output: \
|
//! Output: \
|
||||||
//! ```no_run
|
//! ```no_run
|
||||||
//!# let output =
|
//! # let output =
|
||||||
//! "Il s'agit d'une phrase à traduire"
|
//! "Il s'agit d'une phrase à traduire"
|
||||||
//!# ;
|
//! # ;
|
||||||
//!```
|
//! ```
|
||||||
|
|
||||||
use crate::pipelines::generation::{MarianGenerator, GenerateConfig, LanguageGenerator};
|
use crate::common::resources::{RemoteResource, Resource};
|
||||||
|
use crate::marian::{
|
||||||
|
MarianConfigResources, MarianModelResources, MarianPrefix, MarianSpmResources,
|
||||||
|
MarianVocabResources,
|
||||||
|
};
|
||||||
|
use crate::pipelines::generation::{GenerateConfig, LanguageGenerator, MarianGenerator};
|
||||||
use tch::Device;
|
use tch::Device;
|
||||||
use crate::common::resources::{Resource, RemoteResource};
|
|
||||||
use crate::marian::{MarianModelResources, MarianConfigResources, MarianVocabResources, MarianSpmResources, MarianPrefix};
|
|
||||||
|
|
||||||
/// Pretrained languages available for direct use
|
/// Pretrained languages available for direct use
|
||||||
pub enum Language {
|
pub enum Language {
|
||||||
@ -84,47 +87,244 @@ pub enum Language {
|
|||||||
struct RemoteTranslationResources;
|
struct RemoteTranslationResources;
|
||||||
|
|
||||||
impl RemoteTranslationResources {
|
impl RemoteTranslationResources {
|
||||||
pub const ENGLISH2FRENCH: ((&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), Option<&'static str>) =
|
pub const ENGLISH2FRENCH: (
|
||||||
(MarianModelResources::ENGLISH2ROMANCE, MarianConfigResources::ENGLISH2ROMANCE, MarianVocabResources::ENGLISH2ROMANCE, MarianSpmResources::ENGLISH2ROMANCE, MarianPrefix::ENGLISH2FRENCH);
|
(&'static str, &'static str),
|
||||||
pub const ENGLISH2CATALAN: ((&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), Option<&'static str>) =
|
(&'static str, &'static str),
|
||||||
(MarianModelResources::ENGLISH2ROMANCE, MarianConfigResources::ENGLISH2ROMANCE, MarianVocabResources::ENGLISH2ROMANCE, MarianSpmResources::ENGLISH2ROMANCE, MarianPrefix::ENGLISH2CATALAN);
|
(&'static str, &'static str),
|
||||||
pub const ENGLISH2SPANISH: ((&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), Option<&'static str>) =
|
(&'static str, &'static str),
|
||||||
(MarianModelResources::ENGLISH2ROMANCE, MarianConfigResources::ENGLISH2ROMANCE, MarianVocabResources::ENGLISH2ROMANCE, MarianSpmResources::ENGLISH2ROMANCE, MarianPrefix::ENGLISH2SPANISH);
|
Option<&'static str>,
|
||||||
pub const ENGLISH2PORTUGUESE: ((&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), Option<&'static str>) =
|
) = (
|
||||||
(MarianModelResources::ENGLISH2ROMANCE, MarianConfigResources::ENGLISH2ROMANCE, MarianVocabResources::ENGLISH2ROMANCE, MarianSpmResources::ENGLISH2ROMANCE, MarianPrefix::ENGLISH2PORTUGUESE);
|
MarianModelResources::ENGLISH2ROMANCE,
|
||||||
pub const ENGLISH2ITALIAN: ((&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), Option<&'static str>) =
|
MarianConfigResources::ENGLISH2ROMANCE,
|
||||||
(MarianModelResources::ENGLISH2ROMANCE, MarianConfigResources::ENGLISH2ROMANCE, MarianVocabResources::ENGLISH2ROMANCE, MarianSpmResources::ENGLISH2ROMANCE, MarianPrefix::ENGLISH2ITALIAN);
|
MarianVocabResources::ENGLISH2ROMANCE,
|
||||||
pub const ENGLISH2ROMANIAN: ((&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), Option<&'static str>) =
|
MarianSpmResources::ENGLISH2ROMANCE,
|
||||||
(MarianModelResources::ENGLISH2ROMANCE, MarianConfigResources::ENGLISH2ROMANCE, MarianVocabResources::ENGLISH2ROMANCE, MarianSpmResources::ENGLISH2ROMANCE, MarianPrefix::ENGLISH2ROMANIAN);
|
MarianPrefix::ENGLISH2FRENCH,
|
||||||
pub const ENGLISH2GERMAN: ((&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), Option<&'static str>) =
|
);
|
||||||
(MarianModelResources::ENGLISH2GERMAN, MarianConfigResources::ENGLISH2GERMAN, MarianVocabResources::ENGLISH2GERMAN, MarianSpmResources::ENGLISH2GERMAN, MarianPrefix::ENGLISH2GERMAN);
|
pub const ENGLISH2CATALAN: (
|
||||||
pub const ENGLISH2RUSSIAN: ((&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), Option<&'static str>) =
|
(&'static str, &'static str),
|
||||||
(MarianModelResources::ENGLISH2RUSSIAN, MarianConfigResources::ENGLISH2RUSSIAN, MarianVocabResources::ENGLISH2RUSSIAN, MarianSpmResources::ENGLISH2RUSSIAN, MarianPrefix::ENGLISH2RUSSIAN);
|
(&'static str, &'static str),
|
||||||
|
(&'static str, &'static str),
|
||||||
|
(&'static str, &'static str),
|
||||||
|
Option<&'static str>,
|
||||||
|
) = (
|
||||||
|
MarianModelResources::ENGLISH2ROMANCE,
|
||||||
|
MarianConfigResources::ENGLISH2ROMANCE,
|
||||||
|
MarianVocabResources::ENGLISH2ROMANCE,
|
||||||
|
MarianSpmResources::ENGLISH2ROMANCE,
|
||||||
|
MarianPrefix::ENGLISH2CATALAN,
|
||||||
|
);
|
||||||
|
pub const ENGLISH2SPANISH: (
|
||||||
|
(&'static str, &'static str),
|
||||||
|
(&'static str, &'static str),
|
||||||
|
(&'static str, &'static str),
|
||||||
|
(&'static str, &'static str),
|
||||||
|
Option<&'static str>,
|
||||||
|
) = (
|
||||||
|
MarianModelResources::ENGLISH2ROMANCE,
|
||||||
|
MarianConfigResources::ENGLISH2ROMANCE,
|
||||||
|
MarianVocabResources::ENGLISH2ROMANCE,
|
||||||
|
MarianSpmResources::ENGLISH2ROMANCE,
|
||||||
|
MarianPrefix::ENGLISH2SPANISH,
|
||||||
|
);
|
||||||
|
pub const ENGLISH2PORTUGUESE: (
|
||||||
|
(&'static str, &'static str),
|
||||||
|
(&'static str, &'static str),
|
||||||
|
(&'static str, &'static str),
|
||||||
|
(&'static str, &'static str),
|
||||||
|
Option<&'static str>,
|
||||||
|
) = (
|
||||||
|
MarianModelResources::ENGLISH2ROMANCE,
|
||||||
|
MarianConfigResources::ENGLISH2ROMANCE,
|
||||||
|
MarianVocabResources::ENGLISH2ROMANCE,
|
||||||
|
MarianSpmResources::ENGLISH2ROMANCE,
|
||||||
|
MarianPrefix::ENGLISH2PORTUGUESE,
|
||||||
|
);
|
||||||
|
pub const ENGLISH2ITALIAN: (
|
||||||
|
(&'static str, &'static str),
|
||||||
|
(&'static str, &'static str),
|
||||||
|
(&'static str, &'static str),
|
||||||
|
(&'static str, &'static str),
|
||||||
|
Option<&'static str>,
|
||||||
|
) = (
|
||||||
|
MarianModelResources::ENGLISH2ROMANCE,
|
||||||
|
MarianConfigResources::ENGLISH2ROMANCE,
|
||||||
|
MarianVocabResources::ENGLISH2ROMANCE,
|
||||||
|
MarianSpmResources::ENGLISH2ROMANCE,
|
||||||
|
MarianPrefix::ENGLISH2ITALIAN,
|
||||||
|
);
|
||||||
|
pub const ENGLISH2ROMANIAN: (
|
||||||
|
(&'static str, &'static str),
|
||||||
|
(&'static str, &'static str),
|
||||||
|
(&'static str, &'static str),
|
||||||
|
(&'static str, &'static str),
|
||||||
|
Option<&'static str>,
|
||||||
|
) = (
|
||||||
|
MarianModelResources::ENGLISH2ROMANCE,
|
||||||
|
MarianConfigResources::ENGLISH2ROMANCE,
|
||||||
|
MarianVocabResources::ENGLISH2ROMANCE,
|
||||||
|
MarianSpmResources::ENGLISH2ROMANCE,
|
||||||
|
MarianPrefix::ENGLISH2ROMANIAN,
|
||||||
|
);
|
||||||
|
pub const ENGLISH2GERMAN: (
|
||||||
|
(&'static str, &'static str),
|
||||||
|
(&'static str, &'static str),
|
||||||
|
(&'static str, &'static str),
|
||||||
|
(&'static str, &'static str),
|
||||||
|
Option<&'static str>,
|
||||||
|
) = (
|
||||||
|
MarianModelResources::ENGLISH2GERMAN,
|
||||||
|
MarianConfigResources::ENGLISH2GERMAN,
|
||||||
|
MarianVocabResources::ENGLISH2GERMAN,
|
||||||
|
MarianSpmResources::ENGLISH2GERMAN,
|
||||||
|
MarianPrefix::ENGLISH2GERMAN,
|
||||||
|
);
|
||||||
|
pub const ENGLISH2RUSSIAN: (
|
||||||
|
(&'static str, &'static str),
|
||||||
|
(&'static str, &'static str),
|
||||||
|
(&'static str, &'static str),
|
||||||
|
(&'static str, &'static str),
|
||||||
|
Option<&'static str>,
|
||||||
|
) = (
|
||||||
|
MarianModelResources::ENGLISH2RUSSIAN,
|
||||||
|
MarianConfigResources::ENGLISH2RUSSIAN,
|
||||||
|
MarianVocabResources::ENGLISH2RUSSIAN,
|
||||||
|
MarianSpmResources::ENGLISH2RUSSIAN,
|
||||||
|
MarianPrefix::ENGLISH2RUSSIAN,
|
||||||
|
);
|
||||||
|
|
||||||
pub const FRENCH2ENGLISH: ((&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), Option<&'static str>) =
|
pub const FRENCH2ENGLISH: (
|
||||||
(MarianModelResources::ROMANCE2ENGLISH, MarianConfigResources::ROMANCE2ENGLISH, MarianVocabResources::ROMANCE2ENGLISH, MarianSpmResources::ROMANCE2ENGLISH, MarianPrefix::FRENCH2ENGLISH);
|
(&'static str, &'static str),
|
||||||
pub const CATALAN2ENGLISH: ((&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), Option<&'static str>) =
|
(&'static str, &'static str),
|
||||||
(MarianModelResources::ROMANCE2ENGLISH, MarianConfigResources::ROMANCE2ENGLISH, MarianVocabResources::ROMANCE2ENGLISH, MarianSpmResources::ROMANCE2ENGLISH, MarianPrefix::CATALAN2ENGLISH);
|
(&'static str, &'static str),
|
||||||
pub const SPANISH2ENGLISH: ((&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), Option<&'static str>) =
|
(&'static str, &'static str),
|
||||||
(MarianModelResources::ROMANCE2ENGLISH, MarianConfigResources::ROMANCE2ENGLISH, MarianVocabResources::ROMANCE2ENGLISH, MarianSpmResources::ROMANCE2ENGLISH, MarianPrefix::SPANISH2ENGLISH);
|
Option<&'static str>,
|
||||||
pub const PORTUGUESE2ENGLISH: ((&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), Option<&'static str>) =
|
) = (
|
||||||
(MarianModelResources::ROMANCE2ENGLISH, MarianConfigResources::ROMANCE2ENGLISH, MarianVocabResources::ROMANCE2ENGLISH, MarianSpmResources::ROMANCE2ENGLISH, MarianPrefix::PORTUGUESE2ENGLISH);
|
MarianModelResources::ROMANCE2ENGLISH,
|
||||||
pub const ITALIAN2ENGLISH: ((&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), Option<&'static str>) =
|
MarianConfigResources::ROMANCE2ENGLISH,
|
||||||
(MarianModelResources::ROMANCE2ENGLISH, MarianConfigResources::ROMANCE2ENGLISH, MarianVocabResources::ROMANCE2ENGLISH, MarianSpmResources::ROMANCE2ENGLISH, MarianPrefix::ITALIAN2ENGLISH);
|
MarianVocabResources::ROMANCE2ENGLISH,
|
||||||
pub const ROMANIAN2ENGLISH: ((&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), Option<&'static str>) =
|
MarianSpmResources::ROMANCE2ENGLISH,
|
||||||
(MarianModelResources::ROMANCE2ENGLISH, MarianConfigResources::ROMANCE2ENGLISH, MarianVocabResources::ROMANCE2ENGLISH, MarianSpmResources::ROMANCE2ENGLISH, MarianPrefix::ROMANIAN2ENGLISH);
|
MarianPrefix::FRENCH2ENGLISH,
|
||||||
pub const GERMAN2ENGLISH: ((&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), Option<&'static str>) =
|
);
|
||||||
(MarianModelResources::GERMAN2ENGLISH, MarianConfigResources::GERMAN2ENGLISH, MarianVocabResources::GERMAN2ENGLISH, MarianSpmResources::GERMAN2ENGLISH, MarianPrefix::GERMAN2ENGLISH);
|
pub const CATALAN2ENGLISH: (
|
||||||
pub const RUSSIAN2ENGLISH: ((&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), Option<&'static str>) =
|
(&'static str, &'static str),
|
||||||
(MarianModelResources::RUSSIAN2ENGLISH, MarianConfigResources::RUSSIAN2ENGLISH, MarianVocabResources::RUSSIAN2ENGLISH, MarianSpmResources::RUSSIAN2ENGLISH, MarianPrefix::RUSSIAN2ENGLISH);
|
(&'static str, &'static str),
|
||||||
|
(&'static str, &'static str),
|
||||||
|
(&'static str, &'static str),
|
||||||
|
Option<&'static str>,
|
||||||
|
) = (
|
||||||
|
MarianModelResources::ROMANCE2ENGLISH,
|
||||||
|
MarianConfigResources::ROMANCE2ENGLISH,
|
||||||
|
MarianVocabResources::ROMANCE2ENGLISH,
|
||||||
|
MarianSpmResources::ROMANCE2ENGLISH,
|
||||||
|
MarianPrefix::CATALAN2ENGLISH,
|
||||||
|
);
|
||||||
|
pub const SPANISH2ENGLISH: (
|
||||||
|
(&'static str, &'static str),
|
||||||
|
(&'static str, &'static str),
|
||||||
|
(&'static str, &'static str),
|
||||||
|
(&'static str, &'static str),
|
||||||
|
Option<&'static str>,
|
||||||
|
) = (
|
||||||
|
MarianModelResources::ROMANCE2ENGLISH,
|
||||||
|
MarianConfigResources::ROMANCE2ENGLISH,
|
||||||
|
MarianVocabResources::ROMANCE2ENGLISH,
|
||||||
|
MarianSpmResources::ROMANCE2ENGLISH,
|
||||||
|
MarianPrefix::SPANISH2ENGLISH,
|
||||||
|
);
|
||||||
|
pub const PORTUGUESE2ENGLISH: (
|
||||||
|
(&'static str, &'static str),
|
||||||
|
(&'static str, &'static str),
|
||||||
|
(&'static str, &'static str),
|
||||||
|
(&'static str, &'static str),
|
||||||
|
Option<&'static str>,
|
||||||
|
) = (
|
||||||
|
MarianModelResources::ROMANCE2ENGLISH,
|
||||||
|
MarianConfigResources::ROMANCE2ENGLISH,
|
||||||
|
MarianVocabResources::ROMANCE2ENGLISH,
|
||||||
|
MarianSpmResources::ROMANCE2ENGLISH,
|
||||||
|
MarianPrefix::PORTUGUESE2ENGLISH,
|
||||||
|
);
|
||||||
|
pub const ITALIAN2ENGLISH: (
|
||||||
|
(&'static str, &'static str),
|
||||||
|
(&'static str, &'static str),
|
||||||
|
(&'static str, &'static str),
|
||||||
|
(&'static str, &'static str),
|
||||||
|
Option<&'static str>,
|
||||||
|
) = (
|
||||||
|
MarianModelResources::ROMANCE2ENGLISH,
|
||||||
|
MarianConfigResources::ROMANCE2ENGLISH,
|
||||||
|
MarianVocabResources::ROMANCE2ENGLISH,
|
||||||
|
MarianSpmResources::ROMANCE2ENGLISH,
|
||||||
|
MarianPrefix::ITALIAN2ENGLISH,
|
||||||
|
);
|
||||||
|
pub const ROMANIAN2ENGLISH: (
|
||||||
|
(&'static str, &'static str),
|
||||||
|
(&'static str, &'static str),
|
||||||
|
(&'static str, &'static str),
|
||||||
|
(&'static str, &'static str),
|
||||||
|
Option<&'static str>,
|
||||||
|
) = (
|
||||||
|
MarianModelResources::ROMANCE2ENGLISH,
|
||||||
|
MarianConfigResources::ROMANCE2ENGLISH,
|
||||||
|
MarianVocabResources::ROMANCE2ENGLISH,
|
||||||
|
MarianSpmResources::ROMANCE2ENGLISH,
|
||||||
|
MarianPrefix::ROMANIAN2ENGLISH,
|
||||||
|
);
|
||||||
|
pub const GERMAN2ENGLISH: (
|
||||||
|
(&'static str, &'static str),
|
||||||
|
(&'static str, &'static str),
|
||||||
|
(&'static str, &'static str),
|
||||||
|
(&'static str, &'static str),
|
||||||
|
Option<&'static str>,
|
||||||
|
) = (
|
||||||
|
MarianModelResources::GERMAN2ENGLISH,
|
||||||
|
MarianConfigResources::GERMAN2ENGLISH,
|
||||||
|
MarianVocabResources::GERMAN2ENGLISH,
|
||||||
|
MarianSpmResources::GERMAN2ENGLISH,
|
||||||
|
MarianPrefix::GERMAN2ENGLISH,
|
||||||
|
);
|
||||||
|
pub const RUSSIAN2ENGLISH: (
|
||||||
|
(&'static str, &'static str),
|
||||||
|
(&'static str, &'static str),
|
||||||
|
(&'static str, &'static str),
|
||||||
|
(&'static str, &'static str),
|
||||||
|
Option<&'static str>,
|
||||||
|
) = (
|
||||||
|
MarianModelResources::RUSSIAN2ENGLISH,
|
||||||
|
MarianConfigResources::RUSSIAN2ENGLISH,
|
||||||
|
MarianVocabResources::RUSSIAN2ENGLISH,
|
||||||
|
MarianSpmResources::RUSSIAN2ENGLISH,
|
||||||
|
MarianPrefix::RUSSIAN2ENGLISH,
|
||||||
|
);
|
||||||
|
|
||||||
pub const FRENCH2GERMAN: ((&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), Option<&'static str>) =
|
pub const FRENCH2GERMAN: (
|
||||||
(MarianModelResources::FRENCH2GERMAN, MarianConfigResources::FRENCH2GERMAN, MarianVocabResources::FRENCH2GERMAN, MarianSpmResources::FRENCH2GERMAN, MarianPrefix::FRENCH2GERMAN);
|
(&'static str, &'static str),
|
||||||
pub const GERMAN2FRENCH: ((&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), (&'static str, &'static str), Option<&'static str>) =
|
(&'static str, &'static str),
|
||||||
(MarianModelResources::GERMAN2FRENCH, MarianConfigResources::GERMAN2FRENCH, MarianVocabResources::GERMAN2FRENCH, MarianSpmResources::GERMAN2FRENCH, MarianPrefix::GERMAN2FRENCH);
|
(&'static str, &'static str),
|
||||||
|
(&'static str, &'static str),
|
||||||
|
Option<&'static str>,
|
||||||
|
) = (
|
||||||
|
MarianModelResources::FRENCH2GERMAN,
|
||||||
|
MarianConfigResources::FRENCH2GERMAN,
|
||||||
|
MarianVocabResources::FRENCH2GERMAN,
|
||||||
|
MarianSpmResources::FRENCH2GERMAN,
|
||||||
|
MarianPrefix::FRENCH2GERMAN,
|
||||||
|
);
|
||||||
|
pub const GERMAN2FRENCH: (
|
||||||
|
(&'static str, &'static str),
|
||||||
|
(&'static str, &'static str),
|
||||||
|
(&'static str, &'static str),
|
||||||
|
(&'static str, &'static str),
|
||||||
|
Option<&'static str>,
|
||||||
|
) = (
|
||||||
|
MarianModelResources::GERMAN2FRENCH,
|
||||||
|
MarianConfigResources::GERMAN2FRENCH,
|
||||||
|
MarianVocabResources::GERMAN2FRENCH,
|
||||||
|
MarianSpmResources::GERMAN2FRENCH,
|
||||||
|
MarianPrefix::GERMAN2FRENCH,
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/// # Configuration for text translation
|
/// # Configuration for text translation
|
||||||
/// Contains information regarding the model to load, mirrors the GenerationConfig, with a
|
/// Contains information regarding the model to load, mirrors the GenerationConfig, with a
|
||||||
/// different set of default parameters and sets the device to place the model on.
|
/// different set of default parameters and sets the device to place the model on.
|
||||||
@ -178,45 +378,46 @@ impl TranslationConfig {
|
|||||||
/// # Example
|
/// # Example
|
||||||
///
|
///
|
||||||
/// ```no_run
|
/// ```no_run
|
||||||
///# fn main() -> failure::Fallible<()> {
|
/// # fn main() -> failure::Fallible<()> {
|
||||||
/// use rust_bert::pipelines::translation::{TranslationConfig, Language};
|
/// use rust_bert::pipelines::translation::{Language, TranslationConfig};
|
||||||
/// use tch::Device;
|
/// use tch::Device;
|
||||||
///
|
///
|
||||||
/// let translation_config = TranslationConfig::new(Language::FrenchToEnglish, Device::cuda_if_available());
|
/// let translation_config =
|
||||||
///# Ok(())
|
/// TranslationConfig::new(Language::FrenchToEnglish, Device::cuda_if_available());
|
||||||
///# }
|
/// # Ok(())
|
||||||
|
/// # }
|
||||||
/// ```
|
/// ```
|
||||||
///
|
|
||||||
pub fn new(language: Language, device: Device) -> TranslationConfig {
|
pub fn new(language: Language, device: Device) -> TranslationConfig {
|
||||||
let (model_resource, config_resource, vocab_resource, merges_resource, prefix) = match language {
|
let (model_resource, config_resource, vocab_resource, merges_resource, prefix) =
|
||||||
Language::EnglishToFrench => RemoteTranslationResources::ENGLISH2FRENCH,
|
match language {
|
||||||
Language::EnglishToCatalan => RemoteTranslationResources::ENGLISH2CATALAN,
|
Language::EnglishToFrench => RemoteTranslationResources::ENGLISH2FRENCH,
|
||||||
Language::EnglishToSpanish => RemoteTranslationResources::ENGLISH2SPANISH,
|
Language::EnglishToCatalan => RemoteTranslationResources::ENGLISH2CATALAN,
|
||||||
Language::EnglishToPortuguese => RemoteTranslationResources::ENGLISH2PORTUGUESE,
|
Language::EnglishToSpanish => RemoteTranslationResources::ENGLISH2SPANISH,
|
||||||
Language::EnglishToItalian => RemoteTranslationResources::ENGLISH2ITALIAN,
|
Language::EnglishToPortuguese => RemoteTranslationResources::ENGLISH2PORTUGUESE,
|
||||||
Language::EnglishToRomanian => RemoteTranslationResources::ENGLISH2ROMANIAN,
|
Language::EnglishToItalian => RemoteTranslationResources::ENGLISH2ITALIAN,
|
||||||
Language::EnglishToGerman => RemoteTranslationResources::ENGLISH2GERMAN,
|
Language::EnglishToRomanian => RemoteTranslationResources::ENGLISH2ROMANIAN,
|
||||||
Language::EnglishToRussian => RemoteTranslationResources::ENGLISH2RUSSIAN,
|
Language::EnglishToGerman => RemoteTranslationResources::ENGLISH2GERMAN,
|
||||||
|
Language::EnglishToRussian => RemoteTranslationResources::ENGLISH2RUSSIAN,
|
||||||
|
|
||||||
Language::FrenchToEnglish => RemoteTranslationResources::FRENCH2ENGLISH,
|
Language::FrenchToEnglish => RemoteTranslationResources::FRENCH2ENGLISH,
|
||||||
Language::CatalanToEnglish => RemoteTranslationResources::CATALAN2ENGLISH,
|
Language::CatalanToEnglish => RemoteTranslationResources::CATALAN2ENGLISH,
|
||||||
Language::SpanishToEnglish => RemoteTranslationResources::SPANISH2ENGLISH,
|
Language::SpanishToEnglish => RemoteTranslationResources::SPANISH2ENGLISH,
|
||||||
Language::PortugueseToEnglish => RemoteTranslationResources::PORTUGUESE2ENGLISH,
|
Language::PortugueseToEnglish => RemoteTranslationResources::PORTUGUESE2ENGLISH,
|
||||||
Language::ItalianToEnglish => RemoteTranslationResources::ITALIAN2ENGLISH,
|
Language::ItalianToEnglish => RemoteTranslationResources::ITALIAN2ENGLISH,
|
||||||
Language::RomanianToEnglish => RemoteTranslationResources::ROMANIAN2ENGLISH,
|
Language::RomanianToEnglish => RemoteTranslationResources::ROMANIAN2ENGLISH,
|
||||||
Language::GermanToEnglish => RemoteTranslationResources::GERMAN2ENGLISH,
|
Language::GermanToEnglish => RemoteTranslationResources::GERMAN2ENGLISH,
|
||||||
Language::RussianToEnglish => RemoteTranslationResources::RUSSIAN2ENGLISH,
|
Language::RussianToEnglish => RemoteTranslationResources::RUSSIAN2ENGLISH,
|
||||||
|
|
||||||
Language::FrenchToGerman => RemoteTranslationResources::FRENCH2GERMAN,
|
Language::FrenchToGerman => RemoteTranslationResources::FRENCH2GERMAN,
|
||||||
Language::GermanToFrench => RemoteTranslationResources::GERMAN2FRENCH,
|
Language::GermanToFrench => RemoteTranslationResources::GERMAN2FRENCH,
|
||||||
};
|
};
|
||||||
let model_resource = Resource::Remote(RemoteResource::from_pretrained(model_resource));
|
let model_resource = Resource::Remote(RemoteResource::from_pretrained(model_resource));
|
||||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(config_resource));
|
let config_resource = Resource::Remote(RemoteResource::from_pretrained(config_resource));
|
||||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(vocab_resource));
|
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(vocab_resource));
|
||||||
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(merges_resource));
|
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(merges_resource));
|
||||||
let prefix = match prefix {
|
let prefix = match prefix {
|
||||||
Some(value) => Some(value.to_string()),
|
Some(value) => Some(value.to_string()),
|
||||||
None => None
|
None => None,
|
||||||
};
|
};
|
||||||
TranslationConfig {
|
TranslationConfig {
|
||||||
model_resource,
|
model_resource,
|
||||||
@ -253,33 +454,44 @@ impl TranslationConfig {
|
|||||||
/// # Example
|
/// # Example
|
||||||
///
|
///
|
||||||
/// ```no_run
|
/// ```no_run
|
||||||
///# fn main() -> failure::Fallible<()> {
|
/// # fn main() -> failure::Fallible<()> {
|
||||||
/// use rust_bert::pipelines::translation::TranslationConfig;
|
/// use rust_bert::pipelines::translation::TranslationConfig;
|
||||||
/// use tch::Device;
|
/// use rust_bert::resources::{LocalResource, Resource};
|
||||||
/// use rust_bert::resources::{Resource, LocalResource};
|
|
||||||
/// use std::path::PathBuf;
|
/// use std::path::PathBuf;
|
||||||
|
/// use tch::Device;
|
||||||
///
|
///
|
||||||
/// let config_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/config.json") });
|
/// let config_resource = Resource::Local(LocalResource {
|
||||||
/// let model_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/model.ot") });
|
/// local_path: PathBuf::from("path/to/config.json"),
|
||||||
/// let vocab_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/vocab.json") });
|
/// });
|
||||||
/// let sentence_piece_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/spiece.model") });
|
/// let model_resource = Resource::Local(LocalResource {
|
||||||
|
/// local_path: PathBuf::from("path/to/model.ot"),
|
||||||
|
/// });
|
||||||
|
/// let vocab_resource = Resource::Local(LocalResource {
|
||||||
|
/// local_path: PathBuf::from("path/to/vocab.json"),
|
||||||
|
/// });
|
||||||
|
/// let sentence_piece_resource = Resource::Local(LocalResource {
|
||||||
|
/// local_path: PathBuf::from("path/to/spiece.model"),
|
||||||
|
/// });
|
||||||
///
|
///
|
||||||
/// let translation_config = TranslationConfig::new_from_resources(model_resource,
|
/// let translation_config = TranslationConfig::new_from_resources(
|
||||||
/// config_resource,
|
/// model_resource,
|
||||||
/// vocab_resource,
|
/// config_resource,
|
||||||
/// sentence_piece_resource,
|
/// vocab_resource,
|
||||||
/// Some(">>fr<<".to_string()),
|
/// sentence_piece_resource,
|
||||||
/// Device::cuda_if_available());
|
/// Some(">>fr<<".to_string()),
|
||||||
///# Ok(())
|
/// Device::cuda_if_available(),
|
||||||
///# }
|
/// );
|
||||||
|
/// # Ok(())
|
||||||
|
/// # }
|
||||||
/// ```
|
/// ```
|
||||||
///
|
pub fn new_from_resources(
|
||||||
pub fn new_from_resources(model_resource: Resource,
|
model_resource: Resource,
|
||||||
config_resource: Resource,
|
config_resource: Resource,
|
||||||
vocab_resource: Resource,
|
vocab_resource: Resource,
|
||||||
sentence_piece_resource: Resource,
|
sentence_piece_resource: Resource,
|
||||||
prefix: Option<String>,
|
prefix: Option<String>,
|
||||||
device: Device) -> TranslationConfig {
|
device: Device,
|
||||||
|
) -> TranslationConfig {
|
||||||
TranslationConfig {
|
TranslationConfig {
|
||||||
model_resource,
|
model_resource,
|
||||||
config_resource,
|
config_resource,
|
||||||
@ -319,18 +531,17 @@ impl TranslationModel {
|
|||||||
/// # Example
|
/// # Example
|
||||||
///
|
///
|
||||||
/// ```no_run
|
/// ```no_run
|
||||||
///# fn main() -> failure::Fallible<()> {
|
/// # fn main() -> failure::Fallible<()> {
|
||||||
/// use rust_bert::pipelines::translation::{TranslationModel, TranslationConfig, Language};
|
/// use rust_bert::pipelines::translation::{Language, TranslationConfig, TranslationModel};
|
||||||
/// use tch::Device;
|
/// use tch::Device;
|
||||||
///
|
///
|
||||||
/// let translation_config = TranslationConfig::new(Language::FrenchToEnglish, Device::cuda_if_available());
|
/// let translation_config =
|
||||||
/// let mut summarization_model = TranslationModel::new(translation_config)?;
|
/// TranslationConfig::new(Language::FrenchToEnglish, Device::cuda_if_available());
|
||||||
///# Ok(())
|
/// let mut summarization_model = TranslationModel::new(translation_config)?;
|
||||||
///# }
|
/// # Ok(())
|
||||||
|
/// # }
|
||||||
/// ```
|
/// ```
|
||||||
///
|
pub fn new(translation_config: TranslationConfig) -> failure::Fallible<TranslationModel> {
|
||||||
pub fn new(translation_config: TranslationConfig)
|
|
||||||
-> failure::Fallible<TranslationModel> {
|
|
||||||
let generate_config = GenerateConfig {
|
let generate_config = GenerateConfig {
|
||||||
model_resource: translation_config.model_resource,
|
model_resource: translation_config.model_resource,
|
||||||
config_resource: translation_config.config_resource,
|
config_resource: translation_config.config_resource,
|
||||||
@ -353,7 +564,10 @@ impl TranslationModel {
|
|||||||
|
|
||||||
let model = MarianGenerator::new(generate_config)?;
|
let model = MarianGenerator::new(generate_config)?;
|
||||||
|
|
||||||
Ok(TranslationModel { model, prefix: translation_config.prefix })
|
Ok(TranslationModel {
|
||||||
|
model,
|
||||||
|
prefix: translation_config.prefix,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Translates texts provided
|
/// Translates texts provided
|
||||||
@ -368,31 +582,32 @@ impl TranslationModel {
|
|||||||
/// # Example
|
/// # Example
|
||||||
///
|
///
|
||||||
/// ```no_run
|
/// ```no_run
|
||||||
///# fn main() -> failure::Fallible<()> {
|
/// # fn main() -> failure::Fallible<()> {
|
||||||
/// use rust_bert::pipelines::generation::LanguageGenerator;
|
/// use rust_bert::pipelines::generation::LanguageGenerator;
|
||||||
/// use rust_bert::pipelines::translation::{TranslationModel, TranslationConfig, Language};
|
/// use rust_bert::pipelines::translation::{Language, TranslationConfig, TranslationModel};
|
||||||
/// use tch::Device;
|
/// use tch::Device;
|
||||||
///
|
///
|
||||||
/// let translation_config = TranslationConfig::new(Language::EnglishToFrench, Device::cuda_if_available());
|
/// let translation_config =
|
||||||
|
/// TranslationConfig::new(Language::EnglishToFrench, Device::cuda_if_available());
|
||||||
/// let model = TranslationModel::new(translation_config)?;
|
/// let model = TranslationModel::new(translation_config)?;
|
||||||
///
|
///
|
||||||
/// let input = ["This is a sentence to be translated"];
|
/// let input = ["This is a sentence to be translated"];
|
||||||
///
|
///
|
||||||
/// let output = model.translate(&input);
|
/// let output = model.translate(&input);
|
||||||
///# Ok(())
|
/// # Ok(())
|
||||||
///# }
|
/// # }
|
||||||
/// ```
|
/// ```
|
||||||
///
|
|
||||||
pub fn translate(&self, texts: &[&str]) -> Vec<String> {
|
pub fn translate(&self, texts: &[&str]) -> Vec<String> {
|
||||||
match &self.prefix {
|
match &self.prefix {
|
||||||
Some(value) => {
|
Some(value) => {
|
||||||
let texts: Vec<String> = texts
|
let texts: Vec<String> = texts
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|&v| { format!("{} {}", value, v) })
|
.map(|&v| format!("{} {}", value, v))
|
||||||
.collect();
|
.collect();
|
||||||
self.model.generate(Some(texts.iter().map(AsRef::as_ref).collect()), None)
|
self.model
|
||||||
|
.generate(Some(texts.iter().map(AsRef::as_ref).collect()), None)
|
||||||
}
|
}
|
||||||
None => self.model.generate(Some(texts.to_vec()), None)
|
None => self.model.generate(Some(texts.to_vec()), None),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -11,10 +11,10 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
use tch::{nn, Tensor, Kind};
|
|
||||||
use crate::common::dropout::Dropout;
|
|
||||||
use tch::nn::{EmbeddingConfig, embedding};
|
|
||||||
use crate::bert::{BertConfig, BertEmbedding};
|
use crate::bert::{BertConfig, BertEmbedding};
|
||||||
|
use crate::common::dropout::Dropout;
|
||||||
|
use tch::nn::{embedding, EmbeddingConfig};
|
||||||
|
use tch::{nn, Kind, Tensor};
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
/// # BertEmbeddings implementation for RoBERTa model
|
/// # BertEmbeddings implementation for RoBERTa model
|
||||||
@ -36,8 +36,12 @@ impl RobertaEmbeddings {
|
|||||||
|
|
||||||
fn create_position_ids_from_embeddings(&self, x: &Tensor) -> Tensor {
|
fn create_position_ids_from_embeddings(&self, x: &Tensor) -> Tensor {
|
||||||
let input_shape = x.size();
|
let input_shape = x.size();
|
||||||
let input_shape = vec!(input_shape[0], input_shape[1]);
|
let input_shape = vec![input_shape[0], input_shape[1]];
|
||||||
let position_ids: Tensor = Tensor::arange1(self.padding_index + 1, input_shape[0], (Kind::Int64, x.device()));
|
let position_ids: Tensor = Tensor::arange1(
|
||||||
|
self.padding_index + 1,
|
||||||
|
input_shape[0],
|
||||||
|
(Kind::Int64, x.device()),
|
||||||
|
);
|
||||||
position_ids.unsqueeze(0).expand(&input_shape, true)
|
position_ids.unsqueeze(0).expand(&input_shape, true)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -54,10 +58,10 @@ impl BertEmbedding for RobertaEmbeddings {
|
|||||||
///
|
///
|
||||||
/// ```no_run
|
/// ```no_run
|
||||||
/// use rust_bert::bert::{BertConfig, BertEmbedding};
|
/// use rust_bert::bert::{BertConfig, BertEmbedding};
|
||||||
/// use tch::{nn, Device};
|
/// use rust_bert::roberta::RobertaEmbeddings;
|
||||||
/// use rust_bert::Config;
|
/// use rust_bert::Config;
|
||||||
/// use std::path::Path;
|
/// use std::path::Path;
|
||||||
/// use rust_bert::roberta::RobertaEmbeddings;
|
/// use tch::{nn, Device};
|
||||||
///
|
///
|
||||||
/// let config_path = Path::new("path/to/config.json");
|
/// let config_path = Path::new("path/to/config.json");
|
||||||
/// let device = Device::Cpu;
|
/// let device = Device::Cpu;
|
||||||
@ -65,29 +69,48 @@ impl BertEmbedding for RobertaEmbeddings {
|
|||||||
/// let config = BertConfig::from_file(config_path);
|
/// let config = BertConfig::from_file(config_path);
|
||||||
/// let robert_embeddings = RobertaEmbeddings::new(&(&p.root() / "bert_embeddings"), &config);
|
/// let robert_embeddings = RobertaEmbeddings::new(&(&p.root() / "bert_embeddings"), &config);
|
||||||
/// ```
|
/// ```
|
||||||
///
|
|
||||||
fn new(p: &nn::Path, config: &BertConfig) -> RobertaEmbeddings {
|
fn new(p: &nn::Path, config: &BertConfig) -> RobertaEmbeddings {
|
||||||
let embedding_config = EmbeddingConfig { padding_idx: 1, ..Default::default() };
|
let embedding_config = EmbeddingConfig {
|
||||||
|
padding_idx: 1,
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
|
||||||
let word_embeddings: nn::Embedding = embedding(p / "word_embeddings",
|
let word_embeddings: nn::Embedding = embedding(
|
||||||
config.vocab_size,
|
p / "word_embeddings",
|
||||||
config.hidden_size,
|
config.vocab_size,
|
||||||
embedding_config);
|
config.hidden_size,
|
||||||
|
embedding_config,
|
||||||
|
);
|
||||||
|
|
||||||
let position_embeddings: nn::Embedding = embedding(p / "position_embeddings",
|
let position_embeddings: nn::Embedding = embedding(
|
||||||
config.max_position_embeddings,
|
p / "position_embeddings",
|
||||||
config.hidden_size,
|
config.max_position_embeddings,
|
||||||
Default::default());
|
config.hidden_size,
|
||||||
|
Default::default(),
|
||||||
|
);
|
||||||
|
|
||||||
let token_type_embeddings: nn::Embedding = embedding(p / "token_type_embeddings",
|
let token_type_embeddings: nn::Embedding = embedding(
|
||||||
config.type_vocab_size,
|
p / "token_type_embeddings",
|
||||||
config.hidden_size,
|
config.type_vocab_size,
|
||||||
Default::default());
|
config.hidden_size,
|
||||||
|
Default::default(),
|
||||||
|
);
|
||||||
|
|
||||||
let layer_norm_config = nn::LayerNormConfig { eps: 1e-12, ..Default::default() };
|
let layer_norm_config = nn::LayerNormConfig {
|
||||||
let layer_norm: nn::LayerNorm = nn::layer_norm(p / "LayerNorm", vec![config.hidden_size], layer_norm_config);
|
eps: 1e-12,
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
let layer_norm: nn::LayerNorm =
|
||||||
|
nn::layer_norm(p / "LayerNorm", vec![config.hidden_size], layer_norm_config);
|
||||||
let dropout: Dropout = Dropout::new(config.hidden_dropout_prob);
|
let dropout: Dropout = Dropout::new(config.hidden_dropout_prob);
|
||||||
RobertaEmbeddings { word_embeddings, position_embeddings, token_type_embeddings, layer_norm, dropout, padding_index: 1 }
|
RobertaEmbeddings {
|
||||||
|
word_embeddings,
|
||||||
|
position_embeddings,
|
||||||
|
token_type_embeddings,
|
||||||
|
layer_norm,
|
||||||
|
dropout,
|
||||||
|
padding_index: 1,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Forward pass through the embedding layer.
|
/// Forward pass through the embedding layer.
|
||||||
@ -108,68 +131,82 @@ impl BertEmbedding for RobertaEmbeddings {
|
|||||||
/// # Example
|
/// # Example
|
||||||
///
|
///
|
||||||
/// ```no_run
|
/// ```no_run
|
||||||
///# use rust_bert::bert::{BertConfig, BertEmbedding};
|
/// # use rust_bert::bert::{BertConfig, BertEmbedding};
|
||||||
///# use tch::{nn, Device, Tensor, no_grad};
|
/// # use tch::{nn, Device, Tensor, no_grad};
|
||||||
///# use rust_bert::Config;
|
/// # use rust_bert::Config;
|
||||||
///# use std::path::Path;
|
/// # use std::path::Path;
|
||||||
///# use tch::kind::Kind::Int64;
|
/// # use tch::kind::Kind::Int64;
|
||||||
/// use rust_bert::roberta::RobertaEmbeddings;
|
/// use rust_bert::roberta::RobertaEmbeddings;
|
||||||
///# let config_path = Path::new("path/to/config.json");
|
/// # let config_path = Path::new("path/to/config.json");
|
||||||
///# let vocab_path = Path::new("path/to/vocab.txt");
|
/// # let vocab_path = Path::new("path/to/vocab.txt");
|
||||||
///# let device = Device::Cpu;
|
/// # let device = Device::Cpu;
|
||||||
///# let vs = nn::VarStore::new(device);
|
/// # let vs = nn::VarStore::new(device);
|
||||||
///# let config = BertConfig::from_file(config_path);
|
/// # let config = BertConfig::from_file(config_path);
|
||||||
///# let roberta_embeddings = RobertaEmbeddings::new(&vs.root(), &config);
|
/// # let roberta_embeddings = RobertaEmbeddings::new(&vs.root(), &config);
|
||||||
/// let (batch_size, sequence_length) = (64, 128);
|
/// let (batch_size, sequence_length) = (64, 128);
|
||||||
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
|
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
|
||||||
/// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
|
/// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
|
||||||
/// let position_ids = Tensor::arange(sequence_length, (Int64, device)).expand(&[batch_size, sequence_length], true);
|
/// let position_ids = Tensor::arange(sequence_length, (Int64, device))
|
||||||
|
/// .expand(&[batch_size, sequence_length], true);
|
||||||
///
|
///
|
||||||
/// let embedded_output = no_grad(|| {
|
/// let embedded_output = no_grad(|| {
|
||||||
/// roberta_embeddings
|
/// roberta_embeddings
|
||||||
/// .forward_t(Some(input_tensor),
|
/// .forward_t(
|
||||||
/// Some(token_type_ids),
|
/// Some(input_tensor),
|
||||||
/// Some(position_ids),
|
/// Some(token_type_ids),
|
||||||
/// None,
|
/// Some(position_ids),
|
||||||
/// false).unwrap()
|
/// None,
|
||||||
/// });
|
/// false,
|
||||||
|
/// )
|
||||||
|
/// .unwrap()
|
||||||
|
/// });
|
||||||
/// ```
|
/// ```
|
||||||
///
|
fn forward_t(
|
||||||
fn forward_t(&self,
|
&self,
|
||||||
input_ids: Option<Tensor>,
|
input_ids: Option<Tensor>,
|
||||||
token_type_ids: Option<Tensor>,
|
token_type_ids: Option<Tensor>,
|
||||||
position_ids: Option<Tensor>,
|
position_ids: Option<Tensor>,
|
||||||
input_embeds: Option<Tensor>,
|
input_embeds: Option<Tensor>,
|
||||||
train: bool) -> Result<Tensor, &'static str> {
|
train: bool,
|
||||||
|
) -> Result<Tensor, &'static str> {
|
||||||
let (input_embeddings, input_shape) = match &input_ids {
|
let (input_embeddings, input_shape) = match &input_ids {
|
||||||
Some(input_value) => match &input_embeds {
|
Some(input_value) => match &input_embeds {
|
||||||
Some(_) => { return Err("Only one of input ids or input embeddings may be set"); }
|
Some(_) => {
|
||||||
None => (input_value.apply_t(&self.word_embeddings, train), input_value.size())
|
return Err("Only one of input ids or input embeddings may be set");
|
||||||
}
|
}
|
||||||
|
None => (
|
||||||
|
input_value.apply_t(&self.word_embeddings, train),
|
||||||
|
input_value.size(),
|
||||||
|
),
|
||||||
|
},
|
||||||
None => match &input_embeds {
|
None => match &input_embeds {
|
||||||
Some(embeds) => (embeds.copy(), vec!(embeds.size()[0], embeds.size()[1])),
|
Some(embeds) => (embeds.copy(), vec![embeds.size()[0], embeds.size()[1]]),
|
||||||
None => { return Err("Only one of input ids or input embeddings may be set"); }
|
None => {
|
||||||
}
|
return Err("Only one of input ids or input embeddings may be set");
|
||||||
|
}
|
||||||
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
let position_ids = match position_ids {
|
let position_ids = match position_ids {
|
||||||
Some(value) => value,
|
Some(value) => value,
|
||||||
None => match input_ids {
|
None => match input_ids {
|
||||||
Some(value) => self.create_position_ids_from_input_ids(&value),
|
Some(value) => self.create_position_ids_from_input_ids(&value),
|
||||||
None => self.create_position_ids_from_embeddings(&input_embeds.unwrap())
|
None => self.create_position_ids_from_embeddings(&input_embeds.unwrap()),
|
||||||
}
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
let token_type_ids = match token_type_ids {
|
let token_type_ids = match token_type_ids {
|
||||||
Some(value) => value,
|
Some(value) => value,
|
||||||
None => Tensor::zeros(&input_shape, (Kind::Int64, input_embeddings.device()))
|
None => Tensor::zeros(&input_shape, (Kind::Int64, input_embeddings.device())),
|
||||||
};
|
};
|
||||||
|
|
||||||
let position_embeddings = position_ids.apply(&self.position_embeddings);
|
let position_embeddings = position_ids.apply(&self.position_embeddings);
|
||||||
let token_type_embeddings = token_type_ids.apply(&self.token_type_embeddings);
|
let token_type_embeddings = token_type_ids.apply(&self.token_type_embeddings);
|
||||||
|
|
||||||
let input_embeddings: Tensor = input_embeddings + position_embeddings + token_type_embeddings;
|
let input_embeddings: Tensor =
|
||||||
Ok(input_embeddings.apply(&self.layer_norm).apply_t(&self.dropout, train))
|
input_embeddings + position_embeddings + token_type_embeddings;
|
||||||
|
Ok(input_embeddings
|
||||||
|
.apply(&self.layer_norm)
|
||||||
|
.apply_t(&self.dropout, train))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -19,20 +19,28 @@
|
|||||||
//! Pretrained models are available and can be downloaded using RemoteResources.
|
//! Pretrained models are available and can be downloaded using RemoteResources.
|
||||||
//!
|
//!
|
||||||
//! ```no_run
|
//! ```no_run
|
||||||
//!# fn main() -> failure::Fallible<()> {
|
//! # fn main() -> failure::Fallible<()> {
|
||||||
//!#
|
//! #
|
||||||
//! use rust_tokenizers::RobertaTokenizer;
|
//! use rust_tokenizers::RobertaTokenizer;
|
||||||
//! use tch::{nn, Device};
|
//! use tch::{nn, Device};
|
||||||
//!# use std::path::PathBuf;
|
//! # use std::path::PathBuf;
|
||||||
//! use rust_bert::bert::BertConfig;
|
//! use rust_bert::bert::BertConfig;
|
||||||
//! use rust_bert::Config;
|
//! use rust_bert::resources::{download_resource, LocalResource, Resource};
|
||||||
//! use rust_bert::roberta::RobertaForMaskedLM;
|
//! use rust_bert::roberta::RobertaForMaskedLM;
|
||||||
//! use rust_bert::resources::{Resource, download_resource, LocalResource};
|
//! use rust_bert::Config;
|
||||||
//!
|
//!
|
||||||
//! let config_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/config.json")});
|
//! let config_resource = Resource::Local(LocalResource {
|
||||||
//! let vocab_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/vocab.txt")});
|
//! local_path: PathBuf::from("path/to/config.json"),
|
||||||
//! let merges_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/vocab.txt")});
|
//! });
|
||||||
//! let weights_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/model.ot")});
|
//! let vocab_resource = Resource::Local(LocalResource {
|
||||||
|
//! local_path: PathBuf::from("path/to/vocab.txt"),
|
||||||
|
//! });
|
||||||
|
//! let merges_resource = Resource::Local(LocalResource {
|
||||||
|
//! local_path: PathBuf::from("path/to/vocab.txt"),
|
||||||
|
//! });
|
||||||
|
//! let weights_resource = Resource::Local(LocalResource {
|
||||||
|
//! local_path: PathBuf::from("path/to/model.ot"),
|
||||||
|
//! });
|
||||||
//! let config_path = download_resource(&config_resource)?;
|
//! let config_path = download_resource(&config_resource)?;
|
||||||
//! let vocab_path = download_resource(&vocab_resource)?;
|
//! let vocab_path = download_resource(&vocab_resource)?;
|
||||||
//! let merges_path = download_resource(&merges_resource)?;
|
//! let merges_path = download_resource(&merges_resource)?;
|
||||||
@ -40,19 +48,25 @@
|
|||||||
//!
|
//!
|
||||||
//! let device = Device::cuda_if_available();
|
//! let device = Device::cuda_if_available();
|
||||||
//! let mut vs = nn::VarStore::new(device);
|
//! let mut vs = nn::VarStore::new(device);
|
||||||
//! let tokenizer: RobertaTokenizer = RobertaTokenizer::from_file(vocab_path.to_str().unwrap(), merges_path.to_str().unwrap(), true);
|
//! let tokenizer: RobertaTokenizer = RobertaTokenizer::from_file(
|
||||||
|
//! vocab_path.to_str().unwrap(),
|
||||||
|
//! merges_path.to_str().unwrap(),
|
||||||
|
//! true,
|
||||||
|
//! );
|
||||||
//! let config = BertConfig::from_file(config_path);
|
//! let config = BertConfig::from_file(config_path);
|
||||||
//! let bert_model = RobertaForMaskedLM::new(&vs.root(), &config);
|
//! let bert_model = RobertaForMaskedLM::new(&vs.root(), &config);
|
||||||
//! vs.load(weights_path)?;
|
//! vs.load(weights_path)?;
|
||||||
//!
|
//!
|
||||||
//!# Ok(())
|
//! # Ok(())
|
||||||
//!# }
|
//! # }
|
||||||
//! ```
|
//! ```
|
||||||
|
|
||||||
|
|
||||||
mod embeddings;
|
mod embeddings;
|
||||||
mod roberta;
|
mod roberta;
|
||||||
|
|
||||||
pub use roberta::{RobertaModelResources, RobertaConfigResources, RobertaVocabResources, RobertaMergesResources,
|
pub use embeddings::RobertaEmbeddings;
|
||||||
RobertaForMaskedLM, RobertaForMultipleChoice, RobertaForTokenClassification, RobertaForQuestionAnswering, RobertaForSequenceClassification};
|
pub use roberta::{
|
||||||
pub use embeddings::RobertaEmbeddings;
|
RobertaConfigResources, RobertaForMaskedLM, RobertaForMultipleChoice,
|
||||||
|
RobertaForQuestionAnswering, RobertaForSequenceClassification, RobertaForTokenClassification,
|
||||||
|
RobertaMergesResources, RobertaModelResources, RobertaVocabResources,
|
||||||
|
};
|
||||||
|
@ -11,13 +11,13 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
use tch::{nn, Tensor};
|
|
||||||
use crate::common::linear::{linear_no_bias, LinearNoBias};
|
|
||||||
use tch::nn::Init;
|
|
||||||
use crate::common::activations::_gelu;
|
|
||||||
use crate::roberta::embeddings::RobertaEmbeddings;
|
|
||||||
use crate::common::dropout::Dropout;
|
|
||||||
use crate::bert::{BertConfig, BertModel};
|
use crate::bert::{BertConfig, BertModel};
|
||||||
|
use crate::common::activations::_gelu;
|
||||||
|
use crate::common::dropout::Dropout;
|
||||||
|
use crate::common::linear::{linear_no_bias, LinearNoBias};
|
||||||
|
use crate::roberta::embeddings::RobertaEmbeddings;
|
||||||
|
use tch::nn::Init;
|
||||||
|
use tch::{nn, Tensor};
|
||||||
|
|
||||||
/// # RoBERTa Pretrained model weight files
|
/// # RoBERTa Pretrained model weight files
|
||||||
pub struct RobertaModelResources;
|
pub struct RobertaModelResources;
|
||||||
@ -33,22 +33,34 @@ pub struct RobertaMergesResources;
|
|||||||
|
|
||||||
impl RobertaModelResources {
|
impl RobertaModelResources {
|
||||||
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
|
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
|
||||||
pub const ROBERTA: (&'static str, &'static str) = ("roberta/model.ot", "https://cdn.huggingface.co/roberta-base-rust_model.ot");
|
pub const ROBERTA: (&'static str, &'static str) = (
|
||||||
|
"roberta/model.ot",
|
||||||
|
"https://cdn.huggingface.co/roberta-base-rust_model.ot",
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
impl RobertaConfigResources {
|
impl RobertaConfigResources {
|
||||||
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
|
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
|
||||||
pub const ROBERTA: (&'static str, &'static str) = ("roberta/config.json", "https://cdn.huggingface.co/roberta-base-config.json");
|
pub const ROBERTA: (&'static str, &'static str) = (
|
||||||
|
"roberta/config.json",
|
||||||
|
"https://cdn.huggingface.co/roberta-base-config.json",
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
impl RobertaVocabResources {
|
impl RobertaVocabResources {
|
||||||
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
|
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
|
||||||
pub const ROBERTA: (&'static str, &'static str) = ("roberta/vocab.txt", "https://cdn.huggingface.co/roberta-base-vocab.json");
|
pub const ROBERTA: (&'static str, &'static str) = (
|
||||||
|
"roberta/vocab.txt",
|
||||||
|
"https://cdn.huggingface.co/roberta-base-vocab.json",
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
impl RobertaMergesResources {
|
impl RobertaMergesResources {
|
||||||
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
|
/// Shared under MIT license by the Facebook AI Research Fairseq team at https://github.com/pytorch/fairseq. Modified with conversion to C-array format.
|
||||||
pub const ROBERTA: (&'static str, &'static str) = ("roberta/merges.txt", "https://cdn.huggingface.co/roberta-base-merges.txt");
|
pub const ROBERTA: (&'static str, &'static str) = (
|
||||||
|
"roberta/merges.txt",
|
||||||
|
"https://cdn.huggingface.co/roberta-base-merges.txt",
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct RobertaLMHead {
|
pub struct RobertaLMHead {
|
||||||
@ -60,17 +72,42 @@ pub struct RobertaLMHead {
|
|||||||
|
|
||||||
impl RobertaLMHead {
|
impl RobertaLMHead {
|
||||||
pub fn new(p: &nn::Path, config: &BertConfig) -> RobertaLMHead {
|
pub fn new(p: &nn::Path, config: &BertConfig) -> RobertaLMHead {
|
||||||
let dense = nn::linear(p / "dense", config.hidden_size, config.hidden_size, Default::default());
|
let dense = nn::linear(
|
||||||
let layer_norm_config = nn::LayerNormConfig { eps: 1e-12, ..Default::default() };
|
p / "dense",
|
||||||
let layer_norm = nn::layer_norm(p / "layer_norm", vec![config.hidden_size], layer_norm_config);
|
config.hidden_size,
|
||||||
let decoder = linear_no_bias(&(p / "decoder"), config.hidden_size, config.vocab_size, Default::default());
|
config.hidden_size,
|
||||||
|
Default::default(),
|
||||||
|
);
|
||||||
|
let layer_norm_config = nn::LayerNormConfig {
|
||||||
|
eps: 1e-12,
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
let layer_norm = nn::layer_norm(
|
||||||
|
p / "layer_norm",
|
||||||
|
vec![config.hidden_size],
|
||||||
|
layer_norm_config,
|
||||||
|
);
|
||||||
|
let decoder = linear_no_bias(
|
||||||
|
&(p / "decoder"),
|
||||||
|
config.hidden_size,
|
||||||
|
config.vocab_size,
|
||||||
|
Default::default(),
|
||||||
|
);
|
||||||
let bias = p.var("bias", &[config.vocab_size], Init::KaimingUniform);
|
let bias = p.var("bias", &[config.vocab_size], Init::KaimingUniform);
|
||||||
|
|
||||||
RobertaLMHead { dense, decoder, layer_norm, bias }
|
RobertaLMHead {
|
||||||
|
dense,
|
||||||
|
decoder,
|
||||||
|
layer_norm,
|
||||||
|
bias,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn forward(&self, hidden_states: &Tensor) -> Tensor {
|
pub fn forward(&self, hidden_states: &Tensor) -> Tensor {
|
||||||
(_gelu(&hidden_states.apply(&self.dense))).apply(&self.layer_norm).apply(&self.decoder) + &self.bias
|
(_gelu(&hidden_states.apply(&self.dense)))
|
||||||
|
.apply(&self.layer_norm)
|
||||||
|
.apply(&self.decoder)
|
||||||
|
+ &self.bias
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -96,10 +133,10 @@ impl RobertaForMaskedLM {
|
|||||||
///
|
///
|
||||||
/// ```no_run
|
/// ```no_run
|
||||||
/// use rust_bert::bert::BertConfig;
|
/// use rust_bert::bert::BertConfig;
|
||||||
/// use tch::{nn, Device};
|
/// use rust_bert::roberta::RobertaForMaskedLM;
|
||||||
/// use rust_bert::Config;
|
/// use rust_bert::Config;
|
||||||
/// use std::path::Path;
|
/// use std::path::Path;
|
||||||
/// use rust_bert::roberta::RobertaForMaskedLM;
|
/// use tch::{nn, Device};
|
||||||
///
|
///
|
||||||
/// let config_path = Path::new("path/to/config.json");
|
/// let config_path = Path::new("path/to/config.json");
|
||||||
/// let device = Device::Cpu;
|
/// let device = Device::Cpu;
|
||||||
@ -107,7 +144,6 @@ impl RobertaForMaskedLM {
|
|||||||
/// let config = BertConfig::from_file(config_path);
|
/// let config = BertConfig::from_file(config_path);
|
||||||
/// let roberta = RobertaForMaskedLM::new(&(&p.root() / "roberta"), &config);
|
/// let roberta = RobertaForMaskedLM::new(&(&p.root() / "roberta"), &config);
|
||||||
/// ```
|
/// ```
|
||||||
///
|
|
||||||
pub fn new(p: &nn::Path, config: &BertConfig) -> RobertaForMaskedLM {
|
pub fn new(p: &nn::Path, config: &BertConfig) -> RobertaForMaskedLM {
|
||||||
let roberta = BertModel::<RobertaEmbeddings>::new(&(p / "roberta"), config);
|
let roberta = BertModel::<RobertaEmbeddings>::new(&(p / "roberta"), config);
|
||||||
let lm_head = RobertaLMHead::new(&(p / "lm_head"), config);
|
let lm_head = RobertaLMHead::new(&(p / "lm_head"), config);
|
||||||
@ -137,49 +173,62 @@ impl RobertaForMaskedLM {
|
|||||||
/// # Example
|
/// # Example
|
||||||
///
|
///
|
||||||
/// ```no_run
|
/// ```no_run
|
||||||
///# use rust_bert::bert::BertConfig;
|
/// # use rust_bert::bert::BertConfig;
|
||||||
///# use tch::{nn, Device, Tensor, no_grad};
|
/// # use tch::{nn, Device, Tensor, no_grad};
|
||||||
///# use rust_bert::Config;
|
/// # use rust_bert::Config;
|
||||||
///# use std::path::Path;
|
/// # use std::path::Path;
|
||||||
///# use tch::kind::Kind::Int64;
|
/// # use tch::kind::Kind::Int64;
|
||||||
/// use rust_bert::roberta::RobertaForMaskedLM;
|
/// use rust_bert::roberta::RobertaForMaskedLM;
|
||||||
///# let config_path = Path::new("path/to/config.json");
|
/// # let config_path = Path::new("path/to/config.json");
|
||||||
///# let vocab_path = Path::new("path/to/vocab.txt");
|
/// # let vocab_path = Path::new("path/to/vocab.txt");
|
||||||
///# let device = Device::Cpu;
|
/// # let device = Device::Cpu;
|
||||||
///# let vs = nn::VarStore::new(device);
|
/// # let vs = nn::VarStore::new(device);
|
||||||
///# let config = BertConfig::from_file(config_path);
|
/// # let config = BertConfig::from_file(config_path);
|
||||||
///# let roberta_model = RobertaForMaskedLM::new(&vs.root(), &config);
|
/// # let roberta_model = RobertaForMaskedLM::new(&vs.root(), &config);
|
||||||
/// let (batch_size, sequence_length) = (64, 128);
|
/// let (batch_size, sequence_length) = (64, 128);
|
||||||
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
|
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
|
||||||
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
|
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
|
||||||
/// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
|
/// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
|
||||||
/// let position_ids = Tensor::arange(sequence_length, (Int64, device)).expand(&[batch_size, sequence_length], true);
|
/// let position_ids = Tensor::arange(sequence_length, (Int64, device))
|
||||||
///
|
/// .expand(&[batch_size, sequence_length], true);
|
||||||
/// let (output, all_hidden_states, all_attentions) = no_grad(|| {
|
|
||||||
/// roberta_model
|
|
||||||
/// .forward_t(Some(input_tensor),
|
|
||||||
/// Some(mask),
|
|
||||||
/// Some(token_type_ids),
|
|
||||||
/// Some(position_ids),
|
|
||||||
/// None,
|
|
||||||
/// &None,
|
|
||||||
/// &None,
|
|
||||||
/// false)
|
|
||||||
/// });
|
|
||||||
///
|
///
|
||||||
|
/// let (output, all_hidden_states, all_attentions) = no_grad(|| {
|
||||||
|
/// roberta_model.forward_t(
|
||||||
|
/// Some(input_tensor),
|
||||||
|
/// Some(mask),
|
||||||
|
/// Some(token_type_ids),
|
||||||
|
/// Some(position_ids),
|
||||||
|
/// None,
|
||||||
|
/// &None,
|
||||||
|
/// &None,
|
||||||
|
/// false,
|
||||||
|
/// )
|
||||||
|
/// });
|
||||||
/// ```
|
/// ```
|
||||||
///
|
pub fn forward_t(
|
||||||
pub fn forward_t(&self,
|
&self,
|
||||||
input_ids: Option<Tensor>,
|
input_ids: Option<Tensor>,
|
||||||
mask: Option<Tensor>,
|
mask: Option<Tensor>,
|
||||||
token_type_ids: Option<Tensor>,
|
token_type_ids: Option<Tensor>,
|
||||||
position_ids: Option<Tensor>,
|
position_ids: Option<Tensor>,
|
||||||
input_embeds: Option<Tensor>,
|
input_embeds: Option<Tensor>,
|
||||||
encoder_hidden_states: &Option<Tensor>,
|
encoder_hidden_states: &Option<Tensor>,
|
||||||
encoder_mask: &Option<Tensor>,
|
encoder_mask: &Option<Tensor>,
|
||||||
train: bool) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
|
train: bool,
|
||||||
let (hidden_state, _, all_hidden_states, all_attentions) = self.roberta.forward_t(input_ids, mask, token_type_ids, position_ids,
|
) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
|
||||||
input_embeds, encoder_hidden_states, encoder_mask, train).unwrap();
|
let (hidden_state, _, all_hidden_states, all_attentions) = self
|
||||||
|
.roberta
|
||||||
|
.forward_t(
|
||||||
|
input_ids,
|
||||||
|
mask,
|
||||||
|
token_type_ids,
|
||||||
|
position_ids,
|
||||||
|
input_embeds,
|
||||||
|
encoder_hidden_states,
|
||||||
|
encoder_mask,
|
||||||
|
train,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
let prediction_scores = self.lm_head.forward(&hidden_state);
|
let prediction_scores = self.lm_head.forward(&hidden_state);
|
||||||
(prediction_scores, all_hidden_states, all_attentions)
|
(prediction_scores, all_hidden_states, all_attentions)
|
||||||
@ -194,12 +243,30 @@ pub struct RobertaClassificationHead {
|
|||||||
|
|
||||||
impl RobertaClassificationHead {
|
impl RobertaClassificationHead {
|
||||||
pub fn new(p: &nn::Path, config: &BertConfig) -> RobertaClassificationHead {
|
pub fn new(p: &nn::Path, config: &BertConfig) -> RobertaClassificationHead {
|
||||||
let dense = nn::linear(p / "dense", config.hidden_size, config.hidden_size, Default::default());
|
let dense = nn::linear(
|
||||||
let num_labels = config.id2label.as_ref().expect("num_labels not provided in configuration").len() as i64;
|
p / "dense",
|
||||||
let out_proj = nn::linear(p / "out_proj", config.hidden_size, num_labels, Default::default());
|
config.hidden_size,
|
||||||
|
config.hidden_size,
|
||||||
|
Default::default(),
|
||||||
|
);
|
||||||
|
let num_labels = config
|
||||||
|
.id2label
|
||||||
|
.as_ref()
|
||||||
|
.expect("num_labels not provided in configuration")
|
||||||
|
.len() as i64;
|
||||||
|
let out_proj = nn::linear(
|
||||||
|
p / "out_proj",
|
||||||
|
config.hidden_size,
|
||||||
|
num_labels,
|
||||||
|
Default::default(),
|
||||||
|
);
|
||||||
let dropout = Dropout::new(config.hidden_dropout_prob);
|
let dropout = Dropout::new(config.hidden_dropout_prob);
|
||||||
|
|
||||||
RobertaClassificationHead { dense, dropout, out_proj }
|
RobertaClassificationHead {
|
||||||
|
dense,
|
||||||
|
dropout,
|
||||||
|
out_proj,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn forward_t(&self, hidden_states: &Tensor, train: bool) -> Tensor {
|
pub fn forward_t(&self, hidden_states: &Tensor, train: bool) -> Tensor {
|
||||||
@ -235,10 +302,10 @@ impl RobertaForSequenceClassification {
|
|||||||
///
|
///
|
||||||
/// ```no_run
|
/// ```no_run
|
||||||
/// use rust_bert::bert::BertConfig;
|
/// use rust_bert::bert::BertConfig;
|
||||||
/// use tch::{nn, Device};
|
/// use rust_bert::roberta::RobertaForSequenceClassification;
|
||||||
/// use rust_bert::Config;
|
/// use rust_bert::Config;
|
||||||
/// use std::path::Path;
|
/// use std::path::Path;
|
||||||
/// use rust_bert::roberta::RobertaForSequenceClassification;
|
/// use tch::{nn, Device};
|
||||||
///
|
///
|
||||||
/// let config_path = Path::new("path/to/config.json");
|
/// let config_path = Path::new("path/to/config.json");
|
||||||
/// let device = Device::Cpu;
|
/// let device = Device::Cpu;
|
||||||
@ -246,12 +313,14 @@ impl RobertaForSequenceClassification {
|
|||||||
/// let config = BertConfig::from_file(config_path);
|
/// let config = BertConfig::from_file(config_path);
|
||||||
/// let roberta = RobertaForSequenceClassification::new(&(&p.root() / "roberta"), &config);
|
/// let roberta = RobertaForSequenceClassification::new(&(&p.root() / "roberta"), &config);
|
||||||
/// ```
|
/// ```
|
||||||
///
|
|
||||||
pub fn new(p: &nn::Path, config: &BertConfig) -> RobertaForSequenceClassification {
|
pub fn new(p: &nn::Path, config: &BertConfig) -> RobertaForSequenceClassification {
|
||||||
let roberta = BertModel::<RobertaEmbeddings>::new(&(p / "roberta"), config);
|
let roberta = BertModel::<RobertaEmbeddings>::new(&(p / "roberta"), config);
|
||||||
let classifier = RobertaClassificationHead::new(&(p / "classifier"), config);
|
let classifier = RobertaClassificationHead::new(&(p / "classifier"), config);
|
||||||
|
|
||||||
RobertaForSequenceClassification { roberta, classifier }
|
RobertaForSequenceClassification {
|
||||||
|
roberta,
|
||||||
|
classifier,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Forward pass through the model
|
/// Forward pass through the model
|
||||||
@ -274,45 +343,58 @@ impl RobertaForSequenceClassification {
|
|||||||
/// # Example
|
/// # Example
|
||||||
///
|
///
|
||||||
/// ```no_run
|
/// ```no_run
|
||||||
///# use rust_bert::bert::BertConfig;
|
/// # use rust_bert::bert::BertConfig;
|
||||||
///# use tch::{nn, Device, Tensor, no_grad};
|
/// # use tch::{nn, Device, Tensor, no_grad};
|
||||||
///# use rust_bert::Config;
|
/// # use rust_bert::Config;
|
||||||
///# use std::path::Path;
|
/// # use std::path::Path;
|
||||||
///# use tch::kind::Kind::Int64;
|
/// # use tch::kind::Kind::Int64;
|
||||||
/// use rust_bert::roberta::RobertaForSequenceClassification;
|
/// use rust_bert::roberta::RobertaForSequenceClassification;
|
||||||
///# let config_path = Path::new("path/to/config.json");
|
/// # let config_path = Path::new("path/to/config.json");
|
||||||
///# let vocab_path = Path::new("path/to/vocab.txt");
|
/// # let vocab_path = Path::new("path/to/vocab.txt");
|
||||||
///# let device = Device::Cpu;
|
/// # let device = Device::Cpu;
|
||||||
///# let vs = nn::VarStore::new(device);
|
/// # let vs = nn::VarStore::new(device);
|
||||||
///# let config = BertConfig::from_file(config_path);
|
/// # let config = BertConfig::from_file(config_path);
|
||||||
///# let roberta_model = RobertaForSequenceClassification::new(&vs.root(), &config);
|
/// # let roberta_model = RobertaForSequenceClassification::new(&vs.root(), &config);
|
||||||
/// let (batch_size, sequence_length) = (64, 128);
|
/// let (batch_size, sequence_length) = (64, 128);
|
||||||
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
|
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
|
||||||
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
|
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
|
||||||
/// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
|
/// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
|
||||||
/// let position_ids = Tensor::arange(sequence_length, (Int64, device)).expand(&[batch_size, sequence_length], true);
|
/// let position_ids = Tensor::arange(sequence_length, (Int64, device))
|
||||||
///
|
/// .expand(&[batch_size, sequence_length], true);
|
||||||
/// let (labels, all_hidden_states, all_attentions) = no_grad(|| {
|
|
||||||
/// roberta_model
|
|
||||||
/// .forward_t(Some(input_tensor),
|
|
||||||
/// Some(mask),
|
|
||||||
/// Some(token_type_ids),
|
|
||||||
/// Some(position_ids),
|
|
||||||
/// None,
|
|
||||||
/// false)
|
|
||||||
/// });
|
|
||||||
///
|
///
|
||||||
|
/// let (labels, all_hidden_states, all_attentions) = no_grad(|| {
|
||||||
|
/// roberta_model.forward_t(
|
||||||
|
/// Some(input_tensor),
|
||||||
|
/// Some(mask),
|
||||||
|
/// Some(token_type_ids),
|
||||||
|
/// Some(position_ids),
|
||||||
|
/// None,
|
||||||
|
/// false,
|
||||||
|
/// )
|
||||||
|
/// });
|
||||||
/// ```
|
/// ```
|
||||||
///
|
pub fn forward_t(
|
||||||
pub fn forward_t(&self,
|
&self,
|
||||||
input_ids: Option<Tensor>,
|
input_ids: Option<Tensor>,
|
||||||
mask: Option<Tensor>,
|
mask: Option<Tensor>,
|
||||||
token_type_ids: Option<Tensor>,
|
token_type_ids: Option<Tensor>,
|
||||||
position_ids: Option<Tensor>,
|
position_ids: Option<Tensor>,
|
||||||
input_embeds: Option<Tensor>,
|
input_embeds: Option<Tensor>,
|
||||||
train: bool) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
|
train: bool,
|
||||||
let (hidden_state, _, all_hidden_states, all_attentions) = self.roberta.forward_t(input_ids, mask, token_type_ids, position_ids,
|
) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
|
||||||
input_embeds, &None, &None, train).unwrap();
|
let (hidden_state, _, all_hidden_states, all_attentions) = self
|
||||||
|
.roberta
|
||||||
|
.forward_t(
|
||||||
|
input_ids,
|
||||||
|
mask,
|
||||||
|
token_type_ids,
|
||||||
|
position_ids,
|
||||||
|
input_embeds,
|
||||||
|
&None,
|
||||||
|
&None,
|
||||||
|
train,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
let output = self.classifier.forward_t(&hidden_state, train);
|
let output = self.classifier.forward_t(&hidden_state, train);
|
||||||
(output, all_hidden_states, all_attentions)
|
(output, all_hidden_states, all_attentions)
|
||||||
@ -344,10 +426,10 @@ impl RobertaForMultipleChoice {
|
|||||||
///
|
///
|
||||||
/// ```no_run
|
/// ```no_run
|
||||||
/// use rust_bert::bert::BertConfig;
|
/// use rust_bert::bert::BertConfig;
|
||||||
/// use tch::{nn, Device};
|
/// use rust_bert::roberta::RobertaForMultipleChoice;
|
||||||
/// use rust_bert::Config;
|
/// use rust_bert::Config;
|
||||||
/// use std::path::Path;
|
/// use std::path::Path;
|
||||||
/// use rust_bert::roberta::RobertaForMultipleChoice;
|
/// use tch::{nn, Device};
|
||||||
///
|
///
|
||||||
/// let config_path = Path::new("path/to/config.json");
|
/// let config_path = Path::new("path/to/config.json");
|
||||||
/// let device = Device::Cpu;
|
/// let device = Device::Cpu;
|
||||||
@ -355,13 +437,16 @@ impl RobertaForMultipleChoice {
|
|||||||
/// let config = BertConfig::from_file(config_path);
|
/// let config = BertConfig::from_file(config_path);
|
||||||
/// let roberta = RobertaForMultipleChoice::new(&(&p.root() / "roberta"), &config);
|
/// let roberta = RobertaForMultipleChoice::new(&(&p.root() / "roberta"), &config);
|
||||||
/// ```
|
/// ```
|
||||||
///
|
|
||||||
pub fn new(p: &nn::Path, config: &BertConfig) -> RobertaForMultipleChoice {
|
pub fn new(p: &nn::Path, config: &BertConfig) -> RobertaForMultipleChoice {
|
||||||
let roberta = BertModel::<RobertaEmbeddings>::new(&(p / "roberta"), config);
|
let roberta = BertModel::<RobertaEmbeddings>::new(&(p / "roberta"), config);
|
||||||
let dropout = Dropout::new(config.hidden_dropout_prob);
|
let dropout = Dropout::new(config.hidden_dropout_prob);
|
||||||
let classifier = nn::linear(p / "classifier", config.hidden_size, 1, Default::default());
|
let classifier = nn::linear(p / "classifier", config.hidden_size, 1, Default::default());
|
||||||
|
|
||||||
RobertaForMultipleChoice { roberta, dropout, classifier }
|
RobertaForMultipleChoice {
|
||||||
|
roberta,
|
||||||
|
dropout,
|
||||||
|
classifier,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Forward pass through the model
|
/// Forward pass through the model
|
||||||
@ -383,61 +468,77 @@ impl RobertaForMultipleChoice {
|
|||||||
/// # Example
|
/// # Example
|
||||||
///
|
///
|
||||||
/// ```no_run
|
/// ```no_run
|
||||||
///# use rust_bert::bert::BertConfig;
|
/// # use rust_bert::bert::BertConfig;
|
||||||
///# use tch::{nn, Device, Tensor, no_grad};
|
/// # use tch::{nn, Device, Tensor, no_grad};
|
||||||
///# use rust_bert::Config;
|
/// # use rust_bert::Config;
|
||||||
///# use std::path::Path;
|
/// # use std::path::Path;
|
||||||
///# use tch::kind::Kind::Int64;
|
/// # use tch::kind::Kind::Int64;
|
||||||
/// use rust_bert::roberta::RobertaForMultipleChoice;
|
/// use rust_bert::roberta::RobertaForMultipleChoice;
|
||||||
///# let config_path = Path::new("path/to/config.json");
|
/// # let config_path = Path::new("path/to/config.json");
|
||||||
///# let vocab_path = Path::new("path/to/vocab.txt");
|
/// # let vocab_path = Path::new("path/to/vocab.txt");
|
||||||
///# let device = Device::Cpu;
|
/// # let device = Device::Cpu;
|
||||||
///# let vs = nn::VarStore::new(device);
|
/// # let vs = nn::VarStore::new(device);
|
||||||
///# let config = BertConfig::from_file(config_path);
|
/// # let config = BertConfig::from_file(config_path);
|
||||||
///# let roberta_model = RobertaForMultipleChoice::new(&vs.root(), &config);
|
/// # let roberta_model = RobertaForMultipleChoice::new(&vs.root(), &config);
|
||||||
/// let (num_choices, sequence_length) = (3, 128);
|
/// let (num_choices, sequence_length) = (3, 128);
|
||||||
/// let input_tensor = Tensor::rand(&[num_choices, sequence_length], (Int64, device));
|
/// let input_tensor = Tensor::rand(&[num_choices, sequence_length], (Int64, device));
|
||||||
/// let mask = Tensor::zeros(&[num_choices, sequence_length], (Int64, device));
|
/// let mask = Tensor::zeros(&[num_choices, sequence_length], (Int64, device));
|
||||||
/// let token_type_ids = Tensor::zeros(&[num_choices, sequence_length], (Int64, device));
|
/// let token_type_ids = Tensor::zeros(&[num_choices, sequence_length], (Int64, device));
|
||||||
/// let position_ids = Tensor::arange(sequence_length, (Int64, device)).expand(&[num_choices, sequence_length], true);
|
/// let position_ids = Tensor::arange(sequence_length, (Int64, device))
|
||||||
///
|
/// .expand(&[num_choices, sequence_length], true);
|
||||||
/// let (choices, all_hidden_states, all_attentions) = no_grad(|| {
|
|
||||||
/// roberta_model
|
|
||||||
/// .forward_t(input_tensor,
|
|
||||||
/// Some(mask),
|
|
||||||
/// Some(token_type_ids),
|
|
||||||
/// Some(position_ids),
|
|
||||||
/// false)
|
|
||||||
/// });
|
|
||||||
///
|
///
|
||||||
|
/// let (choices, all_hidden_states, all_attentions) = no_grad(|| {
|
||||||
|
/// roberta_model.forward_t(
|
||||||
|
/// input_tensor,
|
||||||
|
/// Some(mask),
|
||||||
|
/// Some(token_type_ids),
|
||||||
|
/// Some(position_ids),
|
||||||
|
/// false,
|
||||||
|
/// )
|
||||||
|
/// });
|
||||||
/// ```
|
/// ```
|
||||||
///
|
pub fn forward_t(
|
||||||
pub fn forward_t(&self,
|
&self,
|
||||||
input_ids: Tensor,
|
input_ids: Tensor,
|
||||||
mask: Option<Tensor>,
|
mask: Option<Tensor>,
|
||||||
token_type_ids: Option<Tensor>,
|
token_type_ids: Option<Tensor>,
|
||||||
position_ids: Option<Tensor>,
|
position_ids: Option<Tensor>,
|
||||||
train: bool) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
|
train: bool,
|
||||||
|
) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
|
||||||
let num_choices = input_ids.size()[1];
|
let num_choices = input_ids.size()[1];
|
||||||
|
|
||||||
let flat_input_ids = Some(input_ids.view((-1i64, *input_ids.size().last().unwrap())));
|
let flat_input_ids = Some(input_ids.view((-1i64, *input_ids.size().last().unwrap())));
|
||||||
let flat_position_ids = match position_ids {
|
let flat_position_ids = match position_ids {
|
||||||
Some(value) => Some(value.view((-1i64, *value.size().last().unwrap()))),
|
Some(value) => Some(value.view((-1i64, *value.size().last().unwrap()))),
|
||||||
None => None
|
None => None,
|
||||||
};
|
};
|
||||||
let flat_token_type_ids = match token_type_ids {
|
let flat_token_type_ids = match token_type_ids {
|
||||||
Some(value) => Some(value.view((-1i64, *value.size().last().unwrap()))),
|
Some(value) => Some(value.view((-1i64, *value.size().last().unwrap()))),
|
||||||
None => None
|
None => None,
|
||||||
};
|
};
|
||||||
let flat_mask = match mask {
|
let flat_mask = match mask {
|
||||||
Some(value) => Some(value.view((-1i64, *value.size().last().unwrap()))),
|
Some(value) => Some(value.view((-1i64, *value.size().last().unwrap()))),
|
||||||
None => None
|
None => None,
|
||||||
};
|
};
|
||||||
|
|
||||||
let (_, pooled_output, all_hidden_states, all_attentions) = self.roberta.forward_t(flat_input_ids, flat_mask, flat_token_type_ids, flat_position_ids,
|
let (_, pooled_output, all_hidden_states, all_attentions) = self
|
||||||
None, &None, &None, train).unwrap();
|
.roberta
|
||||||
|
.forward_t(
|
||||||
|
flat_input_ids,
|
||||||
|
flat_mask,
|
||||||
|
flat_token_type_ids,
|
||||||
|
flat_position_ids,
|
||||||
|
None,
|
||||||
|
&None,
|
||||||
|
&None,
|
||||||
|
train,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
let output = pooled_output.apply_t(&self.dropout, train).apply(&self.classifier).view((-1, num_choices));
|
let output = pooled_output
|
||||||
|
.apply_t(&self.dropout, train)
|
||||||
|
.apply(&self.classifier)
|
||||||
|
.view((-1, num_choices));
|
||||||
(output, all_hidden_states, all_attentions)
|
(output, all_hidden_states, all_attentions)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -466,10 +567,10 @@ impl RobertaForTokenClassification {
|
|||||||
///
|
///
|
||||||
/// ```no_run
|
/// ```no_run
|
||||||
/// use rust_bert::bert::BertConfig;
|
/// use rust_bert::bert::BertConfig;
|
||||||
/// use tch::{nn, Device};
|
/// use rust_bert::roberta::RobertaForTokenClassification;
|
||||||
/// use rust_bert::Config;
|
/// use rust_bert::Config;
|
||||||
/// use std::path::Path;
|
/// use std::path::Path;
|
||||||
/// use rust_bert::roberta::RobertaForTokenClassification;
|
/// use tch::{nn, Device};
|
||||||
///
|
///
|
||||||
/// let config_path = Path::new("path/to/config.json");
|
/// let config_path = Path::new("path/to/config.json");
|
||||||
/// let device = Device::Cpu;
|
/// let device = Device::Cpu;
|
||||||
@ -477,14 +578,26 @@ impl RobertaForTokenClassification {
|
|||||||
/// let config = BertConfig::from_file(config_path);
|
/// let config = BertConfig::from_file(config_path);
|
||||||
/// let roberta = RobertaForTokenClassification::new(&(&p.root() / "roberta"), &config);
|
/// let roberta = RobertaForTokenClassification::new(&(&p.root() / "roberta"), &config);
|
||||||
/// ```
|
/// ```
|
||||||
///
|
|
||||||
pub fn new(p: &nn::Path, config: &BertConfig) -> RobertaForTokenClassification {
|
pub fn new(p: &nn::Path, config: &BertConfig) -> RobertaForTokenClassification {
|
||||||
let roberta = BertModel::<RobertaEmbeddings>::new(&(p / "roberta"), config);
|
let roberta = BertModel::<RobertaEmbeddings>::new(&(p / "roberta"), config);
|
||||||
let dropout = Dropout::new(config.hidden_dropout_prob);
|
let dropout = Dropout::new(config.hidden_dropout_prob);
|
||||||
let num_labels = config.id2label.as_ref().expect("num_labels not provided in configuration").len() as i64;
|
let num_labels = config
|
||||||
let classifier = nn::linear(p / "classifier", config.hidden_size, num_labels, Default::default());
|
.id2label
|
||||||
|
.as_ref()
|
||||||
|
.expect("num_labels not provided in configuration")
|
||||||
|
.len() as i64;
|
||||||
|
let classifier = nn::linear(
|
||||||
|
p / "classifier",
|
||||||
|
config.hidden_size,
|
||||||
|
num_labels,
|
||||||
|
Default::default(),
|
||||||
|
);
|
||||||
|
|
||||||
RobertaForTokenClassification { roberta, dropout, classifier }
|
RobertaForTokenClassification {
|
||||||
|
roberta,
|
||||||
|
dropout,
|
||||||
|
classifier,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Forward pass through the model
|
/// Forward pass through the model
|
||||||
@ -507,47 +620,62 @@ impl RobertaForTokenClassification {
|
|||||||
/// # Example
|
/// # Example
|
||||||
///
|
///
|
||||||
/// ```no_run
|
/// ```no_run
|
||||||
///# use rust_bert::bert::BertConfig;
|
/// # use rust_bert::bert::BertConfig;
|
||||||
///# use tch::{nn, Device, Tensor, no_grad};
|
/// # use tch::{nn, Device, Tensor, no_grad};
|
||||||
///# use rust_bert::Config;
|
/// # use rust_bert::Config;
|
||||||
///# use std::path::Path;
|
/// # use std::path::Path;
|
||||||
///# use tch::kind::Kind::Int64;
|
/// # use tch::kind::Kind::Int64;
|
||||||
/// use rust_bert::roberta::RobertaForTokenClassification;
|
/// use rust_bert::roberta::RobertaForTokenClassification;
|
||||||
///# let config_path = Path::new("path/to/config.json");
|
/// # let config_path = Path::new("path/to/config.json");
|
||||||
///# let vocab_path = Path::new("path/to/vocab.txt");
|
/// # let vocab_path = Path::new("path/to/vocab.txt");
|
||||||
///# let device = Device::Cpu;
|
/// # let device = Device::Cpu;
|
||||||
///# let vs = nn::VarStore::new(device);
|
/// # let vs = nn::VarStore::new(device);
|
||||||
///# let config = BertConfig::from_file(config_path);
|
/// # let config = BertConfig::from_file(config_path);
|
||||||
///# let roberta_model = RobertaForTokenClassification::new(&vs.root(), &config);
|
/// # let roberta_model = RobertaForTokenClassification::new(&vs.root(), &config);
|
||||||
/// let (batch_size, sequence_length) = (64, 128);
|
/// let (batch_size, sequence_length) = (64, 128);
|
||||||
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
|
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
|
||||||
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
|
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
|
||||||
/// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
|
/// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
|
||||||
/// let position_ids = Tensor::arange(sequence_length, (Int64, device)).expand(&[batch_size, sequence_length], true);
|
/// let position_ids = Tensor::arange(sequence_length, (Int64, device))
|
||||||
///
|
/// .expand(&[batch_size, sequence_length], true);
|
||||||
/// let (token_labels, all_hidden_states, all_attentions) = no_grad(|| {
|
|
||||||
/// roberta_model
|
|
||||||
/// .forward_t(Some(input_tensor),
|
|
||||||
/// Some(mask),
|
|
||||||
/// Some(token_type_ids),
|
|
||||||
/// Some(position_ids),
|
|
||||||
/// None,
|
|
||||||
/// false)
|
|
||||||
/// });
|
|
||||||
///
|
///
|
||||||
|
/// let (token_labels, all_hidden_states, all_attentions) = no_grad(|| {
|
||||||
|
/// roberta_model.forward_t(
|
||||||
|
/// Some(input_tensor),
|
||||||
|
/// Some(mask),
|
||||||
|
/// Some(token_type_ids),
|
||||||
|
/// Some(position_ids),
|
||||||
|
/// None,
|
||||||
|
/// false,
|
||||||
|
/// )
|
||||||
|
/// });
|
||||||
/// ```
|
/// ```
|
||||||
///
|
pub fn forward_t(
|
||||||
pub fn forward_t(&self,
|
&self,
|
||||||
input_ids: Option<Tensor>,
|
input_ids: Option<Tensor>,
|
||||||
mask: Option<Tensor>,
|
mask: Option<Tensor>,
|
||||||
token_type_ids: Option<Tensor>,
|
token_type_ids: Option<Tensor>,
|
||||||
position_ids: Option<Tensor>,
|
position_ids: Option<Tensor>,
|
||||||
input_embeds: Option<Tensor>,
|
input_embeds: Option<Tensor>,
|
||||||
train: bool) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
|
train: bool,
|
||||||
let (hidden_state, _, all_hidden_states, all_attentions) = self.roberta.forward_t(input_ids, mask, token_type_ids, position_ids,
|
) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
|
||||||
input_embeds, &None, &None, train).unwrap();
|
let (hidden_state, _, all_hidden_states, all_attentions) = self
|
||||||
|
.roberta
|
||||||
|
.forward_t(
|
||||||
|
input_ids,
|
||||||
|
mask,
|
||||||
|
token_type_ids,
|
||||||
|
position_ids,
|
||||||
|
input_embeds,
|
||||||
|
&None,
|
||||||
|
&None,
|
||||||
|
train,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
let sequence_output = hidden_state.apply_t(&self.dropout, train).apply(&self.classifier);
|
let sequence_output = hidden_state
|
||||||
|
.apply_t(&self.dropout, train)
|
||||||
|
.apply(&self.classifier);
|
||||||
(sequence_output, all_hidden_states, all_attentions)
|
(sequence_output, all_hidden_states, all_attentions)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -576,10 +704,10 @@ impl RobertaForQuestionAnswering {
|
|||||||
///
|
///
|
||||||
/// ```no_run
|
/// ```no_run
|
||||||
/// use rust_bert::bert::BertConfig;
|
/// use rust_bert::bert::BertConfig;
|
||||||
/// use tch::{nn, Device};
|
/// use rust_bert::roberta::RobertaForQuestionAnswering;
|
||||||
/// use rust_bert::Config;
|
/// use rust_bert::Config;
|
||||||
/// use std::path::Path;
|
/// use std::path::Path;
|
||||||
/// use rust_bert::roberta::RobertaForQuestionAnswering;
|
/// use tch::{nn, Device};
|
||||||
///
|
///
|
||||||
/// let config_path = Path::new("path/to/config.json");
|
/// let config_path = Path::new("path/to/config.json");
|
||||||
/// let device = Device::Cpu;
|
/// let device = Device::Cpu;
|
||||||
@ -587,13 +715,20 @@ impl RobertaForQuestionAnswering {
|
|||||||
/// let config = BertConfig::from_file(config_path);
|
/// let config = BertConfig::from_file(config_path);
|
||||||
/// let roberta = RobertaForQuestionAnswering::new(&(&p.root() / "roberta"), &config);
|
/// let roberta = RobertaForQuestionAnswering::new(&(&p.root() / "roberta"), &config);
|
||||||
/// ```
|
/// ```
|
||||||
///
|
|
||||||
pub fn new(p: &nn::Path, config: &BertConfig) -> RobertaForQuestionAnswering {
|
pub fn new(p: &nn::Path, config: &BertConfig) -> RobertaForQuestionAnswering {
|
||||||
let roberta = BertModel::<RobertaEmbeddings>::new(&(p / "roberta"), config);
|
let roberta = BertModel::<RobertaEmbeddings>::new(&(p / "roberta"), config);
|
||||||
let num_labels = 2;
|
let num_labels = 2;
|
||||||
let qa_outputs = nn::linear(p / "qa_outputs", config.hidden_size, num_labels, Default::default());
|
let qa_outputs = nn::linear(
|
||||||
|
p / "qa_outputs",
|
||||||
|
config.hidden_size,
|
||||||
|
num_labels,
|
||||||
|
Default::default(),
|
||||||
|
);
|
||||||
|
|
||||||
RobertaForQuestionAnswering { roberta, qa_outputs }
|
RobertaForQuestionAnswering {
|
||||||
|
roberta,
|
||||||
|
qa_outputs,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Forward pass through the model
|
/// Forward pass through the model
|
||||||
@ -617,45 +752,58 @@ impl RobertaForQuestionAnswering {
|
|||||||
/// # Example
|
/// # Example
|
||||||
///
|
///
|
||||||
/// ```no_run
|
/// ```no_run
|
||||||
///# use rust_bert::bert::BertConfig;
|
/// # use rust_bert::bert::BertConfig;
|
||||||
///# use tch::{nn, Device, Tensor, no_grad};
|
/// # use tch::{nn, Device, Tensor, no_grad};
|
||||||
///# use rust_bert::Config;
|
/// # use rust_bert::Config;
|
||||||
///# use std::path::Path;
|
/// # use std::path::Path;
|
||||||
///# use tch::kind::Kind::Int64;
|
/// # use tch::kind::Kind::Int64;
|
||||||
/// use rust_bert::roberta::RobertaForQuestionAnswering;
|
/// use rust_bert::roberta::RobertaForQuestionAnswering;
|
||||||
///# let config_path = Path::new("path/to/config.json");
|
/// # let config_path = Path::new("path/to/config.json");
|
||||||
///# let vocab_path = Path::new("path/to/vocab.txt");
|
/// # let vocab_path = Path::new("path/to/vocab.txt");
|
||||||
///# let device = Device::Cpu;
|
/// # let device = Device::Cpu;
|
||||||
///# let vs = nn::VarStore::new(device);
|
/// # let vs = nn::VarStore::new(device);
|
||||||
///# let config = BertConfig::from_file(config_path);
|
/// # let config = BertConfig::from_file(config_path);
|
||||||
///# let roberta_model = RobertaForQuestionAnswering::new(&vs.root(), &config);
|
/// # let roberta_model = RobertaForQuestionAnswering::new(&vs.root(), &config);
|
||||||
/// let (batch_size, sequence_length) = (64, 128);
|
/// let (batch_size, sequence_length) = (64, 128);
|
||||||
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
|
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
|
||||||
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
|
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
|
||||||
/// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
|
/// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
|
||||||
/// let position_ids = Tensor::arange(sequence_length, (Int64, device)).expand(&[batch_size, sequence_length], true);
|
/// let position_ids = Tensor::arange(sequence_length, (Int64, device))
|
||||||
///
|
/// .expand(&[batch_size, sequence_length], true);
|
||||||
/// let (start_scores, end_scores, all_hidden_states, all_attentions) = no_grad(|| {
|
|
||||||
/// roberta_model
|
|
||||||
/// .forward_t(Some(input_tensor),
|
|
||||||
/// Some(mask),
|
|
||||||
/// Some(token_type_ids),
|
|
||||||
/// Some(position_ids),
|
|
||||||
/// None,
|
|
||||||
/// false)
|
|
||||||
/// });
|
|
||||||
///
|
///
|
||||||
|
/// let (start_scores, end_scores, all_hidden_states, all_attentions) = no_grad(|| {
|
||||||
|
/// roberta_model.forward_t(
|
||||||
|
/// Some(input_tensor),
|
||||||
|
/// Some(mask),
|
||||||
|
/// Some(token_type_ids),
|
||||||
|
/// Some(position_ids),
|
||||||
|
/// None,
|
||||||
|
/// false,
|
||||||
|
/// )
|
||||||
|
/// });
|
||||||
/// ```
|
/// ```
|
||||||
///
|
pub fn forward_t(
|
||||||
pub fn forward_t(&self,
|
&self,
|
||||||
input_ids: Option<Tensor>,
|
input_ids: Option<Tensor>,
|
||||||
mask: Option<Tensor>,
|
mask: Option<Tensor>,
|
||||||
token_type_ids: Option<Tensor>,
|
token_type_ids: Option<Tensor>,
|
||||||
position_ids: Option<Tensor>,
|
position_ids: Option<Tensor>,
|
||||||
input_embeds: Option<Tensor>,
|
input_embeds: Option<Tensor>,
|
||||||
train: bool) -> (Tensor, Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
|
train: bool,
|
||||||
let (hidden_state, _, all_hidden_states, all_attentions) = self.roberta.forward_t(input_ids, mask, token_type_ids, position_ids,
|
) -> (Tensor, Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
|
||||||
input_embeds, &None, &None, train).unwrap();
|
let (hidden_state, _, all_hidden_states, all_attentions) = self
|
||||||
|
.roberta
|
||||||
|
.forward_t(
|
||||||
|
input_ids,
|
||||||
|
mask,
|
||||||
|
token_type_ids,
|
||||||
|
position_ids,
|
||||||
|
input_embeds,
|
||||||
|
&None,
|
||||||
|
&None,
|
||||||
|
train,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
let sequence_output = hidden_state.apply(&self.qa_outputs);
|
let sequence_output = hidden_state.apply(&self.qa_outputs);
|
||||||
let logits = sequence_output.split(1, -1);
|
let logits = sequence_output.split(1, -1);
|
||||||
@ -665,4 +813,4 @@ impl RobertaForQuestionAnswering {
|
|||||||
|
|
||||||
(start_logits, end_logits, all_hidden_states, all_attentions)
|
(start_logits, end_logits, all_hidden_states, all_attentions)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
351
tests/albert.rs
351
tests/albert.rs
@ -1,67 +1,77 @@
|
|||||||
extern crate failure;
|
|
||||||
extern crate dirs;
|
extern crate dirs;
|
||||||
|
extern crate failure;
|
||||||
|
|
||||||
use tch::{Device, nn, Tensor, no_grad};
|
use rust_bert::albert::{
|
||||||
use rust_tokenizers::{TruncationStrategy, Tokenizer, Vocab, AlbertTokenizer};
|
AlbertConfig, AlbertConfigResources, AlbertForMaskedLM, AlbertForMultipleChoice,
|
||||||
|
AlbertForQuestionAnswering, AlbertForSequenceClassification, AlbertForTokenClassification,
|
||||||
|
AlbertModelResources, AlbertVocabResources,
|
||||||
|
};
|
||||||
|
use rust_bert::resources::{download_resource, RemoteResource, Resource};
|
||||||
use rust_bert::Config;
|
use rust_bert::Config;
|
||||||
use rust_bert::resources::{Resource, RemoteResource, download_resource};
|
use rust_tokenizers::{AlbertTokenizer, Tokenizer, TruncationStrategy, Vocab};
|
||||||
use rust_bert::albert::{AlbertConfigResources, AlbertVocabResources, AlbertModelResources, AlbertConfig, AlbertForMaskedLM, AlbertForSequenceClassification, AlbertForMultipleChoice, AlbertForTokenClassification, AlbertForQuestionAnswering};
|
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
use tch::{nn, no_grad, Device, Tensor};
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn albert_masked_lm() -> failure::Fallible<()> {
|
fn albert_masked_lm() -> failure::Fallible<()> {
|
||||||
// Resources paths
|
// Resources paths
|
||||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(AlbertConfigResources::ALBERT_BASE_V2));
|
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(AlbertVocabResources::ALBERT_BASE_V2));
|
AlbertConfigResources::ALBERT_BASE_V2,
|
||||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(AlbertModelResources::ALBERT_BASE_V2));
|
));
|
||||||
|
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||||
|
AlbertVocabResources::ALBERT_BASE_V2,
|
||||||
|
));
|
||||||
|
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||||
|
AlbertModelResources::ALBERT_BASE_V2,
|
||||||
|
));
|
||||||
let config_path = download_resource(&config_resource)?;
|
let config_path = download_resource(&config_resource)?;
|
||||||
let vocab_path = download_resource(&vocab_resource)?;
|
let vocab_path = download_resource(&vocab_resource)?;
|
||||||
let weights_path = download_resource(&weights_resource)?;
|
let weights_path = download_resource(&weights_resource)?;
|
||||||
|
|
||||||
// Set-up masked LM model
|
// Set-up masked LM model
|
||||||
let device = Device::Cpu;
|
let device = Device::Cpu;
|
||||||
let mut vs = nn::VarStore::new(device);
|
let mut vs = nn::VarStore::new(device);
|
||||||
let tokenizer: AlbertTokenizer = AlbertTokenizer::from_file(vocab_path.to_str().unwrap(), true, false);
|
let tokenizer: AlbertTokenizer =
|
||||||
|
AlbertTokenizer::from_file(vocab_path.to_str().unwrap(), true, false);
|
||||||
let config = AlbertConfig::from_file(config_path);
|
let config = AlbertConfig::from_file(config_path);
|
||||||
let albert_model = AlbertForMaskedLM::new(&vs.root(), &config);
|
let albert_model = AlbertForMaskedLM::new(&vs.root(), &config);
|
||||||
vs.load(weights_path)?;
|
vs.load(weights_path)?;
|
||||||
|
|
||||||
// Define input
|
// Define input
|
||||||
let input = ["Looks like one [MASK] is missing", "It\'s like comparing [MASK] to apples"];
|
let input = [
|
||||||
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
"Looks like one [MASK] is missing",
|
||||||
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
|
"It\'s like comparing [MASK] to apples",
|
||||||
let tokenized_input = tokenized_input.
|
];
|
||||||
iter().
|
let tokenized_input =
|
||||||
map(|input| input.token_ids.clone()).
|
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
||||||
map(|mut input| {
|
let max_len = tokenized_input
|
||||||
|
.iter()
|
||||||
|
.map(|input| input.token_ids.len())
|
||||||
|
.max()
|
||||||
|
.unwrap();
|
||||||
|
let tokenized_input = tokenized_input
|
||||||
|
.iter()
|
||||||
|
.map(|input| input.token_ids.clone())
|
||||||
|
.map(|mut input| {
|
||||||
input.extend(vec![0; max_len - input.len()]);
|
input.extend(vec![0; max_len - input.len()]);
|
||||||
input
|
input
|
||||||
}).
|
})
|
||||||
map(|input|
|
.map(|input| Tensor::of_slice(&(input)))
|
||||||
Tensor::of_slice(&(input))).
|
.collect::<Vec<_>>();
|
||||||
collect::<Vec<_>>();
|
|
||||||
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
|
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
|
||||||
|
|
||||||
// Forward pass
|
// Forward pass
|
||||||
let (output, _, _) = no_grad(|| {
|
let (output, _, _) =
|
||||||
albert_model
|
no_grad(|| albert_model.forward_t(Some(input_tensor), None, None, None, None, false));
|
||||||
.forward_t(Some(input_tensor),
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
false)
|
|
||||||
});
|
|
||||||
|
|
||||||
// Print masked tokens
|
// Print masked tokens
|
||||||
let index_1 = output.get(0).get(4).argmax(0, false);
|
let index_1 = output.get(0).get(4).argmax(0, false);
|
||||||
let index_2 = output.get(1).get(6).argmax(0, false);
|
let index_2 = output.get(1).get(6).argmax(0, false);
|
||||||
let word_1 = tokenizer.vocab().id_to_token(&index_1.int64_value(&[]));
|
let word_1 = tokenizer.vocab().id_to_token(&index_1.int64_value(&[]));
|
||||||
let word_2 = tokenizer.vocab().id_to_token(&index_2.int64_value(&[]));
|
let word_2 = tokenizer.vocab().id_to_token(&index_2.int64_value(&[]));
|
||||||
|
|
||||||
assert_eq!("▁them", word_1); // Outputs "_them" : "Looks like one [them] is missing (? this is identical with the original implementation)"
|
assert_eq!("▁them", word_1); // Outputs "_them" : "Looks like one [them] is missing (? this is identical with the original implementation)"
|
||||||
assert_eq!("▁grapes", word_2);// Outputs "grapes" : "It\'s like comparing [grapes] to apples"
|
assert_eq!("▁grapes", word_2); // Outputs "grapes" : "It\'s like comparing [grapes] to apples"
|
||||||
assert!((output.double_value(&[0, 0, 0]) - 4.6143).abs() < 1e-4);
|
assert!((output.double_value(&[0, 0, 0]) - 4.6143).abs() < 1e-4);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@ -69,15 +79,20 @@ fn albert_masked_lm() -> failure::Fallible<()> {
|
|||||||
#[test]
|
#[test]
|
||||||
fn albert_for_sequence_classification() -> failure::Fallible<()> {
|
fn albert_for_sequence_classification() -> failure::Fallible<()> {
|
||||||
// Resources paths
|
// Resources paths
|
||||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(AlbertConfigResources::ALBERT_BASE_V2));
|
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(AlbertVocabResources::ALBERT_BASE_V2));
|
AlbertConfigResources::ALBERT_BASE_V2,
|
||||||
|
));
|
||||||
|
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||||
|
AlbertVocabResources::ALBERT_BASE_V2,
|
||||||
|
));
|
||||||
let config_path = download_resource(&config_resource)?;
|
let config_path = download_resource(&config_resource)?;
|
||||||
let vocab_path = download_resource(&vocab_resource)?;
|
let vocab_path = download_resource(&vocab_resource)?;
|
||||||
|
|
||||||
// Set-up model
|
// Set-up model
|
||||||
let device = Device::Cpu;
|
let device = Device::Cpu;
|
||||||
let vs = nn::VarStore::new(device);
|
let vs = nn::VarStore::new(device);
|
||||||
let tokenizer: AlbertTokenizer = AlbertTokenizer::from_file(vocab_path.to_str().unwrap(), true, false);
|
let tokenizer: AlbertTokenizer =
|
||||||
|
AlbertTokenizer::from_file(vocab_path.to_str().unwrap(), true, false);
|
||||||
let mut config = AlbertConfig::from_file(config_path);
|
let mut config = AlbertConfig::from_file(config_path);
|
||||||
let mut dummy_label_mapping = HashMap::new();
|
let mut dummy_label_mapping = HashMap::new();
|
||||||
dummy_label_mapping.insert(0, String::from("Positive"));
|
dummy_label_mapping.insert(0, String::from("Positive"));
|
||||||
@ -88,37 +103,42 @@ fn albert_for_sequence_classification() -> failure::Fallible<()> {
|
|||||||
config.output_hidden_states = Some(true);
|
config.output_hidden_states = Some(true);
|
||||||
let albert_model = AlbertForSequenceClassification::new(&vs.root(), &config);
|
let albert_model = AlbertForSequenceClassification::new(&vs.root(), &config);
|
||||||
|
|
||||||
|
// Define input
|
||||||
// Define input
|
let input = [
|
||||||
let input = ["Looks like one thing is missing", "It\'s like comparing oranges to apples"];
|
"Looks like one thing is missing",
|
||||||
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
"It\'s like comparing oranges to apples",
|
||||||
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
|
];
|
||||||
let tokenized_input = tokenized_input.
|
let tokenized_input =
|
||||||
iter().
|
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
||||||
map(|input| input.token_ids.clone()).
|
let max_len = tokenized_input
|
||||||
map(|mut input| {
|
.iter()
|
||||||
|
.map(|input| input.token_ids.len())
|
||||||
|
.max()
|
||||||
|
.unwrap();
|
||||||
|
let tokenized_input = tokenized_input
|
||||||
|
.iter()
|
||||||
|
.map(|input| input.token_ids.clone())
|
||||||
|
.map(|mut input| {
|
||||||
input.extend(vec![0; max_len - input.len()]);
|
input.extend(vec![0; max_len - input.len()]);
|
||||||
input
|
input
|
||||||
}).
|
})
|
||||||
map(|input|
|
.map(|input| Tensor::of_slice(&(input)))
|
||||||
Tensor::of_slice(&(input))).
|
.collect::<Vec<_>>();
|
||||||
collect::<Vec<_>>();
|
|
||||||
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
|
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
|
||||||
|
|
||||||
// Forward pass
|
// Forward pass
|
||||||
let (output, all_hidden_states, all_attentions) = no_grad(|| {
|
let (output, all_hidden_states, all_attentions) =
|
||||||
albert_model
|
no_grad(|| albert_model.forward_t(Some(input_tensor), None, None, None, None, false));
|
||||||
.forward_t(Some(input_tensor),
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
false)
|
|
||||||
});
|
|
||||||
|
|
||||||
assert_eq!(output.size(), &[2, 3]);
|
assert_eq!(output.size(), &[2, 3]);
|
||||||
assert_eq!(config.num_hidden_layers as usize, all_hidden_states.unwrap().len());
|
assert_eq!(
|
||||||
assert_eq!(config.num_hidden_layers as usize, all_attentions.unwrap().len());
|
config.num_hidden_layers as usize,
|
||||||
|
all_hidden_states.unwrap().len()
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
config.num_hidden_layers as usize,
|
||||||
|
all_attentions.unwrap().len()
|
||||||
|
);
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@ -126,50 +146,66 @@ fn albert_for_sequence_classification() -> failure::Fallible<()> {
|
|||||||
#[test]
|
#[test]
|
||||||
fn albert_for_multiple_choice() -> failure::Fallible<()> {
|
fn albert_for_multiple_choice() -> failure::Fallible<()> {
|
||||||
// Resources paths
|
// Resources paths
|
||||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(AlbertConfigResources::ALBERT_BASE_V2));
|
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(AlbertVocabResources::ALBERT_BASE_V2));
|
AlbertConfigResources::ALBERT_BASE_V2,
|
||||||
|
));
|
||||||
|
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||||
|
AlbertVocabResources::ALBERT_BASE_V2,
|
||||||
|
));
|
||||||
let config_path = download_resource(&config_resource)?;
|
let config_path = download_resource(&config_resource)?;
|
||||||
let vocab_path = download_resource(&vocab_resource)?;
|
let vocab_path = download_resource(&vocab_resource)?;
|
||||||
|
|
||||||
// Set-up model
|
// Set-up model
|
||||||
let device = Device::Cpu;
|
let device = Device::Cpu;
|
||||||
let vs = nn::VarStore::new(device);
|
let vs = nn::VarStore::new(device);
|
||||||
let tokenizer: AlbertTokenizer = AlbertTokenizer::from_file(vocab_path.to_str().unwrap(), true, false);
|
let tokenizer: AlbertTokenizer =
|
||||||
|
AlbertTokenizer::from_file(vocab_path.to_str().unwrap(), true, false);
|
||||||
let mut config = AlbertConfig::from_file(config_path);
|
let mut config = AlbertConfig::from_file(config_path);
|
||||||
config.output_attentions = Some(true);
|
config.output_attentions = Some(true);
|
||||||
config.output_hidden_states = Some(true);
|
config.output_hidden_states = Some(true);
|
||||||
let albert_model = AlbertForMultipleChoice::new(&vs.root(), &config);
|
let albert_model = AlbertForMultipleChoice::new(&vs.root(), &config);
|
||||||
|
|
||||||
// Define input
|
// Define input
|
||||||
let input = ["Looks like one thing is missing", "It\'s like comparing oranges to apples"];
|
let input = [
|
||||||
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
"Looks like one thing is missing",
|
||||||
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
|
"It\'s like comparing oranges to apples",
|
||||||
let tokenized_input = tokenized_input.
|
];
|
||||||
iter().
|
let tokenized_input =
|
||||||
map(|input| input.token_ids.clone()).
|
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
||||||
map(|mut input| {
|
let max_len = tokenized_input
|
||||||
|
.iter()
|
||||||
|
.map(|input| input.token_ids.len())
|
||||||
|
.max()
|
||||||
|
.unwrap();
|
||||||
|
let tokenized_input = tokenized_input
|
||||||
|
.iter()
|
||||||
|
.map(|input| input.token_ids.clone())
|
||||||
|
.map(|mut input| {
|
||||||
input.extend(vec![0; max_len - input.len()]);
|
input.extend(vec![0; max_len - input.len()]);
|
||||||
input
|
input
|
||||||
}).
|
})
|
||||||
map(|input|
|
.map(|input| Tensor::of_slice(&(input)))
|
||||||
Tensor::of_slice(&(input))).
|
.collect::<Vec<_>>();
|
||||||
collect::<Vec<_>>();
|
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0)
|
||||||
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device).unsqueeze(0);
|
.to(device)
|
||||||
|
.unsqueeze(0);
|
||||||
|
|
||||||
// Forward pass
|
// Forward pass
|
||||||
let (output, all_hidden_states, all_attentions) = no_grad(|| {
|
let (output, all_hidden_states, all_attentions) = no_grad(|| {
|
||||||
albert_model
|
albert_model
|
||||||
.forward_t(Some(input_tensor),
|
.forward_t(Some(input_tensor), None, None, None, None, false)
|
||||||
None,
|
.unwrap()
|
||||||
None,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
false).unwrap()
|
|
||||||
});
|
});
|
||||||
|
|
||||||
assert_eq!(output.size(), &[1, 2]);
|
assert_eq!(output.size(), &[1, 2]);
|
||||||
assert_eq!(config.num_hidden_layers as usize, all_hidden_states.unwrap().len());
|
assert_eq!(
|
||||||
assert_eq!(config.num_hidden_layers as usize, all_attentions.unwrap().len());
|
config.num_hidden_layers as usize,
|
||||||
|
all_hidden_states.unwrap().len()
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
config.num_hidden_layers as usize,
|
||||||
|
all_attentions.unwrap().len()
|
||||||
|
);
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@ -177,15 +213,20 @@ fn albert_for_multiple_choice() -> failure::Fallible<()> {
|
|||||||
#[test]
|
#[test]
|
||||||
fn albert_for_token_classification() -> failure::Fallible<()> {
|
fn albert_for_token_classification() -> failure::Fallible<()> {
|
||||||
// Resources paths
|
// Resources paths
|
||||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(AlbertConfigResources::ALBERT_BASE_V2));
|
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(AlbertVocabResources::ALBERT_BASE_V2));
|
AlbertConfigResources::ALBERT_BASE_V2,
|
||||||
|
));
|
||||||
|
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||||
|
AlbertVocabResources::ALBERT_BASE_V2,
|
||||||
|
));
|
||||||
let config_path = download_resource(&config_resource)?;
|
let config_path = download_resource(&config_resource)?;
|
||||||
let vocab_path = download_resource(&vocab_resource)?;
|
let vocab_path = download_resource(&vocab_resource)?;
|
||||||
|
|
||||||
// Set-up model
|
// Set-up model
|
||||||
let device = Device::Cpu;
|
let device = Device::Cpu;
|
||||||
let vs = nn::VarStore::new(device);
|
let vs = nn::VarStore::new(device);
|
||||||
let tokenizer: AlbertTokenizer = AlbertTokenizer::from_file(vocab_path.to_str().unwrap(), true, false);
|
let tokenizer: AlbertTokenizer =
|
||||||
|
AlbertTokenizer::from_file(vocab_path.to_str().unwrap(), true, false);
|
||||||
let mut config = AlbertConfig::from_file(config_path);
|
let mut config = AlbertConfig::from_file(config_path);
|
||||||
let mut dummy_label_mapping = HashMap::new();
|
let mut dummy_label_mapping = HashMap::new();
|
||||||
dummy_label_mapping.insert(0, String::from("O"));
|
dummy_label_mapping.insert(0, String::from("O"));
|
||||||
@ -197,37 +238,42 @@ fn albert_for_token_classification() -> failure::Fallible<()> {
|
|||||||
config.output_hidden_states = Some(true);
|
config.output_hidden_states = Some(true);
|
||||||
let bert_model = AlbertForTokenClassification::new(&vs.root(), &config);
|
let bert_model = AlbertForTokenClassification::new(&vs.root(), &config);
|
||||||
|
|
||||||
|
// Define input
|
||||||
// Define input
|
let input = [
|
||||||
let input = ["Looks like one thing is missing", "It\'s like comparing oranges to apples"];
|
"Looks like one thing is missing",
|
||||||
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
"It\'s like comparing oranges to apples",
|
||||||
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
|
];
|
||||||
let tokenized_input = tokenized_input.
|
let tokenized_input =
|
||||||
iter().
|
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
||||||
map(|input| input.token_ids.clone()).
|
let max_len = tokenized_input
|
||||||
map(|mut input| {
|
.iter()
|
||||||
|
.map(|input| input.token_ids.len())
|
||||||
|
.max()
|
||||||
|
.unwrap();
|
||||||
|
let tokenized_input = tokenized_input
|
||||||
|
.iter()
|
||||||
|
.map(|input| input.token_ids.clone())
|
||||||
|
.map(|mut input| {
|
||||||
input.extend(vec![0; max_len - input.len()]);
|
input.extend(vec![0; max_len - input.len()]);
|
||||||
input
|
input
|
||||||
}).
|
})
|
||||||
map(|input|
|
.map(|input| Tensor::of_slice(&(input)))
|
||||||
Tensor::of_slice(&(input))).
|
.collect::<Vec<_>>();
|
||||||
collect::<Vec<_>>();
|
|
||||||
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
|
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
|
||||||
|
|
||||||
// Forward pass
|
// Forward pass
|
||||||
let (output, all_hidden_states, all_attentions) = no_grad(|| {
|
let (output, all_hidden_states, all_attentions) =
|
||||||
bert_model
|
no_grad(|| bert_model.forward_t(Some(input_tensor), None, None, None, None, false));
|
||||||
.forward_t(Some(input_tensor),
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
false)
|
|
||||||
});
|
|
||||||
|
|
||||||
assert_eq!(output.size(), &[2, 12, 4]);
|
assert_eq!(output.size(), &[2, 12, 4]);
|
||||||
assert_eq!(config.num_hidden_layers as usize, all_hidden_states.unwrap().len());
|
assert_eq!(
|
||||||
assert_eq!(config.num_hidden_layers as usize, all_attentions.unwrap().len());
|
config.num_hidden_layers as usize,
|
||||||
|
all_hidden_states.unwrap().len()
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
config.num_hidden_layers as usize,
|
||||||
|
all_attentions.unwrap().len()
|
||||||
|
);
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@ -235,51 +281,62 @@ fn albert_for_token_classification() -> failure::Fallible<()> {
|
|||||||
#[test]
|
#[test]
|
||||||
fn albert_for_question_answering() -> failure::Fallible<()> {
|
fn albert_for_question_answering() -> failure::Fallible<()> {
|
||||||
// Resources paths
|
// Resources paths
|
||||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(AlbertConfigResources::ALBERT_BASE_V2));
|
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(AlbertVocabResources::ALBERT_BASE_V2));
|
AlbertConfigResources::ALBERT_BASE_V2,
|
||||||
|
));
|
||||||
|
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||||
|
AlbertVocabResources::ALBERT_BASE_V2,
|
||||||
|
));
|
||||||
let config_path = download_resource(&config_resource)?;
|
let config_path = download_resource(&config_resource)?;
|
||||||
let vocab_path = download_resource(&vocab_resource)?;
|
let vocab_path = download_resource(&vocab_resource)?;
|
||||||
|
|
||||||
// Set-up model
|
// Set-up model
|
||||||
let device = Device::Cpu;
|
let device = Device::Cpu;
|
||||||
let vs = nn::VarStore::new(device);
|
let vs = nn::VarStore::new(device);
|
||||||
let tokenizer: AlbertTokenizer = AlbertTokenizer::from_file(vocab_path.to_str().unwrap(), true, false);
|
let tokenizer: AlbertTokenizer =
|
||||||
|
AlbertTokenizer::from_file(vocab_path.to_str().unwrap(), true, false);
|
||||||
let mut config = AlbertConfig::from_file(config_path);
|
let mut config = AlbertConfig::from_file(config_path);
|
||||||
config.output_attentions = Some(true);
|
config.output_attentions = Some(true);
|
||||||
config.output_hidden_states = Some(true);
|
config.output_hidden_states = Some(true);
|
||||||
let albert_model = AlbertForQuestionAnswering::new(&vs.root(), &config);
|
let albert_model = AlbertForQuestionAnswering::new(&vs.root(), &config);
|
||||||
|
|
||||||
// Define input
|
// Define input
|
||||||
let input = ["Looks like one thing is missing", "It\'s like comparing oranges to apples"];
|
let input = [
|
||||||
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
"Looks like one thing is missing",
|
||||||
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
|
"It\'s like comparing oranges to apples",
|
||||||
let tokenized_input = tokenized_input.
|
];
|
||||||
iter().
|
let tokenized_input =
|
||||||
map(|input| input.token_ids.clone()).
|
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
||||||
map(|mut input| {
|
let max_len = tokenized_input
|
||||||
|
.iter()
|
||||||
|
.map(|input| input.token_ids.len())
|
||||||
|
.max()
|
||||||
|
.unwrap();
|
||||||
|
let tokenized_input = tokenized_input
|
||||||
|
.iter()
|
||||||
|
.map(|input| input.token_ids.clone())
|
||||||
|
.map(|mut input| {
|
||||||
input.extend(vec![0; max_len - input.len()]);
|
input.extend(vec![0; max_len - input.len()]);
|
||||||
input
|
input
|
||||||
}).
|
})
|
||||||
map(|input|
|
.map(|input| Tensor::of_slice(&(input)))
|
||||||
Tensor::of_slice(&(input))).
|
.collect::<Vec<_>>();
|
||||||
collect::<Vec<_>>();
|
|
||||||
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
|
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
|
||||||
|
|
||||||
// Forward pass
|
// Forward pass
|
||||||
let (start_scores, end_scores, all_hidden_states, all_attentions) = no_grad(|| {
|
let (start_scores, end_scores, all_hidden_states, all_attentions) =
|
||||||
albert_model
|
no_grad(|| albert_model.forward_t(Some(input_tensor), None, None, None, None, false));
|
||||||
.forward_t(Some(input_tensor),
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
false)
|
|
||||||
});
|
|
||||||
|
|
||||||
assert_eq!(start_scores.size(), &[2, 12]);
|
assert_eq!(start_scores.size(), &[2, 12]);
|
||||||
assert_eq!(end_scores.size(), &[2, 12]);
|
assert_eq!(end_scores.size(), &[2, 12]);
|
||||||
assert_eq!(config.num_hidden_layers as usize, all_hidden_states.unwrap().len());
|
assert_eq!(
|
||||||
assert_eq!(config.num_hidden_layers as usize, all_attentions.unwrap().len());
|
config.num_hidden_layers as usize,
|
||||||
|
all_hidden_states.unwrap().len()
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
config.num_hidden_layers as usize,
|
||||||
|
all_attentions.unwrap().len()
|
||||||
|
);
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -1,56 +1,65 @@
|
|||||||
use tch::{Device, nn, Tensor};
|
use rust_bert::bart::{
|
||||||
use rust_tokenizers::{TruncationStrategy, Tokenizer, RobertaTokenizer};
|
BartConfig, BartConfigResources, BartMergesResources, BartModel, BartModelResources,
|
||||||
use rust_bert::Config;
|
BartVocabResources,
|
||||||
use rust_bert::bart::{BartConfig, BartConfigResources, BartVocabResources, BartModelResources, BartMergesResources, BartModel};
|
};
|
||||||
use rust_bert::pipelines::summarization::{SummarizationConfig, SummarizationModel};
|
use rust_bert::pipelines::summarization::{SummarizationConfig, SummarizationModel};
|
||||||
use rust_bert::resources::{Resource, RemoteResource, download_resource};
|
use rust_bert::resources::{download_resource, RemoteResource, Resource};
|
||||||
|
use rust_bert::Config;
|
||||||
|
use rust_tokenizers::{RobertaTokenizer, Tokenizer, TruncationStrategy};
|
||||||
|
use tch::{nn, Device, Tensor};
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
#[cfg_attr(not(feature = "all-tests"), ignore)]
|
#[cfg_attr(not(feature = "all-tests"), ignore)]
|
||||||
fn bart_lm_model() -> failure::Fallible<()> {
|
fn bart_lm_model() -> failure::Fallible<()> {
|
||||||
// Resources paths
|
// Resources paths
|
||||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(BartConfigResources::BART));
|
let config_resource =
|
||||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(BartVocabResources::BART));
|
Resource::Remote(RemoteResource::from_pretrained(BartConfigResources::BART));
|
||||||
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(BartMergesResources::BART));
|
let vocab_resource =
|
||||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(BartModelResources::BART));
|
Resource::Remote(RemoteResource::from_pretrained(BartVocabResources::BART));
|
||||||
|
let merges_resource =
|
||||||
|
Resource::Remote(RemoteResource::from_pretrained(BartMergesResources::BART));
|
||||||
|
let weights_resource =
|
||||||
|
Resource::Remote(RemoteResource::from_pretrained(BartModelResources::BART));
|
||||||
let config_path = download_resource(&config_resource)?;
|
let config_path = download_resource(&config_resource)?;
|
||||||
let vocab_path = download_resource(&vocab_resource)?;
|
let vocab_path = download_resource(&vocab_resource)?;
|
||||||
let merges_path = download_resource(&merges_resource)?;
|
let merges_path = download_resource(&merges_resource)?;
|
||||||
let weights_path = download_resource(&weights_resource)?;
|
let weights_path = download_resource(&weights_resource)?;
|
||||||
|
|
||||||
// Set-up masked LM model
|
// Set-up masked LM model
|
||||||
let device = Device::Cpu;
|
let device = Device::Cpu;
|
||||||
let mut vs = nn::VarStore::new(device);
|
let mut vs = nn::VarStore::new(device);
|
||||||
let tokenizer: RobertaTokenizer = RobertaTokenizer::from_file(vocab_path.to_str().unwrap(), merges_path.to_str().unwrap(), false);
|
let tokenizer: RobertaTokenizer = RobertaTokenizer::from_file(
|
||||||
|
vocab_path.to_str().unwrap(),
|
||||||
|
merges_path.to_str().unwrap(),
|
||||||
|
false,
|
||||||
|
);
|
||||||
let config = BartConfig::from_file(config_path);
|
let config = BartConfig::from_file(config_path);
|
||||||
let bart_model = BartModel::new(&vs.root(), &config, false);
|
let bart_model = BartModel::new(&vs.root(), &config, false);
|
||||||
vs.load(weights_path)?;
|
vs.load(weights_path)?;
|
||||||
|
|
||||||
// Define input
|
// Define input
|
||||||
let input = ["One two three four"];
|
let input = ["One two three four"];
|
||||||
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
let tokenized_input =
|
||||||
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
|
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
||||||
let tokenized_input = tokenized_input.
|
let max_len = tokenized_input
|
||||||
iter().
|
.iter()
|
||||||
map(|input| input.token_ids.clone()).
|
.map(|input| input.token_ids.len())
|
||||||
map(|mut input| {
|
.max()
|
||||||
|
.unwrap();
|
||||||
|
let tokenized_input = tokenized_input
|
||||||
|
.iter()
|
||||||
|
.map(|input| input.token_ids.clone())
|
||||||
|
.map(|mut input| {
|
||||||
input.extend(vec![0; max_len - input.len()]);
|
input.extend(vec![0; max_len - input.len()]);
|
||||||
input
|
input
|
||||||
}).
|
})
|
||||||
map(|input|
|
.map(|input| Tensor::of_slice(&(input)))
|
||||||
Tensor::of_slice(&(input))).
|
.collect::<Vec<_>>();
|
||||||
collect::<Vec<_>>();
|
|
||||||
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
|
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
|
||||||
|
|
||||||
// Forward pass
|
// Forward pass
|
||||||
let (output, encoder_outputs, _, _, _, _, _) = bart_model.forward_t(
|
let (output, encoder_outputs, _, _, _, _, _) =
|
||||||
Some(&input_tensor),
|
bart_model.forward_t(Some(&input_tensor), None, None, None, None, None, false);
|
||||||
None,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
false);
|
|
||||||
|
|
||||||
assert_eq!(output.size(), vec!(1, 6, 1024));
|
assert_eq!(output.size(), vec!(1, 6, 1024));
|
||||||
assert_eq!(encoder_outputs.size(), vec!(1, 6, 1024));
|
assert_eq!(encoder_outputs.size(), vec!(1, 6, 1024));
|
||||||
@ -58,12 +67,10 @@ fn bart_lm_model() -> failure::Fallible<()> {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
#[cfg_attr(not(feature = "all-tests"), ignore)]
|
#[cfg_attr(not(feature = "all-tests"), ignore)]
|
||||||
fn bart_summarization_greedy() -> failure::Fallible<()> {
|
fn bart_summarization_greedy() -> failure::Fallible<()> {
|
||||||
|
// Set-up masked LM model
|
||||||
// Set-up masked LM model
|
|
||||||
let summarization_config = SummarizationConfig {
|
let summarization_config = SummarizationConfig {
|
||||||
num_beams: 1,
|
num_beams: 1,
|
||||||
device: Device::Cpu,
|
device: Device::Cpu,
|
||||||
@ -93,7 +100,7 @@ on K2-18b lasts 33 Earth days. According to The Guardian, astronomers were optim
|
|||||||
telescope — scheduled for launch in 2021 — and the European Space Agency's 2028 ARIEL program, could reveal more \
|
telescope — scheduled for launch in 2021 — and the European Space Agency's 2028 ARIEL program, could reveal more \
|
||||||
about exoplanets like K2-18b."];
|
about exoplanets like K2-18b."];
|
||||||
|
|
||||||
// Credits: WikiNews, CC BY 2.5 license (https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b)
|
// Credits: WikiNews, CC BY 2.5 license (https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b)
|
||||||
let output = model.summarize(&input);
|
let output = model.summarize(&input);
|
||||||
|
|
||||||
assert_eq!(output.len(), 1);
|
assert_eq!(output.len(), 1);
|
||||||
@ -107,8 +114,7 @@ about exoplanets like K2-18b."];
|
|||||||
#[test]
|
#[test]
|
||||||
#[cfg_attr(not(feature = "all-tests"), ignore)]
|
#[cfg_attr(not(feature = "all-tests"), ignore)]
|
||||||
fn bart_summarization_beam_search() -> failure::Fallible<()> {
|
fn bart_summarization_beam_search() -> failure::Fallible<()> {
|
||||||
|
// Set-up masked LM model
|
||||||
// Set-up masked LM model
|
|
||||||
let summarization_config = SummarizationConfig {
|
let summarization_config = SummarizationConfig {
|
||||||
num_beams: 3,
|
num_beams: 3,
|
||||||
device: Device::Cpu,
|
device: Device::Cpu,
|
||||||
@ -138,7 +144,7 @@ on K2-18b lasts 33 Earth days. According to The Guardian, astronomers were optim
|
|||||||
telescope — scheduled for launch in 2021 — and the European Space Agency's 2028 ARIEL program, could reveal more \
|
telescope — scheduled for launch in 2021 — and the European Space Agency's 2028 ARIEL program, could reveal more \
|
||||||
about exoplanets like K2-18b."];
|
about exoplanets like K2-18b."];
|
||||||
|
|
||||||
// Credits: WikiNews, CC BY 2.5 license (https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b)
|
// Credits: WikiNews, CC BY 2.5 license (https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b)
|
||||||
let output = model.summarize(&input);
|
let output = model.summarize(&input);
|
||||||
|
|
||||||
assert_eq!(output.len(), 1);
|
assert_eq!(output.len(), 1);
|
||||||
@ -148,4 +154,4 @@ about exoplanets like K2-18b."];
|
|||||||
star as the planet passed between it and Earth.");
|
star as the planet passed between it and Earth.");
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
356
tests/bert.rs
356
tests/bert.rs
@ -1,27 +1,32 @@
|
|||||||
extern crate failure;
|
|
||||||
extern crate dirs;
|
extern crate dirs;
|
||||||
|
extern crate failure;
|
||||||
|
|
||||||
use tch::{Device, nn, Tensor, no_grad};
|
use rust_bert::bert::{
|
||||||
use rust_tokenizers::{BertTokenizer, TruncationStrategy, Tokenizer, Vocab};
|
BertConfig, BertConfigResources, BertForMaskedLM, BertForMultipleChoice,
|
||||||
use rust_bert::Config;
|
BertForQuestionAnswering, BertForSequenceClassification, BertForTokenClassification,
|
||||||
use rust_bert::bert::{BertConfig, BertForMaskedLM, BertForSequenceClassification, BertForMultipleChoice, BertForTokenClassification, BertForQuestionAnswering,
|
BertModelResources, BertVocabResources,
|
||||||
BertConfigResources, BertVocabResources, BertModelResources};
|
};
|
||||||
use rust_bert::pipelines::ner::NERModel;
|
use rust_bert::pipelines::ner::NERModel;
|
||||||
use rust_bert::resources::{Resource, RemoteResource, download_resource};
|
use rust_bert::resources::{download_resource, RemoteResource, Resource};
|
||||||
|
use rust_bert::Config;
|
||||||
|
use rust_tokenizers::{BertTokenizer, Tokenizer, TruncationStrategy, Vocab};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
use tch::{nn, no_grad, Device, Tensor};
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn bert_masked_lm() -> failure::Fallible<()> {
|
fn bert_masked_lm() -> failure::Fallible<()> {
|
||||||
// Resources paths
|
// Resources paths
|
||||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT));
|
let config_resource =
|
||||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT));
|
Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT));
|
||||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(BertModelResources::BERT));
|
let vocab_resource =
|
||||||
|
Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT));
|
||||||
|
let weights_resource =
|
||||||
|
Resource::Remote(RemoteResource::from_pretrained(BertModelResources::BERT));
|
||||||
let config_path = download_resource(&config_resource)?;
|
let config_path = download_resource(&config_resource)?;
|
||||||
let vocab_path = download_resource(&vocab_resource)?;
|
let vocab_path = download_resource(&vocab_resource)?;
|
||||||
let weights_path = download_resource(&weights_resource)?;
|
let weights_path = download_resource(&weights_resource)?;
|
||||||
|
|
||||||
// Set-up masked LM model
|
// Set-up masked LM model
|
||||||
let device = Device::Cpu;
|
let device = Device::Cpu;
|
||||||
let mut vs = nn::VarStore::new(device);
|
let mut vs = nn::VarStore::new(device);
|
||||||
let tokenizer: BertTokenizer = BertTokenizer::from_file(vocab_path.to_str().unwrap(), true);
|
let tokenizer: BertTokenizer = BertTokenizer::from_file(vocab_path.to_str().unwrap(), true);
|
||||||
@ -29,50 +34,58 @@ fn bert_masked_lm() -> failure::Fallible<()> {
|
|||||||
let bert_model = BertForMaskedLM::new(&vs.root(), &config);
|
let bert_model = BertForMaskedLM::new(&vs.root(), &config);
|
||||||
vs.load(weights_path)?;
|
vs.load(weights_path)?;
|
||||||
|
|
||||||
// Define input
|
// Define input
|
||||||
let input = ["Looks like one thing is missing", "It\'s like comparing oranges to apples"];
|
let input = [
|
||||||
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
"Looks like one thing is missing",
|
||||||
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
|
"It\'s like comparing oranges to apples",
|
||||||
let mut tokenized_input = tokenized_input.
|
];
|
||||||
iter().
|
let tokenized_input =
|
||||||
map(|input| input.token_ids.clone()).
|
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
||||||
map(|mut input| {
|
let max_len = tokenized_input
|
||||||
|
.iter()
|
||||||
|
.map(|input| input.token_ids.len())
|
||||||
|
.max()
|
||||||
|
.unwrap();
|
||||||
|
let mut tokenized_input = tokenized_input
|
||||||
|
.iter()
|
||||||
|
.map(|input| input.token_ids.clone())
|
||||||
|
.map(|mut input| {
|
||||||
input.extend(vec![0; max_len - input.len()]);
|
input.extend(vec![0; max_len - input.len()]);
|
||||||
input
|
input
|
||||||
}).
|
})
|
||||||
collect::<Vec<_>>();
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
// Masking the token [thing] of sentence 1 and [oranges] of sentence 2
|
// Masking the token [thing] of sentence 1 and [oranges] of sentence 2
|
||||||
tokenized_input[0][4] = 103;
|
tokenized_input[0][4] = 103;
|
||||||
tokenized_input[1][6] = 103;
|
tokenized_input[1][6] = 103;
|
||||||
let tokenized_input = tokenized_input.
|
let tokenized_input = tokenized_input
|
||||||
iter().
|
.iter()
|
||||||
map(|input|
|
.map(|input| Tensor::of_slice(&(input)))
|
||||||
Tensor::of_slice(&(input))).
|
.collect::<Vec<_>>();
|
||||||
collect::<Vec<_>>();
|
|
||||||
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
|
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
|
||||||
|
|
||||||
// Forward pass
|
// Forward pass
|
||||||
let (output, _, _) = no_grad(|| {
|
let (output, _, _) = no_grad(|| {
|
||||||
bert_model
|
bert_model.forward_t(
|
||||||
.forward_t(Some(input_tensor),
|
Some(input_tensor),
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
&None,
|
&None,
|
||||||
&None,
|
&None,
|
||||||
false)
|
false,
|
||||||
|
)
|
||||||
});
|
});
|
||||||
|
|
||||||
// Print masked tokens
|
// Print masked tokens
|
||||||
let index_1 = output.get(0).get(4).argmax(0, false);
|
let index_1 = output.get(0).get(4).argmax(0, false);
|
||||||
let index_2 = output.get(1).get(6).argmax(0, false);
|
let index_2 = output.get(1).get(6).argmax(0, false);
|
||||||
let word_1 = tokenizer.vocab().id_to_token(&index_1.int64_value(&[]));
|
let word_1 = tokenizer.vocab().id_to_token(&index_1.int64_value(&[]));
|
||||||
let word_2 = tokenizer.vocab().id_to_token(&index_2.int64_value(&[]));
|
let word_2 = tokenizer.vocab().id_to_token(&index_2.int64_value(&[]));
|
||||||
|
|
||||||
assert_eq!("person", word_1); // Outputs "person" : "Looks like one [person] is missing"
|
assert_eq!("person", word_1); // Outputs "person" : "Looks like one [person] is missing"
|
||||||
assert_eq!("orange", word_2);// Outputs "pear" : "It\'s like comparing [pear] to apples"
|
assert_eq!("orange", word_2); // Outputs "pear" : "It\'s like comparing [pear] to apples"
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@ -80,12 +93,14 @@ fn bert_masked_lm() -> failure::Fallible<()> {
|
|||||||
#[test]
|
#[test]
|
||||||
fn bert_for_sequence_classification() -> failure::Fallible<()> {
|
fn bert_for_sequence_classification() -> failure::Fallible<()> {
|
||||||
// Resources paths
|
// Resources paths
|
||||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT));
|
let config_resource =
|
||||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT));
|
Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT));
|
||||||
|
let vocab_resource =
|
||||||
|
Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT));
|
||||||
let config_path = download_resource(&config_resource)?;
|
let config_path = download_resource(&config_resource)?;
|
||||||
let vocab_path = download_resource(&vocab_resource)?;
|
let vocab_path = download_resource(&vocab_resource)?;
|
||||||
|
|
||||||
// Set-up model
|
// Set-up model
|
||||||
let device = Device::Cpu;
|
let device = Device::Cpu;
|
||||||
let vs = nn::VarStore::new(device);
|
let vs = nn::VarStore::new(device);
|
||||||
let tokenizer: BertTokenizer = BertTokenizer::from_file(vocab_path.to_str().unwrap(), true);
|
let tokenizer: BertTokenizer = BertTokenizer::from_file(vocab_path.to_str().unwrap(), true);
|
||||||
@ -99,37 +114,42 @@ fn bert_for_sequence_classification() -> failure::Fallible<()> {
|
|||||||
config.output_hidden_states = Some(true);
|
config.output_hidden_states = Some(true);
|
||||||
let bert_model = BertForSequenceClassification::new(&vs.root(), &config);
|
let bert_model = BertForSequenceClassification::new(&vs.root(), &config);
|
||||||
|
|
||||||
|
// Define input
|
||||||
// Define input
|
let input = [
|
||||||
let input = ["Looks like one thing is missing", "It\'s like comparing oranges to apples"];
|
"Looks like one thing is missing",
|
||||||
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
"It\'s like comparing oranges to apples",
|
||||||
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
|
];
|
||||||
let tokenized_input = tokenized_input.
|
let tokenized_input =
|
||||||
iter().
|
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
||||||
map(|input| input.token_ids.clone()).
|
let max_len = tokenized_input
|
||||||
map(|mut input| {
|
.iter()
|
||||||
|
.map(|input| input.token_ids.len())
|
||||||
|
.max()
|
||||||
|
.unwrap();
|
||||||
|
let tokenized_input = tokenized_input
|
||||||
|
.iter()
|
||||||
|
.map(|input| input.token_ids.clone())
|
||||||
|
.map(|mut input| {
|
||||||
input.extend(vec![0; max_len - input.len()]);
|
input.extend(vec![0; max_len - input.len()]);
|
||||||
input
|
input
|
||||||
}).
|
})
|
||||||
map(|input|
|
.map(|input| Tensor::of_slice(&(input)))
|
||||||
Tensor::of_slice(&(input))).
|
.collect::<Vec<_>>();
|
||||||
collect::<Vec<_>>();
|
|
||||||
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
|
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
|
||||||
|
|
||||||
// Forward pass
|
// Forward pass
|
||||||
let (output, all_hidden_states, all_attentions) = no_grad(|| {
|
let (output, all_hidden_states, all_attentions) =
|
||||||
bert_model
|
no_grad(|| bert_model.forward_t(Some(input_tensor), None, None, None, None, false));
|
||||||
.forward_t(Some(input_tensor),
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
false)
|
|
||||||
});
|
|
||||||
|
|
||||||
assert_eq!(output.size(), &[2, 3]);
|
assert_eq!(output.size(), &[2, 3]);
|
||||||
assert_eq!(config.num_hidden_layers as usize, all_hidden_states.unwrap().len());
|
assert_eq!(
|
||||||
assert_eq!(config.num_hidden_layers as usize, all_attentions.unwrap().len());
|
config.num_hidden_layers as usize,
|
||||||
|
all_hidden_states.unwrap().len()
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
config.num_hidden_layers as usize,
|
||||||
|
all_attentions.unwrap().len()
|
||||||
|
);
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@ -137,12 +157,14 @@ fn bert_for_sequence_classification() -> failure::Fallible<()> {
|
|||||||
#[test]
|
#[test]
|
||||||
fn bert_for_multiple_choice() -> failure::Fallible<()> {
|
fn bert_for_multiple_choice() -> failure::Fallible<()> {
|
||||||
// Resources paths
|
// Resources paths
|
||||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT));
|
let config_resource =
|
||||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT));
|
Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT));
|
||||||
|
let vocab_resource =
|
||||||
|
Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT));
|
||||||
let config_path = download_resource(&config_resource)?;
|
let config_path = download_resource(&config_resource)?;
|
||||||
let vocab_path = download_resource(&vocab_resource)?;
|
let vocab_path = download_resource(&vocab_resource)?;
|
||||||
|
|
||||||
// Set-up model
|
// Set-up model
|
||||||
let device = Device::Cpu;
|
let device = Device::Cpu;
|
||||||
let vs = nn::VarStore::new(device);
|
let vs = nn::VarStore::new(device);
|
||||||
let tokenizer: BertTokenizer = BertTokenizer::from_file(vocab_path.to_str().unwrap(), true);
|
let tokenizer: BertTokenizer = BertTokenizer::from_file(vocab_path.to_str().unwrap(), true);
|
||||||
@ -151,35 +173,44 @@ fn bert_for_multiple_choice() -> failure::Fallible<()> {
|
|||||||
config.output_hidden_states = Some(true);
|
config.output_hidden_states = Some(true);
|
||||||
let bert_model = BertForMultipleChoice::new(&vs.root(), &config);
|
let bert_model = BertForMultipleChoice::new(&vs.root(), &config);
|
||||||
|
|
||||||
// Define input
|
// Define input
|
||||||
let input = ["Looks like one thing is missing", "It\'s like comparing oranges to apples"];
|
let input = [
|
||||||
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
"Looks like one thing is missing",
|
||||||
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
|
"It\'s like comparing oranges to apples",
|
||||||
let tokenized_input = tokenized_input.
|
];
|
||||||
iter().
|
let tokenized_input =
|
||||||
map(|input| input.token_ids.clone()).
|
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
||||||
map(|mut input| {
|
let max_len = tokenized_input
|
||||||
|
.iter()
|
||||||
|
.map(|input| input.token_ids.len())
|
||||||
|
.max()
|
||||||
|
.unwrap();
|
||||||
|
let tokenized_input = tokenized_input
|
||||||
|
.iter()
|
||||||
|
.map(|input| input.token_ids.clone())
|
||||||
|
.map(|mut input| {
|
||||||
input.extend(vec![0; max_len - input.len()]);
|
input.extend(vec![0; max_len - input.len()]);
|
||||||
input
|
input
|
||||||
}).
|
})
|
||||||
map(|input|
|
.map(|input| Tensor::of_slice(&(input)))
|
||||||
Tensor::of_slice(&(input))).
|
.collect::<Vec<_>>();
|
||||||
collect::<Vec<_>>();
|
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0)
|
||||||
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device).unsqueeze(0);
|
.to(device)
|
||||||
|
.unsqueeze(0);
|
||||||
|
|
||||||
// Forward pass
|
// Forward pass
|
||||||
let (output, all_hidden_states, all_attentions) = no_grad(|| {
|
let (output, all_hidden_states, all_attentions) =
|
||||||
bert_model
|
no_grad(|| bert_model.forward_t(input_tensor, None, None, None, false));
|
||||||
.forward_t(input_tensor,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
false)
|
|
||||||
});
|
|
||||||
|
|
||||||
assert_eq!(output.size(), &[1, 2]);
|
assert_eq!(output.size(), &[1, 2]);
|
||||||
assert_eq!(config.num_hidden_layers as usize, all_hidden_states.unwrap().len());
|
assert_eq!(
|
||||||
assert_eq!(config.num_hidden_layers as usize, all_attentions.unwrap().len());
|
config.num_hidden_layers as usize,
|
||||||
|
all_hidden_states.unwrap().len()
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
config.num_hidden_layers as usize,
|
||||||
|
all_attentions.unwrap().len()
|
||||||
|
);
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@ -187,12 +218,14 @@ fn bert_for_multiple_choice() -> failure::Fallible<()> {
|
|||||||
#[test]
|
#[test]
|
||||||
fn bert_for_token_classification() -> failure::Fallible<()> {
|
fn bert_for_token_classification() -> failure::Fallible<()> {
|
||||||
// Resources paths
|
// Resources paths
|
||||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT));
|
let config_resource =
|
||||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT));
|
Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT));
|
||||||
|
let vocab_resource =
|
||||||
|
Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT));
|
||||||
let config_path = download_resource(&config_resource)?;
|
let config_path = download_resource(&config_resource)?;
|
||||||
let vocab_path = download_resource(&vocab_resource)?;
|
let vocab_path = download_resource(&vocab_resource)?;
|
||||||
|
|
||||||
// Set-up model
|
// Set-up model
|
||||||
let device = Device::Cpu;
|
let device = Device::Cpu;
|
||||||
let vs = nn::VarStore::new(device);
|
let vs = nn::VarStore::new(device);
|
||||||
let tokenizer: BertTokenizer = BertTokenizer::from_file(vocab_path.to_str().unwrap(), true);
|
let tokenizer: BertTokenizer = BertTokenizer::from_file(vocab_path.to_str().unwrap(), true);
|
||||||
@ -207,37 +240,42 @@ fn bert_for_token_classification() -> failure::Fallible<()> {
|
|||||||
config.output_hidden_states = Some(true);
|
config.output_hidden_states = Some(true);
|
||||||
let bert_model = BertForTokenClassification::new(&vs.root(), &config);
|
let bert_model = BertForTokenClassification::new(&vs.root(), &config);
|
||||||
|
|
||||||
|
// Define input
|
||||||
// Define input
|
let input = [
|
||||||
let input = ["Looks like one thing is missing", "It\'s like comparing oranges to apples"];
|
"Looks like one thing is missing",
|
||||||
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
"It\'s like comparing oranges to apples",
|
||||||
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
|
];
|
||||||
let tokenized_input = tokenized_input.
|
let tokenized_input =
|
||||||
iter().
|
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
||||||
map(|input| input.token_ids.clone()).
|
let max_len = tokenized_input
|
||||||
map(|mut input| {
|
.iter()
|
||||||
|
.map(|input| input.token_ids.len())
|
||||||
|
.max()
|
||||||
|
.unwrap();
|
||||||
|
let tokenized_input = tokenized_input
|
||||||
|
.iter()
|
||||||
|
.map(|input| input.token_ids.clone())
|
||||||
|
.map(|mut input| {
|
||||||
input.extend(vec![0; max_len - input.len()]);
|
input.extend(vec![0; max_len - input.len()]);
|
||||||
input
|
input
|
||||||
}).
|
})
|
||||||
map(|input|
|
.map(|input| Tensor::of_slice(&(input)))
|
||||||
Tensor::of_slice(&(input))).
|
.collect::<Vec<_>>();
|
||||||
collect::<Vec<_>>();
|
|
||||||
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
|
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
|
||||||
|
|
||||||
// Forward pass
|
// Forward pass
|
||||||
let (output, all_hidden_states, all_attentions) = no_grad(|| {
|
let (output, all_hidden_states, all_attentions) =
|
||||||
bert_model
|
no_grad(|| bert_model.forward_t(Some(input_tensor), None, None, None, None, false));
|
||||||
.forward_t(Some(input_tensor),
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
false)
|
|
||||||
});
|
|
||||||
|
|
||||||
assert_eq!(output.size(), &[2, 11, 4]);
|
assert_eq!(output.size(), &[2, 11, 4]);
|
||||||
assert_eq!(config.num_hidden_layers as usize, all_hidden_states.unwrap().len());
|
assert_eq!(
|
||||||
assert_eq!(config.num_hidden_layers as usize, all_attentions.unwrap().len());
|
config.num_hidden_layers as usize,
|
||||||
|
all_hidden_states.unwrap().len()
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
config.num_hidden_layers as usize,
|
||||||
|
all_attentions.unwrap().len()
|
||||||
|
);
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@ -245,12 +283,14 @@ fn bert_for_token_classification() -> failure::Fallible<()> {
|
|||||||
#[test]
|
#[test]
|
||||||
fn bert_for_question_answering() -> failure::Fallible<()> {
|
fn bert_for_question_answering() -> failure::Fallible<()> {
|
||||||
// Resources paths
|
// Resources paths
|
||||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT));
|
let config_resource =
|
||||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT));
|
Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT));
|
||||||
|
let vocab_resource =
|
||||||
|
Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT));
|
||||||
let config_path = download_resource(&config_resource)?;
|
let config_path = download_resource(&config_resource)?;
|
||||||
let vocab_path = download_resource(&vocab_resource)?;
|
let vocab_path = download_resource(&vocab_resource)?;
|
||||||
|
|
||||||
// Set-up model
|
// Set-up model
|
||||||
let device = Device::Cpu;
|
let device = Device::Cpu;
|
||||||
let vs = nn::VarStore::new(device);
|
let vs = nn::VarStore::new(device);
|
||||||
let tokenizer: BertTokenizer = BertTokenizer::from_file(vocab_path.to_str().unwrap(), true);
|
let tokenizer: BertTokenizer = BertTokenizer::from_file(vocab_path.to_str().unwrap(), true);
|
||||||
@ -259,53 +299,59 @@ fn bert_for_question_answering() -> failure::Fallible<()> {
|
|||||||
config.output_hidden_states = Some(true);
|
config.output_hidden_states = Some(true);
|
||||||
let bert_model = BertForQuestionAnswering::new(&vs.root(), &config);
|
let bert_model = BertForQuestionAnswering::new(&vs.root(), &config);
|
||||||
|
|
||||||
// Define input
|
// Define input
|
||||||
let input = ["Looks like one thing is missing", "It\'s like comparing oranges to apples"];
|
let input = [
|
||||||
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
"Looks like one thing is missing",
|
||||||
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
|
"It\'s like comparing oranges to apples",
|
||||||
let tokenized_input = tokenized_input.
|
];
|
||||||
iter().
|
let tokenized_input =
|
||||||
map(|input| input.token_ids.clone()).
|
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
||||||
map(|mut input| {
|
let max_len = tokenized_input
|
||||||
|
.iter()
|
||||||
|
.map(|input| input.token_ids.len())
|
||||||
|
.max()
|
||||||
|
.unwrap();
|
||||||
|
let tokenized_input = tokenized_input
|
||||||
|
.iter()
|
||||||
|
.map(|input| input.token_ids.clone())
|
||||||
|
.map(|mut input| {
|
||||||
input.extend(vec![0; max_len - input.len()]);
|
input.extend(vec![0; max_len - input.len()]);
|
||||||
input
|
input
|
||||||
}).
|
})
|
||||||
map(|input|
|
.map(|input| Tensor::of_slice(&(input)))
|
||||||
Tensor::of_slice(&(input))).
|
.collect::<Vec<_>>();
|
||||||
collect::<Vec<_>>();
|
|
||||||
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
|
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
|
||||||
|
|
||||||
// Forward pass
|
// Forward pass
|
||||||
let (start_scores, end_scores, all_hidden_states, all_attentions) = no_grad(|| {
|
let (start_scores, end_scores, all_hidden_states, all_attentions) =
|
||||||
bert_model
|
no_grad(|| bert_model.forward_t(Some(input_tensor), None, None, None, None, false));
|
||||||
.forward_t(Some(input_tensor),
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
false)
|
|
||||||
});
|
|
||||||
|
|
||||||
assert_eq!(start_scores.size(), &[2, 11]);
|
assert_eq!(start_scores.size(), &[2, 11]);
|
||||||
assert_eq!(end_scores.size(), &[2, 11]);
|
assert_eq!(end_scores.size(), &[2, 11]);
|
||||||
assert_eq!(config.num_hidden_layers as usize, all_hidden_states.unwrap().len());
|
assert_eq!(
|
||||||
assert_eq!(config.num_hidden_layers as usize, all_attentions.unwrap().len());
|
config.num_hidden_layers as usize,
|
||||||
|
all_hidden_states.unwrap().len()
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
config.num_hidden_layers as usize,
|
||||||
|
all_attentions.unwrap().len()
|
||||||
|
);
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn bert_pre_trained_ner() -> failure::Fallible<()> {
|
fn bert_pre_trained_ner() -> failure::Fallible<()> {
|
||||||
// Set-up model
|
// Set-up model
|
||||||
let ner_model = NERModel::new(Default::default())?;
|
let ner_model = NERModel::new(Default::default())?;
|
||||||
|
|
||||||
// Define input
|
// Define input
|
||||||
let input = [
|
let input = [
|
||||||
"My name is Amy. I live in Paris.",
|
"My name is Amy. I live in Paris.",
|
||||||
"Paris is a city in France."
|
"Paris is a city in France.",
|
||||||
];
|
];
|
||||||
|
|
||||||
// Run model
|
// Run model
|
||||||
let output = ner_model.predict(&input);
|
let output = ner_model.predict(&input);
|
||||||
|
|
||||||
assert_eq!(output.len(), 4);
|
assert_eq!(output.len(), 4);
|
||||||
@ -327,4 +373,4 @@ fn bert_pre_trained_ner() -> failure::Fallible<()> {
|
|||||||
assert_eq!(output[3].label, "I-LOC");
|
assert_eq!(output[3].label, "I-LOC");
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -1,22 +1,26 @@
|
|||||||
use tch::{Device, Tensor, nn, no_grad};
|
use rust_bert::distilbert::{
|
||||||
use rust_tokenizers::preprocessing::tokenizer::base_tokenizer::{Tokenizer, TruncationStrategy};
|
DistilBertConfig, DistilBertConfigResources, DistilBertForQuestionAnswering,
|
||||||
use rust_tokenizers::bert_tokenizer::BertTokenizer;
|
DistilBertForTokenClassification, DistilBertModelMaskedLM, DistilBertModelResources,
|
||||||
use rust_tokenizers::preprocessing::vocab::base_vocab::Vocab;
|
DistilBertVocabResources,
|
||||||
use rust_bert::Config;
|
};
|
||||||
use rust_bert::distilbert::{DistilBertConfig, DistilBertModelMaskedLM, DistilBertForQuestionAnswering, DistilBertForTokenClassification, DistilBertModelResources, DistilBertConfigResources, DistilBertVocabResources};
|
use rust_bert::pipelines::question_answering::{QaInput, QuestionAnsweringModel};
|
||||||
use rust_bert::pipelines::sentiment::{SentimentModel, SentimentPolarity};
|
use rust_bert::pipelines::sentiment::{SentimentModel, SentimentPolarity};
|
||||||
use rust_bert::pipelines::question_answering::{QuestionAnsweringModel, QaInput};
|
use rust_bert::resources::{download_resource, RemoteResource, Resource};
|
||||||
use rust_bert::resources::{Resource, RemoteResource, download_resource};
|
use rust_bert::Config;
|
||||||
|
use rust_tokenizers::bert_tokenizer::BertTokenizer;
|
||||||
|
use rust_tokenizers::preprocessing::tokenizer::base_tokenizer::{Tokenizer, TruncationStrategy};
|
||||||
|
use rust_tokenizers::preprocessing::vocab::base_vocab::Vocab;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
use tch::{nn, no_grad, Device, Tensor};
|
||||||
|
|
||||||
extern crate failure;
|
extern crate failure;
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn distilbert_sentiment_classifier() -> failure::Fallible<()> {
|
fn distilbert_sentiment_classifier() -> failure::Fallible<()> {
|
||||||
// Set-up classifier
|
// Set-up classifier
|
||||||
let sentiment_classifier = SentimentModel::new(Default::default())?;
|
let sentiment_classifier = SentimentModel::new(Default::default())?;
|
||||||
|
|
||||||
// Get sentiments
|
// Get sentiments
|
||||||
let input = [
|
let input = [
|
||||||
"Probably my all-time favorite movie, a story of selflessness, sacrifice and dedication to a noble cause, but it's not preachy or boring.",
|
"Probably my all-time favorite movie, a story of selflessness, sacrifice and dedication to a noble cause, but it's not preachy or boring.",
|
||||||
"This film tried to be too many things all at once: stinging political satire, Hollywood blockbuster, sappy romantic comedy, family values promo...",
|
"This film tried to be too many things all at once: stinging political satire, Hollywood blockbuster, sappy romantic comedy, family values promo...",
|
||||||
@ -36,18 +40,23 @@ fn distilbert_sentiment_classifier() -> failure::Fallible<()> {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn distilbert_masked_lm() -> failure::Fallible<()> {
|
fn distilbert_masked_lm() -> failure::Fallible<()> {
|
||||||
// Resources paths
|
// Resources paths
|
||||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(DistilBertConfigResources::DISTIL_BERT));
|
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(DistilBertVocabResources::DISTIL_BERT));
|
DistilBertConfigResources::DISTIL_BERT,
|
||||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(DistilBertModelResources::DISTIL_BERT));
|
));
|
||||||
|
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||||
|
DistilBertVocabResources::DISTIL_BERT,
|
||||||
|
));
|
||||||
|
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||||
|
DistilBertModelResources::DISTIL_BERT,
|
||||||
|
));
|
||||||
let config_path = download_resource(&config_resource)?;
|
let config_path = download_resource(&config_resource)?;
|
||||||
let vocab_path = download_resource(&vocab_resource)?;
|
let vocab_path = download_resource(&vocab_resource)?;
|
||||||
let weights_path = download_resource(&weights_resource)?;
|
let weights_path = download_resource(&weights_resource)?;
|
||||||
|
|
||||||
// Set-up masked LM model
|
// Set-up masked LM model
|
||||||
let device = Device::cuda_if_available();
|
let device = Device::cuda_if_available();
|
||||||
let mut vs = nn::VarStore::new(device);
|
let mut vs = nn::VarStore::new(device);
|
||||||
let tokenizer: BertTokenizer = BertTokenizer::from_file(vocab_path.to_str().unwrap(), true);
|
let tokenizer: BertTokenizer = BertTokenizer::from_file(vocab_path.to_str().unwrap(), true);
|
||||||
@ -55,59 +64,68 @@ fn distilbert_masked_lm() -> failure::Fallible<()> {
|
|||||||
let distil_bert_model = DistilBertModelMaskedLM::new(&vs.root(), &config);
|
let distil_bert_model = DistilBertModelMaskedLM::new(&vs.root(), &config);
|
||||||
vs.load(weights_path)?;
|
vs.load(weights_path)?;
|
||||||
|
|
||||||
// Define input
|
// Define input
|
||||||
let input = ["Looks like one thing is missing", "It\'s like comparing oranges to apples"];
|
let input = [
|
||||||
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
"Looks like one thing is missing",
|
||||||
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
|
"It\'s like comparing oranges to apples",
|
||||||
let mut tokenized_input = tokenized_input.
|
];
|
||||||
iter().
|
let tokenized_input =
|
||||||
map(|input| input.token_ids.clone()).
|
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
||||||
map(|mut input| {
|
let max_len = tokenized_input
|
||||||
|
.iter()
|
||||||
|
.map(|input| input.token_ids.len())
|
||||||
|
.max()
|
||||||
|
.unwrap();
|
||||||
|
let mut tokenized_input = tokenized_input
|
||||||
|
.iter()
|
||||||
|
.map(|input| input.token_ids.clone())
|
||||||
|
.map(|mut input| {
|
||||||
input.extend(vec![0; max_len - input.len()]);
|
input.extend(vec![0; max_len - input.len()]);
|
||||||
input
|
input
|
||||||
}).
|
})
|
||||||
collect::<Vec<_>>();
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
// Masking the token [thing] of sentence 1 and [oranges] of sentence 2
|
// Masking the token [thing] of sentence 1 and [oranges] of sentence 2
|
||||||
tokenized_input[0][4] = 103;
|
tokenized_input[0][4] = 103;
|
||||||
tokenized_input[1][6] = 103;
|
tokenized_input[1][6] = 103;
|
||||||
let tokenized_input = tokenized_input.
|
let tokenized_input = tokenized_input
|
||||||
iter().
|
.iter()
|
||||||
map(|input|
|
.map(|input| Tensor::of_slice(&(input)))
|
||||||
Tensor::of_slice(&(input))).
|
.collect::<Vec<_>>();
|
||||||
collect::<Vec<_>>();
|
|
||||||
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
|
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
|
||||||
|
|
||||||
|
// Forward pass
|
||||||
// Forward pass
|
|
||||||
let (output, _, _) = no_grad(|| {
|
let (output, _, _) = no_grad(|| {
|
||||||
distil_bert_model
|
distil_bert_model
|
||||||
.forward_t(Some(input_tensor), None, None, false)
|
.forward_t(Some(input_tensor), None, None, false)
|
||||||
.unwrap()
|
.unwrap()
|
||||||
});
|
});
|
||||||
|
|
||||||
// Print masked tokens
|
// Print masked tokens
|
||||||
let index_1 = output.get(0).get(4).argmax(0, false);
|
let index_1 = output.get(0).get(4).argmax(0, false);
|
||||||
let index_2 = output.get(1).get(6).argmax(0, false);
|
let index_2 = output.get(1).get(6).argmax(0, false);
|
||||||
let word_1 = tokenizer.vocab().id_to_token(&index_1.int64_value(&[]));
|
let word_1 = tokenizer.vocab().id_to_token(&index_1.int64_value(&[]));
|
||||||
let word_2 = tokenizer.vocab().id_to_token(&index_2.int64_value(&[]));
|
let word_2 = tokenizer.vocab().id_to_token(&index_2.int64_value(&[]));
|
||||||
|
|
||||||
assert_eq!("person", word_1); // Outputs "person" : "Looks like one [person] is missing"
|
assert_eq!("person", word_1); // Outputs "person" : "Looks like one [person] is missing"
|
||||||
assert_eq!("pear", word_2);// Outputs "pear" : "It\'s like comparing [pear] to apples"
|
assert_eq!("pear", word_2); // Outputs "pear" : "It\'s like comparing [pear] to apples"
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn distilbert_for_question_answering() -> failure::Fallible<()> {
|
fn distilbert_for_question_answering() -> failure::Fallible<()> {
|
||||||
|
// Resources paths
|
||||||
// Resources paths
|
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(DistilBertConfigResources::DISTIL_BERT_SQUAD));
|
DistilBertConfigResources::DISTIL_BERT_SQUAD,
|
||||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(DistilBertVocabResources::DISTIL_BERT_SQUAD));
|
));
|
||||||
|
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||||
|
DistilBertVocabResources::DISTIL_BERT_SQUAD,
|
||||||
|
));
|
||||||
let config_path = download_resource(&config_resource)?;
|
let config_path = download_resource(&config_resource)?;
|
||||||
let vocab_path = download_resource(&vocab_resource)?;
|
let vocab_path = download_resource(&vocab_resource)?;
|
||||||
|
|
||||||
// Set-up masked LM model
|
// Set-up masked LM model
|
||||||
let device = Device::cuda_if_available();
|
let device = Device::cuda_if_available();
|
||||||
let vs = nn::VarStore::new(device);
|
let vs = nn::VarStore::new(device);
|
||||||
let tokenizer: BertTokenizer = BertTokenizer::from_file(vocab_path.to_str().unwrap(), true);
|
let tokenizer: BertTokenizer = BertTokenizer::from_file(vocab_path.to_str().unwrap(), true);
|
||||||
@ -116,23 +134,30 @@ fn distilbert_for_question_answering() -> failure::Fallible<()> {
|
|||||||
config.output_hidden_states = Some(true);
|
config.output_hidden_states = Some(true);
|
||||||
let distil_bert_model = DistilBertForQuestionAnswering::new(&vs.root(), &config);
|
let distil_bert_model = DistilBertForQuestionAnswering::new(&vs.root(), &config);
|
||||||
|
|
||||||
// Define input
|
// Define input
|
||||||
let input = ["Looks like one thing is missing", "It\'s like comparing oranges to apples"];
|
let input = [
|
||||||
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
"Looks like one thing is missing",
|
||||||
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
|
"It\'s like comparing oranges to apples",
|
||||||
let tokenized_input = tokenized_input.
|
];
|
||||||
iter().
|
let tokenized_input =
|
||||||
map(|input| input.token_ids.clone()).
|
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
||||||
map(|mut input| {
|
let max_len = tokenized_input
|
||||||
|
.iter()
|
||||||
|
.map(|input| input.token_ids.len())
|
||||||
|
.max()
|
||||||
|
.unwrap();
|
||||||
|
let tokenized_input = tokenized_input
|
||||||
|
.iter()
|
||||||
|
.map(|input| input.token_ids.clone())
|
||||||
|
.map(|mut input| {
|
||||||
input.extend(vec![0; max_len - input.len()]);
|
input.extend(vec![0; max_len - input.len()]);
|
||||||
input
|
input
|
||||||
}).
|
})
|
||||||
map(|input|
|
.map(|input| Tensor::of_slice(&(input)))
|
||||||
Tensor::of_slice(&(input))).
|
.collect::<Vec<_>>();
|
||||||
collect::<Vec<_>>();
|
|
||||||
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
|
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
|
||||||
|
|
||||||
// Forward pass
|
// Forward pass
|
||||||
let (start_scores, end_scores, all_hidden_states, all_attentions) = no_grad(|| {
|
let (start_scores, end_scores, all_hidden_states, all_attentions) = no_grad(|| {
|
||||||
distil_bert_model
|
distil_bert_model
|
||||||
.forward_t(Some(input_tensor), None, None, false)
|
.forward_t(Some(input_tensor), None, None, false)
|
||||||
@ -149,14 +174,17 @@ fn distilbert_for_question_answering() -> failure::Fallible<()> {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn distilbert_for_token_classification() -> failure::Fallible<()> {
|
fn distilbert_for_token_classification() -> failure::Fallible<()> {
|
||||||
|
// Resources paths
|
||||||
// Resources paths
|
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(DistilBertConfigResources::DISTIL_BERT));
|
DistilBertConfigResources::DISTIL_BERT,
|
||||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(DistilBertVocabResources::DISTIL_BERT));
|
));
|
||||||
|
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||||
|
DistilBertVocabResources::DISTIL_BERT,
|
||||||
|
));
|
||||||
let config_path = download_resource(&config_resource)?;
|
let config_path = download_resource(&config_resource)?;
|
||||||
let vocab_path = download_resource(&vocab_resource)?;
|
let vocab_path = download_resource(&vocab_resource)?;
|
||||||
|
|
||||||
// Set-up masked LM model
|
// Set-up masked LM model
|
||||||
let device = Device::cuda_if_available();
|
let device = Device::cuda_if_available();
|
||||||
let vs = nn::VarStore::new(device);
|
let vs = nn::VarStore::new(device);
|
||||||
let tokenizer: BertTokenizer = BertTokenizer::from_file(vocab_path.to_str().unwrap(), true);
|
let tokenizer: BertTokenizer = BertTokenizer::from_file(vocab_path.to_str().unwrap(), true);
|
||||||
@ -171,23 +199,30 @@ fn distilbert_for_token_classification() -> failure::Fallible<()> {
|
|||||||
config.id2label = Some(dummy_label_mapping);
|
config.id2label = Some(dummy_label_mapping);
|
||||||
let distil_bert_model = DistilBertForTokenClassification::new(&vs.root(), &config);
|
let distil_bert_model = DistilBertForTokenClassification::new(&vs.root(), &config);
|
||||||
|
|
||||||
// Define input
|
// Define input
|
||||||
let input = ["Looks like one thing is missing", "It\'s like comparing oranges to apples"];
|
let input = [
|
||||||
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
"Looks like one thing is missing",
|
||||||
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
|
"It\'s like comparing oranges to apples",
|
||||||
let tokenized_input = tokenized_input.
|
];
|
||||||
iter().
|
let tokenized_input =
|
||||||
map(|input| input.token_ids.clone()).
|
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
||||||
map(|mut input| {
|
let max_len = tokenized_input
|
||||||
|
.iter()
|
||||||
|
.map(|input| input.token_ids.len())
|
||||||
|
.max()
|
||||||
|
.unwrap();
|
||||||
|
let tokenized_input = tokenized_input
|
||||||
|
.iter()
|
||||||
|
.map(|input| input.token_ids.clone())
|
||||||
|
.map(|mut input| {
|
||||||
input.extend(vec![0; max_len - input.len()]);
|
input.extend(vec![0; max_len - input.len()]);
|
||||||
input
|
input
|
||||||
}).
|
})
|
||||||
map(|input|
|
.map(|input| Tensor::of_slice(&(input)))
|
||||||
Tensor::of_slice(&(input))).
|
.collect::<Vec<_>>();
|
||||||
collect::<Vec<_>>();
|
|
||||||
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
|
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
|
||||||
|
|
||||||
// Forward pass
|
// Forward pass
|
||||||
let (output, all_hidden_states, all_attentions) = no_grad(|| {
|
let (output, all_hidden_states, all_attentions) = no_grad(|| {
|
||||||
distil_bert_model
|
distil_bert_model
|
||||||
.forward_t(Some(input_tensor), None, None, false)
|
.forward_t(Some(input_tensor), None, None, false)
|
||||||
@ -203,15 +238,15 @@ fn distilbert_for_token_classification() -> failure::Fallible<()> {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn distilbert_question_answering() -> failure::Fallible<()> {
|
fn distilbert_question_answering() -> failure::Fallible<()> {
|
||||||
// Set-up question answering model
|
// Set-up question answering model
|
||||||
let qa_model = QuestionAnsweringModel::new(Default::default())?;
|
let qa_model = QuestionAnsweringModel::new(Default::default())?;
|
||||||
|
|
||||||
// Define input
|
// Define input
|
||||||
let question = String::from("Where does Amy live ?");
|
let question = String::from("Where does Amy live ?");
|
||||||
let context = String::from("Amy lives in Amsterdam");
|
let context = String::from("Amy lives in Amsterdam");
|
||||||
let qa_input = QaInput { question, context };
|
let qa_input = QaInput { question, context };
|
||||||
|
|
||||||
let answers = qa_model.predict(&vec!(qa_input), 1, 32);
|
let answers = qa_model.predict(&vec![qa_input], 1, 32);
|
||||||
|
|
||||||
assert_eq!(answers.len(), 1 as usize);
|
assert_eq!(answers.len(), 1 as usize);
|
||||||
assert_eq!(answers[0].len(), 1 as usize);
|
assert_eq!(answers[0].len(), 1 as usize);
|
||||||
@ -221,4 +256,4 @@ fn distilbert_question_answering() -> failure::Fallible<()> {
|
|||||||
assert_eq!(answers[0][0].answer, "Amsterdam");
|
assert_eq!(answers[0][0].answer, "Amsterdam");
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -1,73 +1,100 @@
|
|||||||
use tch::{Device, nn, Tensor};
|
use rust_bert::gpt2::{
|
||||||
use rust_tokenizers::{Gpt2Tokenizer, TruncationStrategy, Tokenizer};
|
GPT2LMHeadModel, Gpt2Config, Gpt2ConfigResources, Gpt2MergesResources, Gpt2ModelResources,
|
||||||
|
Gpt2VocabResources,
|
||||||
|
};
|
||||||
|
use rust_bert::pipelines::generation::{Cache, LMHeadModel};
|
||||||
|
use rust_bert::resources::{download_resource, RemoteResource, Resource};
|
||||||
use rust_bert::Config;
|
use rust_bert::Config;
|
||||||
use rust_bert::gpt2::{Gpt2Config, GPT2LMHeadModel, Gpt2ConfigResources, Gpt2VocabResources, Gpt2MergesResources, Gpt2ModelResources};
|
use rust_tokenizers::{Gpt2Tokenizer, Tokenizer, TruncationStrategy};
|
||||||
use rust_bert::pipelines::generation::{LMHeadModel, Cache};
|
use tch::{nn, Device, Tensor};
|
||||||
use rust_bert::resources::{Resource, download_resource, RemoteResource};
|
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn distilgpt2_lm_model() -> failure::Fallible<()> {
|
fn distilgpt2_lm_model() -> failure::Fallible<()> {
|
||||||
// Resources paths
|
// Resources paths
|
||||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::DISTIL_GPT2));
|
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::DISTIL_GPT2));
|
Gpt2ConfigResources::DISTIL_GPT2,
|
||||||
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::DISTIL_GPT2));
|
));
|
||||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::DISTIL_GPT2));
|
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||||
|
Gpt2VocabResources::DISTIL_GPT2,
|
||||||
|
));
|
||||||
|
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||||
|
Gpt2MergesResources::DISTIL_GPT2,
|
||||||
|
));
|
||||||
|
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||||
|
Gpt2ModelResources::DISTIL_GPT2,
|
||||||
|
));
|
||||||
let config_path = download_resource(&config_resource)?;
|
let config_path = download_resource(&config_resource)?;
|
||||||
let vocab_path = download_resource(&vocab_resource)?;
|
let vocab_path = download_resource(&vocab_resource)?;
|
||||||
let merges_path = download_resource(&merges_resource)?;
|
let merges_path = download_resource(&merges_resource)?;
|
||||||
let weights_path = download_resource(&weights_resource)?;
|
let weights_path = download_resource(&weights_resource)?;
|
||||||
|
|
||||||
// Set-up masked LM model
|
// Set-up masked LM model
|
||||||
let device = Device::Cpu;
|
let device = Device::Cpu;
|
||||||
let mut vs = nn::VarStore::new(device);
|
let mut vs = nn::VarStore::new(device);
|
||||||
let tokenizer: Gpt2Tokenizer = Gpt2Tokenizer::from_file(vocab_path.to_str().unwrap(), merges_path.to_str().unwrap(), false);
|
let tokenizer: Gpt2Tokenizer = Gpt2Tokenizer::from_file(
|
||||||
|
vocab_path.to_str().unwrap(),
|
||||||
|
merges_path.to_str().unwrap(),
|
||||||
|
false,
|
||||||
|
);
|
||||||
let config = Gpt2Config::from_file(config_path);
|
let config = Gpt2Config::from_file(config_path);
|
||||||
let gpt2_model = GPT2LMHeadModel::new(&vs.root(), &config);
|
let gpt2_model = GPT2LMHeadModel::new(&vs.root(), &config);
|
||||||
vs.load(weights_path)?;
|
vs.load(weights_path)?;
|
||||||
|
|
||||||
// Define input
|
// Define input
|
||||||
let input = ["One two three four five six seven eight nine ten eleven"];
|
let input = ["One two three four five six seven eight nine ten eleven"];
|
||||||
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
let tokenized_input =
|
||||||
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
|
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
||||||
let tokenized_input = tokenized_input.
|
let max_len = tokenized_input
|
||||||
iter().
|
.iter()
|
||||||
map(|input| input.token_ids.clone()).
|
.map(|input| input.token_ids.len())
|
||||||
map(|mut input| {
|
.max()
|
||||||
|
.unwrap();
|
||||||
|
let tokenized_input = tokenized_input
|
||||||
|
.iter()
|
||||||
|
.map(|input| input.token_ids.clone())
|
||||||
|
.map(|mut input| {
|
||||||
input.extend(vec![0; max_len - input.len()]);
|
input.extend(vec![0; max_len - input.len()]);
|
||||||
input
|
input
|
||||||
}).
|
})
|
||||||
map(|input|
|
.map(|input| Tensor::of_slice(&(input)))
|
||||||
Tensor::of_slice(&(input))).
|
.collect::<Vec<_>>();
|
||||||
collect::<Vec<_>>();
|
|
||||||
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
|
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
|
||||||
|
|
||||||
// Forward pass
|
// Forward pass
|
||||||
let (output, _, past, _, _) = gpt2_model.forward_t(
|
let (output, _, past, _, _) = gpt2_model
|
||||||
&Some(input_tensor),
|
.forward_t(
|
||||||
Cache::None,
|
&Some(input_tensor),
|
||||||
&None,
|
Cache::None,
|
||||||
&None,
|
&None,
|
||||||
&None,
|
&None,
|
||||||
&None,
|
&None,
|
||||||
None,
|
&None,
|
||||||
&None,
|
None,
|
||||||
false).unwrap();
|
&None,
|
||||||
|
false,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
let next_word_id = output.get(0).get(-1).argmax(-1, true).int64_value(&[0]);
|
let next_word_id = output.get(0).get(-1).argmax(-1, true).int64_value(&[0]);
|
||||||
let next_word = tokenizer.decode(vec!(next_word_id), true, true);
|
let next_word = tokenizer.decode(vec![next_word_id], true, true);
|
||||||
|
|
||||||
assert_eq!(output.size(), vec!(1, 11, 50257));
|
assert_eq!(output.size(), vec!(1, 11, 50257));
|
||||||
match past {
|
match past {
|
||||||
Cache::GPT2Cache(past) => {
|
Cache::GPT2Cache(past) => {
|
||||||
assert!(past.is_some());
|
assert!(past.is_some());
|
||||||
assert_eq!(past.as_ref().unwrap().len(), config.n_layer as usize);
|
assert_eq!(past.as_ref().unwrap().len(), config.n_layer as usize);
|
||||||
assert_eq!(past.as_ref().unwrap()[0].size(), vec!(2, 1, config.n_head, 11, 64));
|
assert_eq!(
|
||||||
|
past.as_ref().unwrap()[0].size(),
|
||||||
|
vec!(2, 1, config.n_head, 11, 64)
|
||||||
|
);
|
||||||
}
|
}
|
||||||
_ => panic!("Wrong cache returned for GPT2")
|
_ => panic!("Wrong cache returned for GPT2"),
|
||||||
}
|
}
|
||||||
assert!((output.double_value(&[0, output.size()[1] - 1, next_word_id]) - (-48.7065)).abs() < 1e-4);
|
assert!(
|
||||||
|
(output.double_value(&[0, output.size()[1] - 1, next_word_id]) - (-48.7065)).abs() < 1e-4
|
||||||
|
);
|
||||||
assert_eq!(next_word_id, 14104i64);
|
assert_eq!(next_word_id, 14104i64);
|
||||||
assert_eq!(next_word, String::from(" twelve"));
|
assert_eq!(next_word, String::from(" twelve"));
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
144
tests/electra.rs
144
tests/electra.rs
@ -1,20 +1,29 @@
|
|||||||
use rust_bert::resources::{Resource, download_resource, RemoteResource};
|
use rust_bert::electra::{
|
||||||
use rust_bert::electra::{ElectraConfigResources, ElectraVocabResources, ElectraModelResources, ElectraConfig, ElectraForMaskedLM, ElectraDiscriminator};
|
ElectraConfig, ElectraConfigResources, ElectraDiscriminator, ElectraForMaskedLM,
|
||||||
use tch::{Device, nn, Tensor, no_grad};
|
ElectraModelResources, ElectraVocabResources,
|
||||||
use rust_tokenizers::{BertTokenizer, TruncationStrategy, Tokenizer, Vocab};
|
};
|
||||||
|
use rust_bert::resources::{download_resource, RemoteResource, Resource};
|
||||||
use rust_bert::Config;
|
use rust_bert::Config;
|
||||||
|
use rust_tokenizers::{BertTokenizer, Tokenizer, TruncationStrategy, Vocab};
|
||||||
|
use tch::{nn, no_grad, Device, Tensor};
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn electra_masked_lm() -> failure::Fallible<()> {
|
fn electra_masked_lm() -> failure::Fallible<()> {
|
||||||
// Resources paths
|
// Resources paths
|
||||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(ElectraConfigResources::BASE_GENERATOR));
|
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(ElectraVocabResources::BASE_GENERATOR));
|
ElectraConfigResources::BASE_GENERATOR,
|
||||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(ElectraModelResources::BASE_GENERATOR));
|
));
|
||||||
|
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||||
|
ElectraVocabResources::BASE_GENERATOR,
|
||||||
|
));
|
||||||
|
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||||
|
ElectraModelResources::BASE_GENERATOR,
|
||||||
|
));
|
||||||
let config_path = download_resource(&config_resource)?;
|
let config_path = download_resource(&config_resource)?;
|
||||||
let vocab_path = download_resource(&vocab_resource)?;
|
let vocab_path = download_resource(&vocab_resource)?;
|
||||||
let weights_path = download_resource(&weights_resource)?;
|
let weights_path = download_resource(&weights_resource)?;
|
||||||
|
|
||||||
// Set-up masked LM model
|
// Set-up masked LM model
|
||||||
let device = Device::Cpu;
|
let device = Device::Cpu;
|
||||||
let mut vs = nn::VarStore::new(device);
|
let mut vs = nn::VarStore::new(device);
|
||||||
let tokenizer: BertTokenizer = BertTokenizer::from_file(vocab_path.to_str().unwrap(), true);
|
let tokenizer: BertTokenizer = BertTokenizer::from_file(vocab_path.to_str().unwrap(), true);
|
||||||
@ -24,60 +33,70 @@ fn electra_masked_lm() -> failure::Fallible<()> {
|
|||||||
let electra_model = ElectraForMaskedLM::new(&vs.root(), &config);
|
let electra_model = ElectraForMaskedLM::new(&vs.root(), &config);
|
||||||
vs.load(weights_path)?;
|
vs.load(weights_path)?;
|
||||||
|
|
||||||
// Define input
|
// Define input
|
||||||
let input = ["Looks like one [MASK] is missing", "It was a very nice and [MASK] day"];
|
let input = [
|
||||||
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
"Looks like one [MASK] is missing",
|
||||||
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
|
"It was a very nice and [MASK] day",
|
||||||
let tokenized_input = tokenized_input.
|
];
|
||||||
iter().
|
let tokenized_input =
|
||||||
map(|input| input.token_ids.clone()).
|
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
||||||
map(|mut input| {
|
let max_len = tokenized_input
|
||||||
|
.iter()
|
||||||
|
.map(|input| input.token_ids.len())
|
||||||
|
.max()
|
||||||
|
.unwrap();
|
||||||
|
let tokenized_input = tokenized_input
|
||||||
|
.iter()
|
||||||
|
.map(|input| input.token_ids.clone())
|
||||||
|
.map(|mut input| {
|
||||||
input.extend(vec![0; max_len - input.len()]);
|
input.extend(vec![0; max_len - input.len()]);
|
||||||
input
|
input
|
||||||
}).
|
})
|
||||||
map(|input|
|
.map(|input| Tensor::of_slice(&(input)))
|
||||||
Tensor::of_slice(&(input))).
|
.collect::<Vec<_>>();
|
||||||
collect::<Vec<_>>();
|
|
||||||
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
|
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
|
||||||
|
|
||||||
// Forward pass
|
// Forward pass
|
||||||
let (output,
|
let (output, all_hidden_states, all_attentions) =
|
||||||
all_hidden_states,
|
no_grad(|| electra_model.forward_t(Some(input_tensor), None, None, None, None, false));
|
||||||
all_attentions) = no_grad(|| {
|
|
||||||
electra_model
|
|
||||||
.forward_t(Some(input_tensor),
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
false)
|
|
||||||
});
|
|
||||||
|
|
||||||
// Decode output
|
// Decode output
|
||||||
let index_1 = output.get(0).get(4).argmax(0, false);
|
let index_1 = output.get(0).get(4).argmax(0, false);
|
||||||
let index_2 = output.get(1).get(7).argmax(0, false);
|
let index_2 = output.get(1).get(7).argmax(0, false);
|
||||||
let word_1 = tokenizer.vocab().id_to_token(&index_1.int64_value(&[]));
|
let word_1 = tokenizer.vocab().id_to_token(&index_1.int64_value(&[]));
|
||||||
let word_2 = tokenizer.vocab().id_to_token(&index_2.int64_value(&[]));
|
let word_2 = tokenizer.vocab().id_to_token(&index_2.int64_value(&[]));
|
||||||
|
|
||||||
assert_eq!(output.size(), &[2, 10, config.vocab_size]);
|
assert_eq!(output.size(), &[2, 10, config.vocab_size]);
|
||||||
assert_eq!(config.num_hidden_layers as usize, all_hidden_states.unwrap().len());
|
assert_eq!(
|
||||||
assert_eq!(config.num_hidden_layers as usize, all_attentions.unwrap().len());
|
config.num_hidden_layers as usize,
|
||||||
|
all_hidden_states.unwrap().len()
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
config.num_hidden_layers as usize,
|
||||||
|
all_attentions.unwrap().len()
|
||||||
|
);
|
||||||
assert_eq!("thing", word_1); // Outputs "person" : "Looks like one [person] is missing"
|
assert_eq!("thing", word_1); // Outputs "person" : "Looks like one [person] is missing"
|
||||||
assert_eq!("sunny", word_2);// Outputs "pear" : "It was a very nice and [sunny] day"
|
assert_eq!("sunny", word_2); // Outputs "pear" : "It was a very nice and [sunny] day"
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn electra_discriminator() -> failure::Fallible<()> {
|
fn electra_discriminator() -> failure::Fallible<()> {
|
||||||
// Resources paths
|
// Resources paths
|
||||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(ElectraConfigResources::BASE_DISCRIMINATOR));
|
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(ElectraVocabResources::BASE_DISCRIMINATOR));
|
ElectraConfigResources::BASE_DISCRIMINATOR,
|
||||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(ElectraModelResources::BASE_DISCRIMINATOR));
|
));
|
||||||
|
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||||
|
ElectraVocabResources::BASE_DISCRIMINATOR,
|
||||||
|
));
|
||||||
|
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||||
|
ElectraModelResources::BASE_DISCRIMINATOR,
|
||||||
|
));
|
||||||
let config_path = download_resource(&config_resource)?;
|
let config_path = download_resource(&config_resource)?;
|
||||||
let vocab_path = download_resource(&vocab_resource)?;
|
let vocab_path = download_resource(&vocab_resource)?;
|
||||||
let weights_path = download_resource(&weights_resource)?;
|
let weights_path = download_resource(&weights_resource)?;
|
||||||
|
|
||||||
// Set-up masked LM model
|
// Set-up masked LM model
|
||||||
let device = Device::Cpu;
|
let device = Device::Cpu;
|
||||||
let mut vs = nn::VarStore::new(device);
|
let mut vs = nn::VarStore::new(device);
|
||||||
let tokenizer: BertTokenizer = BertTokenizer::from_file(vocab_path.to_str().unwrap(), true);
|
let tokenizer: BertTokenizer = BertTokenizer::from_file(vocab_path.to_str().unwrap(), true);
|
||||||
@ -85,35 +104,34 @@ fn electra_discriminator() -> failure::Fallible<()> {
|
|||||||
let electra_model = ElectraDiscriminator::new(&vs.root(), &config);
|
let electra_model = ElectraDiscriminator::new(&vs.root(), &config);
|
||||||
vs.load(weights_path)?;
|
vs.load(weights_path)?;
|
||||||
|
|
||||||
// Define input
|
// Define input
|
||||||
let input = ["One Two Three Ten Five Six Seven Eight"];
|
let input = ["One Two Three Ten Five Six Seven Eight"];
|
||||||
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
let tokenized_input =
|
||||||
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
|
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
||||||
let encoded_input = tokenized_input.
|
let max_len = tokenized_input
|
||||||
iter().
|
.iter()
|
||||||
map(|input| input.token_ids.clone()).
|
.map(|input| input.token_ids.len())
|
||||||
map(|mut input| {
|
.max()
|
||||||
|
.unwrap();
|
||||||
|
let encoded_input = tokenized_input
|
||||||
|
.iter()
|
||||||
|
.map(|input| input.token_ids.clone())
|
||||||
|
.map(|mut input| {
|
||||||
input.extend(vec![0; max_len - input.len()]);
|
input.extend(vec![0; max_len - input.len()]);
|
||||||
input
|
input
|
||||||
}).
|
})
|
||||||
map(|input|
|
.map(|input| Tensor::of_slice(&(input)))
|
||||||
Tensor::of_slice(&(input))).
|
.collect::<Vec<_>>();
|
||||||
collect::<Vec<_>>();
|
|
||||||
let input_tensor = Tensor::stack(encoded_input.as_slice(), 0).to(device);
|
let input_tensor = Tensor::stack(encoded_input.as_slice(), 0).to(device);
|
||||||
|
|
||||||
// Forward pass
|
// Forward pass
|
||||||
let (output, _, _) = no_grad(|| {
|
let (output, _, _) =
|
||||||
electra_model
|
no_grad(|| electra_model.forward_t(Some(input_tensor), None, None, None, None, false));
|
||||||
.forward_t(Some(input_tensor),
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
false)
|
|
||||||
});
|
|
||||||
|
|
||||||
// Validate model predictions
|
// Validate model predictions
|
||||||
let expected_probabilities = vec!(0.0101, 0.0030, 0.0010, 0.0018, 0.9489, 0.0067, 0.0026, 0.0017, 0.0311, 0.0101);
|
let expected_probabilities = vec![
|
||||||
|
0.0101, 0.0030, 0.0010, 0.0018, 0.9489, 0.0067, 0.0026, 0.0017, 0.0311, 0.0101,
|
||||||
|
];
|
||||||
let probabilities = output.iter::<f64>().unwrap().collect::<Vec<f64>>();
|
let probabilities = output.iter::<f64>().unwrap().collect::<Vec<f64>>();
|
||||||
|
|
||||||
assert_eq!(output.size(), &[10]);
|
assert_eq!(output.size(), &[10]);
|
||||||
@ -122,4 +140,4 @@ fn electra_discriminator() -> failure::Fallible<()> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
243
tests/gpt2.rs
243
tests/gpt2.rs
@ -1,87 +1,115 @@
|
|||||||
use tch::{Device, nn, Tensor};
|
use rust_bert::gpt2::{
|
||||||
use rust_tokenizers::{Gpt2Tokenizer, TruncationStrategy, Tokenizer};
|
GPT2LMHeadModel, Gpt2Config, Gpt2ConfigResources, Gpt2MergesResources, Gpt2ModelResources,
|
||||||
|
Gpt2VocabResources,
|
||||||
|
};
|
||||||
|
use rust_bert::pipelines::generation::{
|
||||||
|
Cache, GPT2Generator, GenerateConfig, LMHeadModel, LanguageGenerator,
|
||||||
|
};
|
||||||
|
use rust_bert::resources::{download_resource, RemoteResource, Resource};
|
||||||
use rust_bert::Config;
|
use rust_bert::Config;
|
||||||
use rust_bert::pipelines::generation::{GPT2Generator, LanguageGenerator, GenerateConfig, LMHeadModel, Cache};
|
use rust_tokenizers::{Gpt2Tokenizer, Tokenizer, TruncationStrategy};
|
||||||
use rust_bert::gpt2::{Gpt2Config, GPT2LMHeadModel, Gpt2ConfigResources, Gpt2MergesResources, Gpt2VocabResources, Gpt2ModelResources};
|
use tch::{nn, Device, Tensor};
|
||||||
use rust_bert::resources::{RemoteResource, Resource, download_resource};
|
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn gpt2_lm_model() -> failure::Fallible<()> {
|
fn gpt2_lm_model() -> failure::Fallible<()> {
|
||||||
// Resources paths
|
// Resources paths
|
||||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2));
|
let config_resource =
|
||||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2));
|
Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2));
|
||||||
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
|
let vocab_resource =
|
||||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
|
Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2));
|
||||||
|
let merges_resource =
|
||||||
|
Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
|
||||||
|
let weights_resource =
|
||||||
|
Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
|
||||||
let config_path = download_resource(&config_resource)?;
|
let config_path = download_resource(&config_resource)?;
|
||||||
let vocab_path = download_resource(&vocab_resource)?;
|
let vocab_path = download_resource(&vocab_resource)?;
|
||||||
let merges_path = download_resource(&merges_resource)?;
|
let merges_path = download_resource(&merges_resource)?;
|
||||||
let weights_path = download_resource(&weights_resource)?;
|
let weights_path = download_resource(&weights_resource)?;
|
||||||
|
|
||||||
// Set-up masked LM model
|
// Set-up masked LM model
|
||||||
let device = Device::Cpu;
|
let device = Device::Cpu;
|
||||||
let mut vs = nn::VarStore::new(device);
|
let mut vs = nn::VarStore::new(device);
|
||||||
let tokenizer: Gpt2Tokenizer = Gpt2Tokenizer::from_file(vocab_path.to_str().unwrap(), merges_path.to_str().unwrap(), false);
|
let tokenizer: Gpt2Tokenizer = Gpt2Tokenizer::from_file(
|
||||||
|
vocab_path.to_str().unwrap(),
|
||||||
|
merges_path.to_str().unwrap(),
|
||||||
|
false,
|
||||||
|
);
|
||||||
let config = Gpt2Config::from_file(config_path);
|
let config = Gpt2Config::from_file(config_path);
|
||||||
let gpt2_model = GPT2LMHeadModel::new(&vs.root(), &config);
|
let gpt2_model = GPT2LMHeadModel::new(&vs.root(), &config);
|
||||||
vs.load(weights_path)?;
|
vs.load(weights_path)?;
|
||||||
|
|
||||||
// Define input
|
// Define input
|
||||||
let input = ["One two three four"];
|
let input = ["One two three four"];
|
||||||
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
let tokenized_input =
|
||||||
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
|
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
||||||
let tokenized_input = tokenized_input.
|
let max_len = tokenized_input
|
||||||
iter().
|
.iter()
|
||||||
map(|input| input.token_ids.clone()).
|
.map(|input| input.token_ids.len())
|
||||||
map(|mut input| {
|
.max()
|
||||||
|
.unwrap();
|
||||||
|
let tokenized_input = tokenized_input
|
||||||
|
.iter()
|
||||||
|
.map(|input| input.token_ids.clone())
|
||||||
|
.map(|mut input| {
|
||||||
input.extend(vec![0; max_len - input.len()]);
|
input.extend(vec![0; max_len - input.len()]);
|
||||||
input
|
input
|
||||||
}).
|
})
|
||||||
map(|input|
|
.map(|input| Tensor::of_slice(&(input)))
|
||||||
Tensor::of_slice(&(input))).
|
.collect::<Vec<_>>();
|
||||||
collect::<Vec<_>>();
|
|
||||||
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
|
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
|
||||||
|
|
||||||
// Forward pass
|
// Forward pass
|
||||||
let (output, _, past, _, _) = gpt2_model.forward_t(
|
let (output, _, past, _, _) = gpt2_model
|
||||||
&Some(input_tensor),
|
.forward_t(
|
||||||
Cache::None,
|
&Some(input_tensor),
|
||||||
&None,
|
Cache::None,
|
||||||
&None,
|
&None,
|
||||||
&None,
|
&None,
|
||||||
&None,
|
&None,
|
||||||
None,
|
&None,
|
||||||
&None,
|
None,
|
||||||
false).unwrap();
|
&None,
|
||||||
|
false,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
let next_word_id = output.get(0).get(-1).argmax(-1, true).int64_value(&[0]);
|
let next_word_id = output.get(0).get(-1).argmax(-1, true).int64_value(&[0]);
|
||||||
let next_word = tokenizer.decode(vec!(next_word_id), true, true);
|
let next_word = tokenizer.decode(vec![next_word_id], true, true);
|
||||||
|
|
||||||
assert_eq!(output.size(), vec!(1, 4, 50257));
|
assert_eq!(output.size(), vec!(1, 4, 50257));
|
||||||
match past {
|
match past {
|
||||||
Cache::GPT2Cache(past) => {
|
Cache::GPT2Cache(past) => {
|
||||||
assert!(past.is_some());
|
assert!(past.is_some());
|
||||||
assert_eq!(past.as_ref().unwrap().len(), config.n_layer as usize);
|
assert_eq!(past.as_ref().unwrap().len(), config.n_layer as usize);
|
||||||
assert_eq!(past.as_ref().unwrap()[0].size(), vec!(2, 1, config.n_head, 4, 64));
|
assert_eq!(
|
||||||
|
past.as_ref().unwrap()[0].size(),
|
||||||
|
vec!(2, 1, config.n_head, 4, 64)
|
||||||
|
);
|
||||||
}
|
}
|
||||||
_ => panic!("Wrong cache returned for GPT2")
|
_ => panic!("Wrong cache returned for GPT2"),
|
||||||
}
|
}
|
||||||
assert!((output.double_value(&[0, output.size()[1] - 1, next_word_id]) - (-69.4948)).abs() < 1e-4);
|
assert!(
|
||||||
|
(output.double_value(&[0, output.size()[1] - 1, next_word_id]) - (-69.4948)).abs() < 1e-4
|
||||||
|
);
|
||||||
assert_eq!(next_word_id, 1936i64);
|
assert_eq!(next_word_id, 1936i64);
|
||||||
assert_eq!(next_word, String::from(" five"));
|
assert_eq!(next_word, String::from(" five"));
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn gpt2_generation_greedy() -> failure::Fallible<()> {
|
fn gpt2_generation_greedy() -> failure::Fallible<()> {
|
||||||
// Resources definition
|
// Resources definition
|
||||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2));
|
let config_resource =
|
||||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2));
|
Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2));
|
||||||
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
|
let vocab_resource =
|
||||||
let model_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
|
Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2));
|
||||||
|
let merges_resource =
|
||||||
|
Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
|
||||||
|
let model_resource =
|
||||||
|
Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
|
||||||
|
|
||||||
// Set-up masked LM model
|
// Set-up masked LM model
|
||||||
let generate_config = GenerateConfig {
|
let generate_config = GenerateConfig {
|
||||||
model_resource,
|
model_resource,
|
||||||
config_resource,
|
config_resource,
|
||||||
@ -97,7 +125,7 @@ fn gpt2_generation_greedy() -> failure::Fallible<()> {
|
|||||||
let model = GPT2Generator::new(generate_config)?;
|
let model = GPT2Generator::new(generate_config)?;
|
||||||
|
|
||||||
let input_context = "The cat";
|
let input_context = "The cat";
|
||||||
let output = model.generate(Some(vec!(input_context)), None);
|
let output = model.generate(Some(vec![input_context]), None);
|
||||||
|
|
||||||
assert_eq!(output.len(), 1);
|
assert_eq!(output.len(), 1);
|
||||||
assert_eq!(output[0], "The cat was found in a field near the town of Keflavik, about 30 miles (48 kilometers) south-east of Moscow.\n\n\n");
|
assert_eq!(output[0], "The cat was found in a field near the town of Keflavik, about 30 miles (48 kilometers) south-east of Moscow.\n\n\n");
|
||||||
@ -108,12 +136,16 @@ fn gpt2_generation_greedy() -> failure::Fallible<()> {
|
|||||||
#[test]
|
#[test]
|
||||||
fn gpt2_generation_beam_search() -> failure::Fallible<()> {
|
fn gpt2_generation_beam_search() -> failure::Fallible<()> {
|
||||||
// Resources definition
|
// Resources definition
|
||||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2));
|
let config_resource =
|
||||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2));
|
Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2));
|
||||||
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
|
let vocab_resource =
|
||||||
let model_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
|
Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2));
|
||||||
|
let merges_resource =
|
||||||
|
Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
|
||||||
|
let model_resource =
|
||||||
|
Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
|
||||||
|
|
||||||
// Set-up masked LM model
|
// Set-up masked LM model
|
||||||
let generate_config = GenerateConfig {
|
let generate_config = GenerateConfig {
|
||||||
model_resource,
|
model_resource,
|
||||||
config_resource,
|
config_resource,
|
||||||
@ -129,12 +161,21 @@ fn gpt2_generation_beam_search() -> failure::Fallible<()> {
|
|||||||
let model = GPT2Generator::new(generate_config)?;
|
let model = GPT2Generator::new(generate_config)?;
|
||||||
|
|
||||||
let input_context = "The dog";
|
let input_context = "The dog";
|
||||||
let output = model.generate(Some(vec!(input_context)), None);
|
let output = model.generate(Some(vec![input_context]), None);
|
||||||
|
|
||||||
assert_eq!(output.len(), 3);
|
assert_eq!(output.len(), 3);
|
||||||
assert_eq!(output[0], "The dog was found in the backyard of a home in the 6200 block of South Main Street.");
|
assert_eq!(
|
||||||
assert_eq!(output[1], "The dog was found in the backyard of a home in the 6500 block of South Main Street.");
|
output[0],
|
||||||
assert_eq!(output[2], "The dog was found in the backyard of a home in the 6200 block of South Main Street,");
|
"The dog was found in the backyard of a home in the 6200 block of South Main Street."
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
output[1],
|
||||||
|
"The dog was found in the backyard of a home in the 6500 block of South Main Street."
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
output[2],
|
||||||
|
"The dog was found in the backyard of a home in the 6200 block of South Main Street,"
|
||||||
|
);
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@ -142,12 +183,16 @@ fn gpt2_generation_beam_search() -> failure::Fallible<()> {
|
|||||||
#[test]
|
#[test]
|
||||||
fn gpt2_generation_beam_search_multiple_prompts_without_padding() -> failure::Fallible<()> {
|
fn gpt2_generation_beam_search_multiple_prompts_without_padding() -> failure::Fallible<()> {
|
||||||
// Resources definition
|
// Resources definition
|
||||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2));
|
let config_resource =
|
||||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2));
|
Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2));
|
||||||
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
|
let vocab_resource =
|
||||||
let model_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
|
Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2));
|
||||||
|
let merges_resource =
|
||||||
|
Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
|
||||||
|
let model_resource =
|
||||||
|
Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
|
||||||
|
|
||||||
// Set-up masked LM model
|
// Set-up masked LM model
|
||||||
let generate_config = GenerateConfig {
|
let generate_config = GenerateConfig {
|
||||||
model_resource,
|
model_resource,
|
||||||
config_resource,
|
config_resource,
|
||||||
@ -164,15 +209,33 @@ fn gpt2_generation_beam_search_multiple_prompts_without_padding() -> failure::Fa
|
|||||||
|
|
||||||
let input_context_1 = "The dog";
|
let input_context_1 = "The dog";
|
||||||
let input_context_2 = "The cat";
|
let input_context_2 = "The cat";
|
||||||
let output = model.generate(Some(vec!(input_context_1, input_context_2)), None);
|
let output = model.generate(Some(vec![input_context_1, input_context_2]), None);
|
||||||
|
|
||||||
assert_eq!(output.len(), 6);
|
assert_eq!(output.len(), 6);
|
||||||
assert_eq!(output[0], "The dog was found in the backyard of a home in the 6200 block of South Main Street.");
|
assert_eq!(
|
||||||
assert_eq!(output[1], "The dog was found in the backyard of a home in the 6500 block of South Main Street.");
|
output[0],
|
||||||
assert_eq!(output[2], "The dog was found in the backyard of a home in the 6200 block of South Main Street,");
|
"The dog was found in the backyard of a home in the 6200 block of South Main Street."
|
||||||
assert_eq!(output[3], "The cat-and-mouse game.\n\n\"I think it\'s going to be interesting to");
|
);
|
||||||
assert_eq!(output[4], "The cat-and-mouse game.\n\n\"I think it\'s going to be a very");
|
assert_eq!(
|
||||||
assert_eq!(output[5], "The cat-and-mouse game.\n\n\"I think it\'s going to be very interesting");
|
output[1],
|
||||||
|
"The dog was found in the backyard of a home in the 6500 block of South Main Street."
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
output[2],
|
||||||
|
"The dog was found in the backyard of a home in the 6200 block of South Main Street,"
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
output[3],
|
||||||
|
"The cat-and-mouse game.\n\n\"I think it\'s going to be interesting to"
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
output[4],
|
||||||
|
"The cat-and-mouse game.\n\n\"I think it\'s going to be a very"
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
output[5],
|
||||||
|
"The cat-and-mouse game.\n\n\"I think it\'s going to be very interesting"
|
||||||
|
);
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@ -180,12 +243,16 @@ fn gpt2_generation_beam_search_multiple_prompts_without_padding() -> failure::Fa
|
|||||||
#[test]
|
#[test]
|
||||||
fn gpt2_generation_beam_search_multiple_prompts_with_padding() -> failure::Fallible<()> {
|
fn gpt2_generation_beam_search_multiple_prompts_with_padding() -> failure::Fallible<()> {
|
||||||
// Resources definition
|
// Resources definition
|
||||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2));
|
let config_resource =
|
||||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2));
|
Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2));
|
||||||
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
|
let vocab_resource =
|
||||||
let model_resource = Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
|
Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2));
|
||||||
|
let merges_resource =
|
||||||
|
Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2));
|
||||||
|
let model_resource =
|
||||||
|
Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2));
|
||||||
|
|
||||||
// Set-up masked LM model
|
// Set-up masked LM model
|
||||||
let generate_config = GenerateConfig {
|
let generate_config = GenerateConfig {
|
||||||
model_resource,
|
model_resource,
|
||||||
config_resource,
|
config_resource,
|
||||||
@ -202,15 +269,33 @@ fn gpt2_generation_beam_search_multiple_prompts_with_padding() -> failure::Falli
|
|||||||
|
|
||||||
let input_context_1 = "The dog";
|
let input_context_1 = "The dog";
|
||||||
let input_context_2 = "The cat was";
|
let input_context_2 = "The cat was";
|
||||||
let output = model.generate(Some(vec!(input_context_1, input_context_2)), None);
|
let output = model.generate(Some(vec![input_context_1, input_context_2]), None);
|
||||||
|
|
||||||
assert_eq!(output.len(), 6);
|
assert_eq!(output.len(), 6);
|
||||||
assert_eq!(output[0], "The dog was found dead on the side of the road in the middle of the night.\n");
|
assert_eq!(
|
||||||
assert_eq!(output[1], "The dog was found dead on the side of the road in the middle of the night on Sunday");
|
output[0],
|
||||||
assert_eq!(output[2], "The dog was found dead on the side of the road in the middle of the night on Saturday");
|
"The dog was found dead on the side of the road in the middle of the night.\n"
|
||||||
assert_eq!(output[3], "The cat was taken to a local hospital, where it was treated and released.\n\nPolice said");
|
);
|
||||||
assert_eq!(output[4], "The cat was taken to a local hospital, where it was treated and released.\n\n\"It");
|
assert_eq!(
|
||||||
assert_eq!(output[5], "The cat was taken to a local hospital, where it was treated and released.\n\n\"We");
|
output[1],
|
||||||
|
"The dog was found dead on the side of the road in the middle of the night on Sunday"
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
output[2],
|
||||||
|
"The dog was found dead on the side of the road in the middle of the night on Saturday"
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
output[3],
|
||||||
|
"The cat was taken to a local hospital, where it was treated and released.\n\nPolice said"
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
output[4],
|
||||||
|
"The cat was taken to a local hospital, where it was treated and released.\n\n\"It"
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
output[5],
|
||||||
|
"The cat was taken to a local hospital, where it was treated and released.\n\n\"We"
|
||||||
|
);
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -1,12 +1,11 @@
|
|||||||
use rust_bert::pipelines::translation::{TranslationConfig, Language, TranslationModel};
|
use rust_bert::pipelines::translation::{Language, TranslationConfig, TranslationModel};
|
||||||
use tch::Device;
|
use tch::Device;
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
#[cfg_attr(not(feature = "all-tests"), ignore)]
|
#[cfg_attr(not(feature = "all-tests"), ignore)]
|
||||||
fn test_translation() -> failure::Fallible<()> {
|
fn test_translation() -> failure::Fallible<()> {
|
||||||
|
// Set-up translation model
|
||||||
// Set-up translation model
|
let translation_config = TranslationConfig::new(Language::EnglishToFrench, Device::Cpu);
|
||||||
let translation_config = TranslationConfig::new(Language::EnglishToFrench, Device::Cpu);
|
|
||||||
let model = TranslationModel::new(translation_config)?;
|
let model = TranslationModel::new(translation_config)?;
|
||||||
|
|
||||||
let input_context_1 = "The quick brown fox jumps over the lazy dog";
|
let input_context_1 = "The quick brown fox jumps over the lazy dog";
|
||||||
@ -15,8 +14,11 @@ fn test_translation() -> failure::Fallible<()> {
|
|||||||
let output = model.translate(&[input_context_1, input_context_2]);
|
let output = model.translate(&[input_context_1, input_context_2]);
|
||||||
|
|
||||||
assert_eq!(output.len(), 2);
|
assert_eq!(output.len(), 2);
|
||||||
assert_eq!(output[0], " Le rapide renard brun saute sur le chien paresseux");
|
assert_eq!(
|
||||||
|
output[0],
|
||||||
|
" Le rapide renard brun saute sur le chien paresseux"
|
||||||
|
);
|
||||||
assert_eq!(output[1], " Le chien ne s'est pas réveillé.");
|
assert_eq!(output[1], " Le chien ne s'est pas réveillé.");
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -1,64 +1,90 @@
|
|||||||
use tch::{Device, nn, Tensor};
|
|
||||||
use rust_tokenizers::{TruncationStrategy, Tokenizer, OpenAiGptTokenizer};
|
|
||||||
use rust_bert::Config;
|
|
||||||
use rust_bert::pipelines::generation::{OpenAIGenerator, LanguageGenerator, GenerateConfig, LMHeadModel, Cache};
|
|
||||||
use rust_bert::gpt2::Gpt2Config;
|
use rust_bert::gpt2::Gpt2Config;
|
||||||
use rust_bert::openai_gpt::{OpenAIGPTLMHeadModel, OpenAiGptConfigResources, OpenAiGptVocabResources, OpenAiGptMergesResources, OpenAiGptModelResources};
|
use rust_bert::openai_gpt::{
|
||||||
use rust_bert::resources::{RemoteResource, Resource, download_resource};
|
OpenAIGPTLMHeadModel, OpenAiGptConfigResources, OpenAiGptMergesResources,
|
||||||
|
OpenAiGptModelResources, OpenAiGptVocabResources,
|
||||||
|
};
|
||||||
|
use rust_bert::pipelines::generation::{
|
||||||
|
Cache, GenerateConfig, LMHeadModel, LanguageGenerator, OpenAIGenerator,
|
||||||
|
};
|
||||||
|
use rust_bert::resources::{download_resource, RemoteResource, Resource};
|
||||||
|
use rust_bert::Config;
|
||||||
|
use rust_tokenizers::{OpenAiGptTokenizer, Tokenizer, TruncationStrategy};
|
||||||
|
use tch::{nn, Device, Tensor};
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn openai_gpt_lm_model() -> failure::Fallible<()> {
|
fn openai_gpt_lm_model() -> failure::Fallible<()> {
|
||||||
// Resources paths
|
// Resources paths
|
||||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptConfigResources::GPT));
|
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptVocabResources::GPT));
|
OpenAiGptConfigResources::GPT,
|
||||||
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptMergesResources::GPT));
|
));
|
||||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptModelResources::GPT));
|
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||||
|
OpenAiGptVocabResources::GPT,
|
||||||
|
));
|
||||||
|
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||||
|
OpenAiGptMergesResources::GPT,
|
||||||
|
));
|
||||||
|
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||||
|
OpenAiGptModelResources::GPT,
|
||||||
|
));
|
||||||
let config_path = download_resource(&config_resource)?;
|
let config_path = download_resource(&config_resource)?;
|
||||||
let vocab_path = download_resource(&vocab_resource)?;
|
let vocab_path = download_resource(&vocab_resource)?;
|
||||||
let merges_path = download_resource(&merges_resource)?;
|
let merges_path = download_resource(&merges_resource)?;
|
||||||
let weights_path = download_resource(&weights_resource)?;
|
let weights_path = download_resource(&weights_resource)?;
|
||||||
|
|
||||||
// Set-up masked LM model
|
// Set-up masked LM model
|
||||||
let device = Device::Cpu;
|
let device = Device::Cpu;
|
||||||
let mut vs = nn::VarStore::new(device);
|
let mut vs = nn::VarStore::new(device);
|
||||||
let tokenizer = OpenAiGptTokenizer::from_file(vocab_path.to_str().unwrap(), merges_path.to_str().unwrap(), true);
|
let tokenizer = OpenAiGptTokenizer::from_file(
|
||||||
|
vocab_path.to_str().unwrap(),
|
||||||
|
merges_path.to_str().unwrap(),
|
||||||
|
true,
|
||||||
|
);
|
||||||
let config = Gpt2Config::from_file(config_path);
|
let config = Gpt2Config::from_file(config_path);
|
||||||
let openai_gpt = OpenAIGPTLMHeadModel::new(&vs.root(), &config);
|
let openai_gpt = OpenAIGPTLMHeadModel::new(&vs.root(), &config);
|
||||||
vs.load(weights_path)?;
|
vs.load(weights_path)?;
|
||||||
|
|
||||||
// Define input
|
// Define input
|
||||||
let input = ["Wondering what the next word will"];
|
let input = ["Wondering what the next word will"];
|
||||||
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
let tokenized_input =
|
||||||
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
|
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
||||||
let tokenized_input = tokenized_input.
|
let max_len = tokenized_input
|
||||||
iter().
|
.iter()
|
||||||
map(|input| input.token_ids.clone()).
|
.map(|input| input.token_ids.len())
|
||||||
map(|mut input| {
|
.max()
|
||||||
|
.unwrap();
|
||||||
|
let tokenized_input = tokenized_input
|
||||||
|
.iter()
|
||||||
|
.map(|input| input.token_ids.clone())
|
||||||
|
.map(|mut input| {
|
||||||
input.extend(vec![0; max_len - input.len()]);
|
input.extend(vec![0; max_len - input.len()]);
|
||||||
input
|
input
|
||||||
}).
|
})
|
||||||
map(|input|
|
.map(|input| Tensor::of_slice(&(input)))
|
||||||
Tensor::of_slice(&(input))).
|
.collect::<Vec<_>>();
|
||||||
collect::<Vec<_>>();
|
|
||||||
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
|
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
|
||||||
|
|
||||||
// Forward pass
|
// Forward pass
|
||||||
let (output, _, _, _, _) = openai_gpt.forward_t(
|
let (output, _, _, _, _) = openai_gpt
|
||||||
&Some(input_tensor),
|
.forward_t(
|
||||||
Cache::None,
|
&Some(input_tensor),
|
||||||
&None,
|
Cache::None,
|
||||||
&None,
|
&None,
|
||||||
&None,
|
&None,
|
||||||
&None,
|
&None,
|
||||||
None,
|
&None,
|
||||||
&None,
|
None,
|
||||||
false).unwrap();
|
&None,
|
||||||
|
false,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
let next_word_id = output.get(0).get(-1).argmax(-1, true).int64_value(&[0]);
|
let next_word_id = output.get(0).get(-1).argmax(-1, true).int64_value(&[0]);
|
||||||
let next_word = tokenizer.decode(vec!(next_word_id), true, true);
|
let next_word = tokenizer.decode(vec![next_word_id], true, true);
|
||||||
|
|
||||||
assert_eq!(output.size(), vec!(1, 6, 40478));
|
assert_eq!(output.size(), vec!(1, 6, 40478));
|
||||||
assert!((output.double_value(&[0, output.size()[1] - 1, next_word_id]) - (9.1056)).abs() < 1e-4);
|
assert!(
|
||||||
|
(output.double_value(&[0, output.size()[1] - 1, next_word_id]) - (9.1056)).abs() < 1e-4
|
||||||
|
);
|
||||||
assert_eq!(next_word_id, 580i64);
|
assert_eq!(next_word_id, 580i64);
|
||||||
assert_eq!(next_word, String::from("be"));
|
assert_eq!(next_word, String::from("be"));
|
||||||
|
|
||||||
@ -68,12 +94,20 @@ fn openai_gpt_lm_model() -> failure::Fallible<()> {
|
|||||||
#[test]
|
#[test]
|
||||||
fn openai_gpt_generation_greedy() -> failure::Fallible<()> {
|
fn openai_gpt_generation_greedy() -> failure::Fallible<()> {
|
||||||
// Resources paths
|
// Resources paths
|
||||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptConfigResources::GPT));
|
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptVocabResources::GPT));
|
OpenAiGptConfigResources::GPT,
|
||||||
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptMergesResources::GPT));
|
));
|
||||||
let model_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptModelResources::GPT));
|
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||||
|
OpenAiGptVocabResources::GPT,
|
||||||
|
));
|
||||||
|
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||||
|
OpenAiGptMergesResources::GPT,
|
||||||
|
));
|
||||||
|
let model_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||||
|
OpenAiGptModelResources::GPT,
|
||||||
|
));
|
||||||
|
|
||||||
// Set-up masked LM model
|
// Set-up masked LM model
|
||||||
let generate_config = GenerateConfig {
|
let generate_config = GenerateConfig {
|
||||||
model_resource,
|
model_resource,
|
||||||
config_resource,
|
config_resource,
|
||||||
@ -90,7 +124,7 @@ fn openai_gpt_generation_greedy() -> failure::Fallible<()> {
|
|||||||
let model = OpenAIGenerator::new(generate_config)?;
|
let model = OpenAIGenerator::new(generate_config)?;
|
||||||
|
|
||||||
let input_context = "It was an intense machine dialogue. ";
|
let input_context = "It was an intense machine dialogue. ";
|
||||||
let output = model.generate(Some(vec!(input_context)), None);
|
let output = model.generate(Some(vec![input_context]), None);
|
||||||
|
|
||||||
assert_eq!(output.len(), 1);
|
assert_eq!(output.len(), 1);
|
||||||
assert_eq!(output[0], "it was an intense machine dialogue. \n \" i\'m sorry, but we have to go now! the police are on their way and they\'re going after you - or at least that\'s what my");
|
assert_eq!(output[0], "it was an intense machine dialogue. \n \" i\'m sorry, but we have to go now! the police are on their way and they\'re going after you - or at least that\'s what my");
|
||||||
@ -101,12 +135,20 @@ fn openai_gpt_generation_greedy() -> failure::Fallible<()> {
|
|||||||
#[test]
|
#[test]
|
||||||
fn openai_gpt_generation_beam_search() -> failure::Fallible<()> {
|
fn openai_gpt_generation_beam_search() -> failure::Fallible<()> {
|
||||||
// Resources paths
|
// Resources paths
|
||||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptConfigResources::GPT));
|
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptVocabResources::GPT));
|
OpenAiGptConfigResources::GPT,
|
||||||
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptMergesResources::GPT));
|
));
|
||||||
let model_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptModelResources::GPT));
|
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||||
|
OpenAiGptVocabResources::GPT,
|
||||||
|
));
|
||||||
|
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||||
|
OpenAiGptMergesResources::GPT,
|
||||||
|
));
|
||||||
|
let model_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||||
|
OpenAiGptModelResources::GPT,
|
||||||
|
));
|
||||||
|
|
||||||
// Set-up masked LM model
|
// Set-up masked LM model
|
||||||
let generate_config = GenerateConfig {
|
let generate_config = GenerateConfig {
|
||||||
model_resource,
|
model_resource,
|
||||||
config_resource,
|
config_resource,
|
||||||
@ -122,12 +164,21 @@ fn openai_gpt_generation_beam_search() -> failure::Fallible<()> {
|
|||||||
let model = OpenAIGenerator::new(generate_config)?;
|
let model = OpenAIGenerator::new(generate_config)?;
|
||||||
|
|
||||||
let input_context = "The dog is";
|
let input_context = "The dog is";
|
||||||
let output = model.generate(Some(vec!(input_context)), None);
|
let output = model.generate(Some(vec![input_context]), None);
|
||||||
|
|
||||||
assert_eq!(output.len(), 3);
|
assert_eq!(output.len(), 3);
|
||||||
assert_eq!(output[0], "the dog isn\'t going anywhere. i\'m going to take care of him. i \'ll be right");
|
assert_eq!(
|
||||||
assert_eq!(output[1], "the dog isn\'t going anywhere. i\'m going to take care of him. i \'ll be back");
|
output[0],
|
||||||
assert_eq!(output[2], "the dog isn\'t going anywhere. i\'m going to take care of him. \" \n \" i");
|
"the dog isn\'t going anywhere. i\'m going to take care of him. i \'ll be right"
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
output[1],
|
||||||
|
"the dog isn\'t going anywhere. i\'m going to take care of him. i \'ll be back"
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
output[2],
|
||||||
|
"the dog isn\'t going anywhere. i\'m going to take care of him. \" \n \" i"
|
||||||
|
);
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@ -135,12 +186,20 @@ fn openai_gpt_generation_beam_search() -> failure::Fallible<()> {
|
|||||||
#[test]
|
#[test]
|
||||||
fn openai_gpt_generation_beam_search_multiple_prompts_without_padding() -> failure::Fallible<()> {
|
fn openai_gpt_generation_beam_search_multiple_prompts_without_padding() -> failure::Fallible<()> {
|
||||||
// Resources paths
|
// Resources paths
|
||||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptConfigResources::GPT));
|
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptVocabResources::GPT));
|
OpenAiGptConfigResources::GPT,
|
||||||
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptMergesResources::GPT));
|
));
|
||||||
let model_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptModelResources::GPT));
|
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||||
|
OpenAiGptVocabResources::GPT,
|
||||||
|
));
|
||||||
|
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||||
|
OpenAiGptMergesResources::GPT,
|
||||||
|
));
|
||||||
|
let model_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||||
|
OpenAiGptModelResources::GPT,
|
||||||
|
));
|
||||||
|
|
||||||
// Set-up masked LM model
|
// Set-up masked LM model
|
||||||
let generate_config = GenerateConfig {
|
let generate_config = GenerateConfig {
|
||||||
model_resource,
|
model_resource,
|
||||||
config_resource,
|
config_resource,
|
||||||
@ -157,18 +216,36 @@ fn openai_gpt_generation_beam_search_multiple_prompts_without_padding() -> failu
|
|||||||
|
|
||||||
let input_context_1 = "The dog is";
|
let input_context_1 = "The dog is";
|
||||||
let input_context_2 = "The cat";
|
let input_context_2 = "The cat";
|
||||||
let output = model.generate(Some(vec!(input_context_1, input_context_2)), None);
|
let output = model.generate(Some(vec![input_context_1, input_context_2]), None);
|
||||||
|
|
||||||
assert_eq!(output.len(), 6);
|
assert_eq!(output.len(), 6);
|
||||||
|
|
||||||
// Unpadded sequence (generation for `The dog is`) is identical to the
|
// Unpadded sequence (generation for `The dog is`) is identical to the
|
||||||
assert_eq!(output[0], "the dog isn\'t going anywhere. i\'m going to take care of him. i \'ll be right");
|
assert_eq!(
|
||||||
assert_eq!(output[1], "the dog isn\'t going anywhere. i\'m going to take care of him. i \'ll be back");
|
output[0],
|
||||||
assert_eq!(output[2], "the dog isn\'t going anywhere. i\'m going to take care of him. \" \n \" i");
|
"the dog isn\'t going anywhere. i\'m going to take care of him. i \'ll be right"
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
output[1],
|
||||||
|
"the dog isn\'t going anywhere. i\'m going to take care of him. i \'ll be back"
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
output[2],
|
||||||
|
"the dog isn\'t going anywhere. i\'m going to take care of him. \" \n \" i"
|
||||||
|
);
|
||||||
|
|
||||||
assert_eq!(output[3], "the cat. \" \n \" i don\'t know what you\'re talking about. i don\'t");
|
assert_eq!(
|
||||||
assert_eq!(output[4], "the cat. \" \n \" i don\'t know what you\'re talking about. i\'m not");
|
output[3],
|
||||||
assert_eq!(output[5], "the cat. \" \n \" i don\'t know what you\'re talking about. i do know");
|
"the cat. \" \n \" i don\'t know what you\'re talking about. i don\'t"
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
output[4],
|
||||||
|
"the cat. \" \n \" i don\'t know what you\'re talking about. i\'m not"
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
output[5],
|
||||||
|
"the cat. \" \n \" i don\'t know what you\'re talking about. i do know"
|
||||||
|
);
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@ -176,12 +253,20 @@ fn openai_gpt_generation_beam_search_multiple_prompts_without_padding() -> failu
|
|||||||
#[test]
|
#[test]
|
||||||
fn openai_gpt_generation_beam_search_multiple_prompts_with_padding() -> failure::Fallible<()> {
|
fn openai_gpt_generation_beam_search_multiple_prompts_with_padding() -> failure::Fallible<()> {
|
||||||
// Resources paths
|
// Resources paths
|
||||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptConfigResources::GPT));
|
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptVocabResources::GPT));
|
OpenAiGptConfigResources::GPT,
|
||||||
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptMergesResources::GPT));
|
));
|
||||||
let model_resource = Resource::Remote(RemoteResource::from_pretrained(OpenAiGptModelResources::GPT));
|
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||||
|
OpenAiGptVocabResources::GPT,
|
||||||
|
));
|
||||||
|
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||||
|
OpenAiGptMergesResources::GPT,
|
||||||
|
));
|
||||||
|
let model_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||||
|
OpenAiGptModelResources::GPT,
|
||||||
|
));
|
||||||
|
|
||||||
// Set-up masked LM model
|
// Set-up masked LM model
|
||||||
let generate_config = GenerateConfig {
|
let generate_config = GenerateConfig {
|
||||||
model_resource,
|
model_resource,
|
||||||
config_resource,
|
config_resource,
|
||||||
@ -198,16 +283,34 @@ fn openai_gpt_generation_beam_search_multiple_prompts_with_padding() -> failure:
|
|||||||
|
|
||||||
let input_context_1 = "The dog is";
|
let input_context_1 = "The dog is";
|
||||||
let input_context_2 = "The cat was in";
|
let input_context_2 = "The cat was in";
|
||||||
let output = model.generate(Some(vec!(input_context_1, input_context_2)), None);
|
let output = model.generate(Some(vec![input_context_1, input_context_2]), None);
|
||||||
|
|
||||||
assert_eq!(output.len(), 6);
|
assert_eq!(output.len(), 6);
|
||||||
// Left padding impacts the generated sentences output
|
// Left padding impacts the generated sentences output
|
||||||
assert_eq!(output[0], "the dog is a dog. \" \n \" i don\'t know what you\'re talking about.");
|
assert_eq!(
|
||||||
assert_eq!(output[1], "the dog is a dog. \" \n \" i don\'t know what you\'re talking about,");
|
output[0],
|
||||||
assert_eq!(output[2], "the dog is a dog. \" \n \" i don\'t know what you\'re talking about!");
|
"the dog is a dog. \" \n \" i don\'t know what you\'re talking about."
|
||||||
assert_eq!(output[3], "the cat was in the room with them. \n \" what\'s going on? \" i asked.");
|
);
|
||||||
assert_eq!(output[4], "the cat was in the room with them. \n \" what\'s going on? \" she asked.");
|
assert_eq!(
|
||||||
assert_eq!(output[5], "the cat was in the room with them. \n \" what\'s going on? why are you all");
|
output[1],
|
||||||
|
"the dog is a dog. \" \n \" i don\'t know what you\'re talking about,"
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
output[2],
|
||||||
|
"the dog is a dog. \" \n \" i don\'t know what you\'re talking about!"
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
output[3],
|
||||||
|
"the cat was in the room with them. \n \" what\'s going on? \" i asked."
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
output[4],
|
||||||
|
"the cat was in the room with them. \n \" what\'s going on? \" she asked."
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
output[5],
|
||||||
|
"the cat was in the room with them. \n \" what\'s going on? why are you all"
|
||||||
|
);
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
407
tests/roberta.rs
407
tests/roberta.rs
@ -1,75 +1,99 @@
|
|||||||
use tch::{Device, nn, Tensor, no_grad};
|
|
||||||
use rust_tokenizers::{RobertaTokenizer, TruncationStrategy, Tokenizer, Vocab};
|
|
||||||
use rust_bert::Config;
|
|
||||||
use rust_bert::bert::BertConfig;
|
use rust_bert::bert::BertConfig;
|
||||||
use rust_bert::roberta::{RobertaForMaskedLM, RobertaForSequenceClassification, RobertaForMultipleChoice, RobertaForTokenClassification, RobertaForQuestionAnswering, RobertaConfigResources, RobertaVocabResources, RobertaMergesResources, RobertaModelResources};
|
use rust_bert::resources::{download_resource, RemoteResource, Resource};
|
||||||
|
use rust_bert::roberta::{
|
||||||
|
RobertaConfigResources, RobertaForMaskedLM, RobertaForMultipleChoice,
|
||||||
|
RobertaForQuestionAnswering, RobertaForSequenceClassification, RobertaForTokenClassification,
|
||||||
|
RobertaMergesResources, RobertaModelResources, RobertaVocabResources,
|
||||||
|
};
|
||||||
|
use rust_bert::Config;
|
||||||
|
use rust_tokenizers::{RobertaTokenizer, Tokenizer, TruncationStrategy, Vocab};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use rust_bert::resources::{RemoteResource, Resource, download_resource};
|
use tch::{nn, no_grad, Device, Tensor};
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn roberta_masked_lm() -> failure::Fallible<()> {
|
fn roberta_masked_lm() -> failure::Fallible<()> {
|
||||||
// Resources paths
|
// Resources paths
|
||||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaConfigResources::ROBERTA));
|
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaVocabResources::ROBERTA));
|
RobertaConfigResources::ROBERTA,
|
||||||
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaMergesResources::ROBERTA));
|
));
|
||||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaModelResources::ROBERTA));
|
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||||
|
RobertaVocabResources::ROBERTA,
|
||||||
|
));
|
||||||
|
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||||
|
RobertaMergesResources::ROBERTA,
|
||||||
|
));
|
||||||
|
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||||
|
RobertaModelResources::ROBERTA,
|
||||||
|
));
|
||||||
let config_path = download_resource(&config_resource)?;
|
let config_path = download_resource(&config_resource)?;
|
||||||
let vocab_path = download_resource(&vocab_resource)?;
|
let vocab_path = download_resource(&vocab_resource)?;
|
||||||
let merges_path = download_resource(&merges_resource)?;
|
let merges_path = download_resource(&merges_resource)?;
|
||||||
let weights_path = download_resource(&weights_resource)?;
|
let weights_path = download_resource(&weights_resource)?;
|
||||||
|
|
||||||
// Set-up masked LM model
|
// Set-up masked LM model
|
||||||
let device = Device::Cpu;
|
let device = Device::Cpu;
|
||||||
let mut vs = nn::VarStore::new(device);
|
let mut vs = nn::VarStore::new(device);
|
||||||
let tokenizer: RobertaTokenizer = RobertaTokenizer::from_file(vocab_path.to_str().unwrap(), merges_path.to_str().unwrap(), true);
|
let tokenizer: RobertaTokenizer = RobertaTokenizer::from_file(
|
||||||
|
vocab_path.to_str().unwrap(),
|
||||||
|
merges_path.to_str().unwrap(),
|
||||||
|
true,
|
||||||
|
);
|
||||||
let config = BertConfig::from_file(config_path);
|
let config = BertConfig::from_file(config_path);
|
||||||
let roberta_model = RobertaForMaskedLM::new(&vs.root(), &config);
|
let roberta_model = RobertaForMaskedLM::new(&vs.root(), &config);
|
||||||
vs.load(weights_path)?;
|
vs.load(weights_path)?;
|
||||||
|
|
||||||
// Define input
|
// Define input
|
||||||
let input = ["<pad> Looks like one thing is missing", "It\'s like comparing oranges to apples"];
|
let input = [
|
||||||
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
"<pad> Looks like one thing is missing",
|
||||||
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
|
"It\'s like comparing oranges to apples",
|
||||||
let mut tokenized_input = tokenized_input.
|
];
|
||||||
iter().
|
let tokenized_input =
|
||||||
map(|input| input.token_ids.clone()).
|
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
||||||
map(|mut input| {
|
let max_len = tokenized_input
|
||||||
|
.iter()
|
||||||
|
.map(|input| input.token_ids.len())
|
||||||
|
.max()
|
||||||
|
.unwrap();
|
||||||
|
let mut tokenized_input = tokenized_input
|
||||||
|
.iter()
|
||||||
|
.map(|input| input.token_ids.clone())
|
||||||
|
.map(|mut input| {
|
||||||
input.extend(vec![0; max_len - input.len()]);
|
input.extend(vec![0; max_len - input.len()]);
|
||||||
input
|
input
|
||||||
}).
|
})
|
||||||
collect::<Vec<_>>();
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
// Masking the token [thing] of sentence 1 and [oranges] of sentence 2
|
// Masking the token [thing] of sentence 1 and [oranges] of sentence 2
|
||||||
tokenized_input[0][4] = 103;
|
tokenized_input[0][4] = 103;
|
||||||
tokenized_input[1][5] = 103;
|
tokenized_input[1][5] = 103;
|
||||||
let tokenized_input = tokenized_input.
|
let tokenized_input = tokenized_input
|
||||||
iter().
|
.iter()
|
||||||
map(|input|
|
.map(|input| Tensor::of_slice(&(input)))
|
||||||
Tensor::of_slice(&(input))).
|
.collect::<Vec<_>>();
|
||||||
collect::<Vec<_>>();
|
|
||||||
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
|
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
|
||||||
|
|
||||||
// Forward pass
|
// Forward pass
|
||||||
let (output, _, _) = no_grad(|| {
|
let (output, _, _) = no_grad(|| {
|
||||||
roberta_model
|
roberta_model.forward_t(
|
||||||
.forward_t(Some(input_tensor),
|
Some(input_tensor),
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
&None,
|
&None,
|
||||||
&None,
|
&None,
|
||||||
false)
|
false,
|
||||||
|
)
|
||||||
});
|
});
|
||||||
|
|
||||||
// Print masked tokens
|
// Print masked tokens
|
||||||
let index_1 = output.get(0).get(4).argmax(0, false);
|
let index_1 = output.get(0).get(4).argmax(0, false);
|
||||||
let index_2 = output.get(1).get(5).argmax(0, false);
|
let index_2 = output.get(1).get(5).argmax(0, false);
|
||||||
let word_1 = tokenizer.vocab().id_to_token(&index_1.int64_value(&[]));
|
let word_1 = tokenizer.vocab().id_to_token(&index_1.int64_value(&[]));
|
||||||
let word_2 = tokenizer.vocab().id_to_token(&index_2.int64_value(&[]));
|
let word_2 = tokenizer.vocab().id_to_token(&index_2.int64_value(&[]));
|
||||||
|
|
||||||
assert_eq!("Ġsome", word_1); // Outputs "person" : "Looks like [some] thing is missing"
|
assert_eq!("Ġsome", word_1); // Outputs "person" : "Looks like [some] thing is missing"
|
||||||
assert_eq!("Ġapples", word_2);// Outputs "pear" : "It\'s like comparing [apples] to apples"
|
assert_eq!("Ġapples", word_2); // Outputs "pear" : "It\'s like comparing [apples] to apples"
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@ -77,17 +101,27 @@ fn roberta_masked_lm() -> failure::Fallible<()> {
|
|||||||
#[test]
|
#[test]
|
||||||
fn roberta_for_sequence_classification() -> failure::Fallible<()> {
|
fn roberta_for_sequence_classification() -> failure::Fallible<()> {
|
||||||
// Resources paths
|
// Resources paths
|
||||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaConfigResources::ROBERTA));
|
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaVocabResources::ROBERTA));
|
RobertaConfigResources::ROBERTA,
|
||||||
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaMergesResources::ROBERTA));
|
));
|
||||||
|
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||||
|
RobertaVocabResources::ROBERTA,
|
||||||
|
));
|
||||||
|
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||||
|
RobertaMergesResources::ROBERTA,
|
||||||
|
));
|
||||||
let config_path = download_resource(&config_resource)?;
|
let config_path = download_resource(&config_resource)?;
|
||||||
let vocab_path = download_resource(&vocab_resource)?;
|
let vocab_path = download_resource(&vocab_resource)?;
|
||||||
let merges_path = download_resource(&merges_resource)?;
|
let merges_path = download_resource(&merges_resource)?;
|
||||||
|
|
||||||
// Set-up model
|
// Set-up model
|
||||||
let device = Device::Cpu;
|
let device = Device::Cpu;
|
||||||
let vs = nn::VarStore::new(device);
|
let vs = nn::VarStore::new(device);
|
||||||
let tokenizer: RobertaTokenizer = RobertaTokenizer::from_file(vocab_path.to_str().unwrap(), merges_path.to_str().unwrap(), true);
|
let tokenizer: RobertaTokenizer = RobertaTokenizer::from_file(
|
||||||
|
vocab_path.to_str().unwrap(),
|
||||||
|
merges_path.to_str().unwrap(),
|
||||||
|
true,
|
||||||
|
);
|
||||||
let mut config = BertConfig::from_file(config_path);
|
let mut config = BertConfig::from_file(config_path);
|
||||||
let mut dummy_label_mapping = HashMap::new();
|
let mut dummy_label_mapping = HashMap::new();
|
||||||
dummy_label_mapping.insert(0, String::from("Positive"));
|
dummy_label_mapping.insert(0, String::from("Positive"));
|
||||||
@ -98,37 +132,42 @@ fn roberta_for_sequence_classification() -> failure::Fallible<()> {
|
|||||||
config.output_hidden_states = Some(true);
|
config.output_hidden_states = Some(true);
|
||||||
let roberta_model = RobertaForSequenceClassification::new(&vs.root(), &config);
|
let roberta_model = RobertaForSequenceClassification::new(&vs.root(), &config);
|
||||||
|
|
||||||
|
// Define input
|
||||||
// Define input
|
let input = [
|
||||||
let input = ["Looks like one thing is missing", "It\'s like comparing oranges to apples"];
|
"Looks like one thing is missing",
|
||||||
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
"It\'s like comparing oranges to apples",
|
||||||
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
|
];
|
||||||
let tokenized_input = tokenized_input.
|
let tokenized_input =
|
||||||
iter().
|
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
||||||
map(|input| input.token_ids.clone()).
|
let max_len = tokenized_input
|
||||||
map(|mut input| {
|
.iter()
|
||||||
|
.map(|input| input.token_ids.len())
|
||||||
|
.max()
|
||||||
|
.unwrap();
|
||||||
|
let tokenized_input = tokenized_input
|
||||||
|
.iter()
|
||||||
|
.map(|input| input.token_ids.clone())
|
||||||
|
.map(|mut input| {
|
||||||
input.extend(vec![0; max_len - input.len()]);
|
input.extend(vec![0; max_len - input.len()]);
|
||||||
input
|
input
|
||||||
}).
|
})
|
||||||
map(|input|
|
.map(|input| Tensor::of_slice(&(input)))
|
||||||
Tensor::of_slice(&(input))).
|
.collect::<Vec<_>>();
|
||||||
collect::<Vec<_>>();
|
|
||||||
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
|
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
|
||||||
|
|
||||||
// Forward pass
|
// Forward pass
|
||||||
let (output, all_hidden_states, all_attentions) = no_grad(|| {
|
let (output, all_hidden_states, all_attentions) =
|
||||||
roberta_model
|
no_grad(|| roberta_model.forward_t(Some(input_tensor), None, None, None, None, false));
|
||||||
.forward_t(Some(input_tensor),
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
false)
|
|
||||||
});
|
|
||||||
|
|
||||||
assert_eq!(output.size(), &[2, 3]);
|
assert_eq!(output.size(), &[2, 3]);
|
||||||
assert_eq!(config.num_hidden_layers as usize, all_hidden_states.unwrap().len());
|
assert_eq!(
|
||||||
assert_eq!(config.num_hidden_layers as usize, all_attentions.unwrap().len());
|
config.num_hidden_layers as usize,
|
||||||
|
all_hidden_states.unwrap().len()
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
config.num_hidden_layers as usize,
|
||||||
|
all_attentions.unwrap().len()
|
||||||
|
);
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@ -136,52 +175,70 @@ fn roberta_for_sequence_classification() -> failure::Fallible<()> {
|
|||||||
#[test]
|
#[test]
|
||||||
fn roberta_for_multiple_choice() -> failure::Fallible<()> {
|
fn roberta_for_multiple_choice() -> failure::Fallible<()> {
|
||||||
// Resources paths
|
// Resources paths
|
||||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaConfigResources::ROBERTA));
|
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaVocabResources::ROBERTA));
|
RobertaConfigResources::ROBERTA,
|
||||||
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaMergesResources::ROBERTA));
|
));
|
||||||
|
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||||
|
RobertaVocabResources::ROBERTA,
|
||||||
|
));
|
||||||
|
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||||
|
RobertaMergesResources::ROBERTA,
|
||||||
|
));
|
||||||
let config_path = download_resource(&config_resource)?;
|
let config_path = download_resource(&config_resource)?;
|
||||||
let vocab_path = download_resource(&vocab_resource)?;
|
let vocab_path = download_resource(&vocab_resource)?;
|
||||||
let merges_path = download_resource(&merges_resource)?;
|
let merges_path = download_resource(&merges_resource)?;
|
||||||
|
|
||||||
// Set-up model
|
// Set-up model
|
||||||
let device = Device::Cpu;
|
let device = Device::Cpu;
|
||||||
let vs = nn::VarStore::new(device);
|
let vs = nn::VarStore::new(device);
|
||||||
let tokenizer: RobertaTokenizer = RobertaTokenizer::from_file(vocab_path.to_str().unwrap(), merges_path.to_str().unwrap(), true);
|
let tokenizer: RobertaTokenizer = RobertaTokenizer::from_file(
|
||||||
|
vocab_path.to_str().unwrap(),
|
||||||
|
merges_path.to_str().unwrap(),
|
||||||
|
true,
|
||||||
|
);
|
||||||
let mut config = BertConfig::from_file(config_path);
|
let mut config = BertConfig::from_file(config_path);
|
||||||
config.output_attentions = Some(true);
|
config.output_attentions = Some(true);
|
||||||
config.output_hidden_states = Some(true);
|
config.output_hidden_states = Some(true);
|
||||||
let roberta_model = RobertaForMultipleChoice::new(&vs.root(), &config);
|
let roberta_model = RobertaForMultipleChoice::new(&vs.root(), &config);
|
||||||
|
|
||||||
|
// Define input
|
||||||
// Define input
|
let input = [
|
||||||
let input = ["Looks like one thing is missing", "It\'s like comparing oranges to apples"];
|
"Looks like one thing is missing",
|
||||||
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
"It\'s like comparing oranges to apples",
|
||||||
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
|
];
|
||||||
let tokenized_input = tokenized_input.
|
let tokenized_input =
|
||||||
iter().
|
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
||||||
map(|input| input.token_ids.clone()).
|
let max_len = tokenized_input
|
||||||
map(|mut input| {
|
.iter()
|
||||||
|
.map(|input| input.token_ids.len())
|
||||||
|
.max()
|
||||||
|
.unwrap();
|
||||||
|
let tokenized_input = tokenized_input
|
||||||
|
.iter()
|
||||||
|
.map(|input| input.token_ids.clone())
|
||||||
|
.map(|mut input| {
|
||||||
input.extend(vec![0; max_len - input.len()]);
|
input.extend(vec![0; max_len - input.len()]);
|
||||||
input
|
input
|
||||||
}).
|
})
|
||||||
map(|input|
|
.map(|input| Tensor::of_slice(&(input)))
|
||||||
Tensor::of_slice(&(input))).
|
.collect::<Vec<_>>();
|
||||||
collect::<Vec<_>>();
|
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0)
|
||||||
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device).unsqueeze(0);
|
.to(device)
|
||||||
|
.unsqueeze(0);
|
||||||
|
|
||||||
// Forward pass
|
// Forward pass
|
||||||
let (output, all_hidden_states, all_attentions) = no_grad(|| {
|
let (output, all_hidden_states, all_attentions) =
|
||||||
roberta_model
|
no_grad(|| roberta_model.forward_t(input_tensor, None, None, None, false));
|
||||||
.forward_t(input_tensor,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
false)
|
|
||||||
});
|
|
||||||
|
|
||||||
assert_eq!(output.size(), &[1, 2]);
|
assert_eq!(output.size(), &[1, 2]);
|
||||||
assert_eq!(config.num_hidden_layers as usize, all_hidden_states.unwrap().len());
|
assert_eq!(
|
||||||
assert_eq!(config.num_hidden_layers as usize, all_attentions.unwrap().len());
|
config.num_hidden_layers as usize,
|
||||||
|
all_hidden_states.unwrap().len()
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
config.num_hidden_layers as usize,
|
||||||
|
all_attentions.unwrap().len()
|
||||||
|
);
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@ -189,17 +246,27 @@ fn roberta_for_multiple_choice() -> failure::Fallible<()> {
|
|||||||
#[test]
|
#[test]
|
||||||
fn roberta_for_token_classification() -> failure::Fallible<()> {
|
fn roberta_for_token_classification() -> failure::Fallible<()> {
|
||||||
// Resources paths
|
// Resources paths
|
||||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaConfigResources::ROBERTA));
|
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaVocabResources::ROBERTA));
|
RobertaConfigResources::ROBERTA,
|
||||||
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaMergesResources::ROBERTA));
|
));
|
||||||
|
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||||
|
RobertaVocabResources::ROBERTA,
|
||||||
|
));
|
||||||
|
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||||
|
RobertaMergesResources::ROBERTA,
|
||||||
|
));
|
||||||
let config_path = download_resource(&config_resource)?;
|
let config_path = download_resource(&config_resource)?;
|
||||||
let vocab_path = download_resource(&vocab_resource)?;
|
let vocab_path = download_resource(&vocab_resource)?;
|
||||||
let merges_path = download_resource(&merges_resource)?;
|
let merges_path = download_resource(&merges_resource)?;
|
||||||
|
|
||||||
// Set-up model
|
// Set-up model
|
||||||
let device = Device::Cpu;
|
let device = Device::Cpu;
|
||||||
let vs = nn::VarStore::new(device);
|
let vs = nn::VarStore::new(device);
|
||||||
let tokenizer: RobertaTokenizer = RobertaTokenizer::from_file(vocab_path.to_str().unwrap(), merges_path.to_str().unwrap(), true);
|
let tokenizer: RobertaTokenizer = RobertaTokenizer::from_file(
|
||||||
|
vocab_path.to_str().unwrap(),
|
||||||
|
merges_path.to_str().unwrap(),
|
||||||
|
true,
|
||||||
|
);
|
||||||
let mut config = BertConfig::from_file(config_path);
|
let mut config = BertConfig::from_file(config_path);
|
||||||
let mut dummy_label_mapping = HashMap::new();
|
let mut dummy_label_mapping = HashMap::new();
|
||||||
dummy_label_mapping.insert(0, String::from("O"));
|
dummy_label_mapping.insert(0, String::from("O"));
|
||||||
@ -211,55 +278,70 @@ fn roberta_for_token_classification() -> failure::Fallible<()> {
|
|||||||
config.output_hidden_states = Some(true);
|
config.output_hidden_states = Some(true);
|
||||||
let roberta_model = RobertaForTokenClassification::new(&vs.root(), &config);
|
let roberta_model = RobertaForTokenClassification::new(&vs.root(), &config);
|
||||||
|
|
||||||
// Define input
|
// Define input
|
||||||
let input = ["Looks like one thing is missing", "It\'s like comparing oranges to apples"];
|
let input = [
|
||||||
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
"Looks like one thing is missing",
|
||||||
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
|
"It\'s like comparing oranges to apples",
|
||||||
let tokenized_input = tokenized_input.
|
];
|
||||||
iter().
|
let tokenized_input =
|
||||||
map(|input| input.token_ids.clone()).
|
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
||||||
map(|mut input| {
|
let max_len = tokenized_input
|
||||||
|
.iter()
|
||||||
|
.map(|input| input.token_ids.len())
|
||||||
|
.max()
|
||||||
|
.unwrap();
|
||||||
|
let tokenized_input = tokenized_input
|
||||||
|
.iter()
|
||||||
|
.map(|input| input.token_ids.clone())
|
||||||
|
.map(|mut input| {
|
||||||
input.extend(vec![0; max_len - input.len()]);
|
input.extend(vec![0; max_len - input.len()]);
|
||||||
input
|
input
|
||||||
}).
|
})
|
||||||
map(|input|
|
.map(|input| Tensor::of_slice(&(input)))
|
||||||
Tensor::of_slice(&(input))).
|
.collect::<Vec<_>>();
|
||||||
collect::<Vec<_>>();
|
|
||||||
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
|
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
|
||||||
|
|
||||||
// Forward pass
|
// Forward pass
|
||||||
let (output, all_hidden_states, all_attentions) = no_grad(|| {
|
let (output, all_hidden_states, all_attentions) =
|
||||||
roberta_model
|
no_grad(|| roberta_model.forward_t(Some(input_tensor), None, None, None, None, false));
|
||||||
.forward_t(Some(input_tensor),
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
false)
|
|
||||||
});
|
|
||||||
|
|
||||||
assert_eq!(output.size(), &[2, 9, 4]);
|
assert_eq!(output.size(), &[2, 9, 4]);
|
||||||
assert_eq!(config.num_hidden_layers as usize, all_hidden_states.unwrap().len());
|
assert_eq!(
|
||||||
assert_eq!(config.num_hidden_layers as usize, all_attentions.unwrap().len());
|
config.num_hidden_layers as usize,
|
||||||
|
all_hidden_states.unwrap().len()
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
config.num_hidden_layers as usize,
|
||||||
|
all_attentions.unwrap().len()
|
||||||
|
);
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn roberta_for_question_answering() -> failure::Fallible<()> {
|
fn roberta_for_question_answering() -> failure::Fallible<()> {
|
||||||
// Resources paths
|
// Resources paths
|
||||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaConfigResources::ROBERTA));
|
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaVocabResources::ROBERTA));
|
RobertaConfigResources::ROBERTA,
|
||||||
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(RobertaMergesResources::ROBERTA));
|
));
|
||||||
|
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||||
|
RobertaVocabResources::ROBERTA,
|
||||||
|
));
|
||||||
|
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||||
|
RobertaMergesResources::ROBERTA,
|
||||||
|
));
|
||||||
let config_path = download_resource(&config_resource)?;
|
let config_path = download_resource(&config_resource)?;
|
||||||
let vocab_path = download_resource(&vocab_resource)?;
|
let vocab_path = download_resource(&vocab_resource)?;
|
||||||
let merges_path = download_resource(&merges_resource)?;
|
let merges_path = download_resource(&merges_resource)?;
|
||||||
|
|
||||||
// Set-up model
|
// Set-up model
|
||||||
let device = Device::Cpu;
|
let device = Device::Cpu;
|
||||||
let vs = nn::VarStore::new(device);
|
let vs = nn::VarStore::new(device);
|
||||||
let tokenizer: RobertaTokenizer = RobertaTokenizer::from_file(vocab_path.to_str().unwrap(), merges_path.to_str().unwrap(), true);
|
let tokenizer: RobertaTokenizer = RobertaTokenizer::from_file(
|
||||||
|
vocab_path.to_str().unwrap(),
|
||||||
|
merges_path.to_str().unwrap(),
|
||||||
|
true,
|
||||||
|
);
|
||||||
let mut config = BertConfig::from_file(config_path);
|
let mut config = BertConfig::from_file(config_path);
|
||||||
let mut dummy_label_mapping = HashMap::new();
|
let mut dummy_label_mapping = HashMap::new();
|
||||||
dummy_label_mapping.insert(0, String::from("Positive"));
|
dummy_label_mapping.insert(0, String::from("Positive"));
|
||||||
@ -270,38 +352,43 @@ fn roberta_for_question_answering() -> failure::Fallible<()> {
|
|||||||
config.output_hidden_states = Some(true);
|
config.output_hidden_states = Some(true);
|
||||||
let roberta_model = RobertaForQuestionAnswering::new(&vs.root(), &config);
|
let roberta_model = RobertaForQuestionAnswering::new(&vs.root(), &config);
|
||||||
|
|
||||||
|
// Define input
|
||||||
// Define input
|
let input = [
|
||||||
let input = ["Looks like one thing is missing", "It\'s like comparing oranges to apples"];
|
"Looks like one thing is missing",
|
||||||
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
"It\'s like comparing oranges to apples",
|
||||||
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
|
];
|
||||||
let tokenized_input = tokenized_input.
|
let tokenized_input =
|
||||||
iter().
|
tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
||||||
map(|input| input.token_ids.clone()).
|
let max_len = tokenized_input
|
||||||
map(|mut input| {
|
.iter()
|
||||||
|
.map(|input| input.token_ids.len())
|
||||||
|
.max()
|
||||||
|
.unwrap();
|
||||||
|
let tokenized_input = tokenized_input
|
||||||
|
.iter()
|
||||||
|
.map(|input| input.token_ids.clone())
|
||||||
|
.map(|mut input| {
|
||||||
input.extend(vec![0; max_len - input.len()]);
|
input.extend(vec![0; max_len - input.len()]);
|
||||||
input
|
input
|
||||||
}).
|
})
|
||||||
map(|input|
|
.map(|input| Tensor::of_slice(&(input)))
|
||||||
Tensor::of_slice(&(input))).
|
.collect::<Vec<_>>();
|
||||||
collect::<Vec<_>>();
|
|
||||||
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
|
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
|
||||||
|
|
||||||
// Forward pass
|
// Forward pass
|
||||||
let (start_scores, end_scores, all_hidden_states, all_attentions) = no_grad(|| {
|
let (start_scores, end_scores, all_hidden_states, all_attentions) =
|
||||||
roberta_model
|
no_grad(|| roberta_model.forward_t(Some(input_tensor), None, None, None, None, false));
|
||||||
.forward_t(Some(input_tensor),
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
false)
|
|
||||||
});
|
|
||||||
|
|
||||||
assert_eq!(start_scores.size(), &[2, 9]);
|
assert_eq!(start_scores.size(), &[2, 9]);
|
||||||
assert_eq!(end_scores.size(), &[2, 9]);
|
assert_eq!(end_scores.size(), &[2, 9]);
|
||||||
assert_eq!(config.num_hidden_layers as usize, all_hidden_states.unwrap().len());
|
assert_eq!(
|
||||||
assert_eq!(config.num_hidden_layers as usize, all_attentions.unwrap().len());
|
config.num_hidden_layers as usize,
|
||||||
|
all_hidden_states.unwrap().len()
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
config.num_hidden_layers as usize,
|
||||||
|
all_attentions.unwrap().len()
|
||||||
|
);
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user