Addition of DistilBERT models

This commit is contained in:
Guillaume B 2020-02-18 19:13:46 +01:00
parent 8322d1877b
commit 5ff8eeb97f
3 changed files with 180 additions and 8 deletions

View File

@ -45,7 +45,7 @@ pub struct DistilBertConfig {
pub output_attentions: bool,
pub output_hidden_states: bool,
pub output_past: Option<bool>,
pub qa_dropout: f32,
pub qa_dropout: f64,
pub seq_classif_dropout: f64,
pub sinusoidal_pos_embds: bool,
pub tie_weights_: bool,
@ -121,10 +121,10 @@ impl DistilBertModelClassifier {
let output = output
.select(1, 0)
.apply_t(&self.pre_classifier, train)
.apply(&self.pre_classifier)
.relu()
.apply_t(&self.dropout, train)
.apply_t(&self.classifier, train);
.apply(&self.classifier);
Ok((output, all_hidden_states, all_attentions))
}
@ -156,11 +156,83 @@ impl DistilBertModelMaskedLM {
};
let output = output
.apply_t(&self.vocab_transform, train)
.apply(&self.vocab_transform)
.gelu()
.apply_t(&self.vocab_layer_norm, train)
.apply_t(&self.vocab_projector, train);
.apply(&self.vocab_layer_norm)
.apply(&self.vocab_projector);
Ok((output, all_hidden_states, all_attentions))
}
}
pub struct DistilBertForQuestionAnswering {
distil_bert_model: DistilBertModel,
qa_outputs: nn::Linear,
dropout: Dropout,
}
impl DistilBertForQuestionAnswering {
pub fn new(p: &nn::Path, config: &DistilBertConfig) -> DistilBertForQuestionAnswering {
let distil_bert_model = DistilBertModel::new(&p, config);
let qa_outputs = nn::linear(&(p / "qa_output"), config.dim, config.num_labels, Default::default());
assert_eq!(config.num_labels, 2, "num_labels should be set to 2 in the configuration provided");
let dropout = Dropout::new(config.qa_dropout);
DistilBertForQuestionAnswering { distil_bert_model, qa_outputs, dropout }
}
pub fn forward_t(&self,
input: Option<Tensor>,
mask: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool)
-> Result<(Tensor, Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
let (output, all_hidden_states, all_attentions) = match self.distil_bert_model.forward_t(input, mask, input_embeds, train) {
Ok(value) => value,
Err(err) => return Err(err)
};
let output = output
.apply_t(&self.dropout, train)
.apply(&self.qa_outputs);
let logits = output.split(1, -1);
let (start_logits, end_logits) = (&logits[0], &logits[1]);
let start_logits = start_logits.squeeze1(-1);
let end_logits = end_logits.squeeze1(-1);
Ok((start_logits, end_logits, all_hidden_states, all_attentions))
}
}
pub struct DistilBertForTokenClassification {
distil_bert_model: DistilBertModel,
classifier: nn::Linear,
dropout: Dropout,
}
impl DistilBertForTokenClassification {
pub fn new(p: &nn::Path, config: &DistilBertConfig) -> DistilBertForTokenClassification {
let distil_bert_model = DistilBertModel::new(&p, config);
let classifier = nn::linear(&(p / "classifier"), config.dim, config.num_labels, Default::default());
let dropout = Dropout::new(config.seq_classif_dropout);
DistilBertForTokenClassification { distil_bert_model, classifier, dropout }
}
pub fn forward_t(&self, input: Option<Tensor>, mask: Option<Tensor>, input_embeds: Option<Tensor>, train: bool)
-> Result<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
let (output, all_hidden_states, all_attentions) = match self.distil_bert_model.forward_t(input, mask, input_embeds, train) {
Ok(value) => value,
Err(err) => return Err(err)
};
let output = output
.apply_t(&self.dropout, train)
.apply(&self.classifier);
Ok((output, all_hidden_states, all_attentions))
}
}

View File

@ -5,4 +5,5 @@ pub mod common;
pub use distilbert::distilbert::{DistilBertConfig, DistilBertModel, DistilBertModelClassifier, DistilBertModelMaskedLM};
pub use distilbert::sentiment::{Sentiment, SentimentPolarity, SentimentClassifier};
pub use bert::bert::BertConfig;
pub use bert::bert::BertConfig;
pub use bert::bert::{BertModel, BertForSequenceClassification, BertForMaskedLM, BertForQuestionAnswering, BertForTokenClassification, BertForMultipleChoice};

View File

@ -1,6 +1,6 @@
use std::path::PathBuf;
use tch::{Device, Tensor, nn, no_grad};
use rust_bert::distilbert::distilbert::{DistilBertModelMaskedLM, DistilBertConfig};
use rust_bert::distilbert::distilbert::{DistilBertModelMaskedLM, DistilBertConfig, DistilBertForQuestionAnswering, DistilBertForTokenClassification};
use rust_tokenizers::preprocessing::tokenizer::base_tokenizer::{Tokenizer, TruncationStrategy};
use rust_tokenizers::bert_tokenizer::BertTokenizer;
use rust_tokenizers::preprocessing::vocab::base_vocab::Vocab;
@ -109,3 +109,102 @@ fn distilbert_masked_lm() -> failure::Fallible<()> {
Ok(())
}
#[test]
fn distilbert_for_question_answering() -> failure::Fallible<()> {
// Resources paths
let mut home: PathBuf = dirs::home_dir().unwrap();
home.push("rustbert");
home.push("distilbert");
let config_path = &home.as_path().join("config.json");
let vocab_path = &home.as_path().join("vocab.txt");
// Set-up masked LM model
let device = Device::cuda_if_available();
let vs = nn::VarStore::new(device);
let tokenizer: BertTokenizer = BertTokenizer::from_file(vocab_path.to_str().unwrap());
let mut config = DistilBertConfig::from_file(config_path);
config.output_attentions = true;
config.output_hidden_states = true;
let distil_bert_model = DistilBertForQuestionAnswering::new(&vs.root(), &config);
// Define input
let input = ["Looks like one thing is missing", "It\'s like comparing oranges to apples"];
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
let tokenized_input = tokenized_input.
iter().
map(|input| input.token_ids.clone()).
map(|mut input| {
input.extend(vec![0; max_len - input.len()]);
input
}).
map(|input|
Tensor::of_slice(&(input))).
collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
let (start_scores, end_scores, all_hidden_states, all_attentions) = no_grad(|| {
distil_bert_model
.forward_t(Some(input_tensor), None, None, false)
.unwrap()
});
assert_eq!(start_scores.size(), &[2, 11]);
assert_eq!(end_scores.size(), &[2, 11]);
assert_eq!(config.n_layers as usize, all_hidden_states.unwrap().len());
assert_eq!(config.n_layers as usize, all_attentions.unwrap().len());
Ok(())
}
#[test]
fn distilbert_for_token_classification() -> failure::Fallible<()> {
// Resources paths
let mut home: PathBuf = dirs::home_dir().unwrap();
home.push("rustbert");
home.push("distilbert");
let config_path = &home.as_path().join("config.json");
let vocab_path = &home.as_path().join("vocab.txt");
// Set-up masked LM model
let device = Device::cuda_if_available();
let vs = nn::VarStore::new(device);
let tokenizer: BertTokenizer = BertTokenizer::from_file(vocab_path.to_str().unwrap());
let mut config = DistilBertConfig::from_file(config_path);
config.output_attentions = true;
config.output_hidden_states = true;
let distil_bert_model = DistilBertForTokenClassification::new(&vs.root(), &config);
// Define input
let input = ["Looks like one thing is missing", "It\'s like comparing oranges to apples"];
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
let tokenized_input = tokenized_input.
iter().
map(|input| input.token_ids.clone()).
map(|mut input| {
input.extend(vec![0; max_len - input.len()]);
input
}).
map(|input|
Tensor::of_slice(&(input))).
collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
let (output, all_hidden_states, all_attentions) = no_grad(|| {
distil_bert_model
.forward_t(Some(input_tensor), None, None, false)
.unwrap()
});
assert_eq!(output.size(), &[2, 11, config.num_labels]);
assert_eq!(config.n_layers as usize, all_hidden_states.unwrap().len());
assert_eq!(config.n_layers as usize, all_attentions.unwrap().len());
Ok(())
}