mirror of
https://github.com/guillaume-be/rust-bert.git
synced 2024-09-20 09:08:24 +03:00
Addition of DistilBERT models
This commit is contained in:
parent
8322d1877b
commit
5ff8eeb97f
@ -45,7 +45,7 @@ pub struct DistilBertConfig {
|
||||
pub output_attentions: bool,
|
||||
pub output_hidden_states: bool,
|
||||
pub output_past: Option<bool>,
|
||||
pub qa_dropout: f32,
|
||||
pub qa_dropout: f64,
|
||||
pub seq_classif_dropout: f64,
|
||||
pub sinusoidal_pos_embds: bool,
|
||||
pub tie_weights_: bool,
|
||||
@ -121,10 +121,10 @@ impl DistilBertModelClassifier {
|
||||
|
||||
let output = output
|
||||
.select(1, 0)
|
||||
.apply_t(&self.pre_classifier, train)
|
||||
.apply(&self.pre_classifier)
|
||||
.relu()
|
||||
.apply_t(&self.dropout, train)
|
||||
.apply_t(&self.classifier, train);
|
||||
.apply(&self.classifier);
|
||||
|
||||
Ok((output, all_hidden_states, all_attentions))
|
||||
}
|
||||
@ -156,11 +156,83 @@ impl DistilBertModelMaskedLM {
|
||||
};
|
||||
|
||||
let output = output
|
||||
.apply_t(&self.vocab_transform, train)
|
||||
.apply(&self.vocab_transform)
|
||||
.gelu()
|
||||
.apply_t(&self.vocab_layer_norm, train)
|
||||
.apply_t(&self.vocab_projector, train);
|
||||
.apply(&self.vocab_layer_norm)
|
||||
.apply(&self.vocab_projector);
|
||||
|
||||
Ok((output, all_hidden_states, all_attentions))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
pub struct DistilBertForQuestionAnswering {
|
||||
distil_bert_model: DistilBertModel,
|
||||
qa_outputs: nn::Linear,
|
||||
dropout: Dropout,
|
||||
}
|
||||
|
||||
impl DistilBertForQuestionAnswering {
|
||||
pub fn new(p: &nn::Path, config: &DistilBertConfig) -> DistilBertForQuestionAnswering {
|
||||
let distil_bert_model = DistilBertModel::new(&p, config);
|
||||
let qa_outputs = nn::linear(&(p / "qa_output"), config.dim, config.num_labels, Default::default());
|
||||
assert_eq!(config.num_labels, 2, "num_labels should be set to 2 in the configuration provided");
|
||||
let dropout = Dropout::new(config.qa_dropout);
|
||||
|
||||
DistilBertForQuestionAnswering { distil_bert_model, qa_outputs, dropout }
|
||||
}
|
||||
|
||||
pub fn forward_t(&self,
|
||||
input: Option<Tensor>,
|
||||
mask: Option<Tensor>,
|
||||
input_embeds: Option<Tensor>,
|
||||
train: bool)
|
||||
-> Result<(Tensor, Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
|
||||
let (output, all_hidden_states, all_attentions) = match self.distil_bert_model.forward_t(input, mask, input_embeds, train) {
|
||||
Ok(value) => value,
|
||||
Err(err) => return Err(err)
|
||||
};
|
||||
|
||||
let output = output
|
||||
.apply_t(&self.dropout, train)
|
||||
.apply(&self.qa_outputs);
|
||||
|
||||
let logits = output.split(1, -1);
|
||||
let (start_logits, end_logits) = (&logits[0], &logits[1]);
|
||||
let start_logits = start_logits.squeeze1(-1);
|
||||
let end_logits = end_logits.squeeze1(-1);
|
||||
|
||||
|
||||
Ok((start_logits, end_logits, all_hidden_states, all_attentions))
|
||||
}
|
||||
}
|
||||
|
||||
pub struct DistilBertForTokenClassification {
|
||||
distil_bert_model: DistilBertModel,
|
||||
classifier: nn::Linear,
|
||||
dropout: Dropout,
|
||||
}
|
||||
|
||||
impl DistilBertForTokenClassification {
|
||||
pub fn new(p: &nn::Path, config: &DistilBertConfig) -> DistilBertForTokenClassification {
|
||||
let distil_bert_model = DistilBertModel::new(&p, config);
|
||||
let classifier = nn::linear(&(p / "classifier"), config.dim, config.num_labels, Default::default());
|
||||
let dropout = Dropout::new(config.seq_classif_dropout);
|
||||
|
||||
DistilBertForTokenClassification { distil_bert_model, classifier, dropout }
|
||||
}
|
||||
|
||||
pub fn forward_t(&self, input: Option<Tensor>, mask: Option<Tensor>, input_embeds: Option<Tensor>, train: bool)
|
||||
-> Result<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
|
||||
let (output, all_hidden_states, all_attentions) = match self.distil_bert_model.forward_t(input, mask, input_embeds, train) {
|
||||
Ok(value) => value,
|
||||
Err(err) => return Err(err)
|
||||
};
|
||||
|
||||
let output = output
|
||||
.apply_t(&self.dropout, train)
|
||||
.apply(&self.classifier);
|
||||
|
||||
Ok((output, all_hidden_states, all_attentions))
|
||||
}
|
||||
}
|
@ -5,4 +5,5 @@ pub mod common;
|
||||
pub use distilbert::distilbert::{DistilBertConfig, DistilBertModel, DistilBertModelClassifier, DistilBertModelMaskedLM};
|
||||
pub use distilbert::sentiment::{Sentiment, SentimentPolarity, SentimentClassifier};
|
||||
|
||||
pub use bert::bert::BertConfig;
|
||||
pub use bert::bert::BertConfig;
|
||||
pub use bert::bert::{BertModel, BertForSequenceClassification, BertForMaskedLM, BertForQuestionAnswering, BertForTokenClassification, BertForMultipleChoice};
|
@ -1,6 +1,6 @@
|
||||
use std::path::PathBuf;
|
||||
use tch::{Device, Tensor, nn, no_grad};
|
||||
use rust_bert::distilbert::distilbert::{DistilBertModelMaskedLM, DistilBertConfig};
|
||||
use rust_bert::distilbert::distilbert::{DistilBertModelMaskedLM, DistilBertConfig, DistilBertForQuestionAnswering, DistilBertForTokenClassification};
|
||||
use rust_tokenizers::preprocessing::tokenizer::base_tokenizer::{Tokenizer, TruncationStrategy};
|
||||
use rust_tokenizers::bert_tokenizer::BertTokenizer;
|
||||
use rust_tokenizers::preprocessing::vocab::base_vocab::Vocab;
|
||||
@ -109,3 +109,102 @@ fn distilbert_masked_lm() -> failure::Fallible<()> {
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn distilbert_for_question_answering() -> failure::Fallible<()> {
|
||||
|
||||
// Resources paths
|
||||
let mut home: PathBuf = dirs::home_dir().unwrap();
|
||||
home.push("rustbert");
|
||||
home.push("distilbert");
|
||||
let config_path = &home.as_path().join("config.json");
|
||||
let vocab_path = &home.as_path().join("vocab.txt");
|
||||
|
||||
// Set-up masked LM model
|
||||
let device = Device::cuda_if_available();
|
||||
let vs = nn::VarStore::new(device);
|
||||
let tokenizer: BertTokenizer = BertTokenizer::from_file(vocab_path.to_str().unwrap());
|
||||
let mut config = DistilBertConfig::from_file(config_path);
|
||||
config.output_attentions = true;
|
||||
config.output_hidden_states = true;
|
||||
let distil_bert_model = DistilBertForQuestionAnswering::new(&vs.root(), &config);
|
||||
|
||||
// Define input
|
||||
let input = ["Looks like one thing is missing", "It\'s like comparing oranges to apples"];
|
||||
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
||||
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
|
||||
let tokenized_input = tokenized_input.
|
||||
iter().
|
||||
map(|input| input.token_ids.clone()).
|
||||
map(|mut input| {
|
||||
input.extend(vec![0; max_len - input.len()]);
|
||||
input
|
||||
}).
|
||||
map(|input|
|
||||
Tensor::of_slice(&(input))).
|
||||
collect::<Vec<_>>();
|
||||
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
|
||||
|
||||
// Forward pass
|
||||
let (start_scores, end_scores, all_hidden_states, all_attentions) = no_grad(|| {
|
||||
distil_bert_model
|
||||
.forward_t(Some(input_tensor), None, None, false)
|
||||
.unwrap()
|
||||
});
|
||||
|
||||
assert_eq!(start_scores.size(), &[2, 11]);
|
||||
assert_eq!(end_scores.size(), &[2, 11]);
|
||||
assert_eq!(config.n_layers as usize, all_hidden_states.unwrap().len());
|
||||
assert_eq!(config.n_layers as usize, all_attentions.unwrap().len());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn distilbert_for_token_classification() -> failure::Fallible<()> {
|
||||
|
||||
// Resources paths
|
||||
let mut home: PathBuf = dirs::home_dir().unwrap();
|
||||
home.push("rustbert");
|
||||
home.push("distilbert");
|
||||
let config_path = &home.as_path().join("config.json");
|
||||
let vocab_path = &home.as_path().join("vocab.txt");
|
||||
|
||||
// Set-up masked LM model
|
||||
let device = Device::cuda_if_available();
|
||||
let vs = nn::VarStore::new(device);
|
||||
let tokenizer: BertTokenizer = BertTokenizer::from_file(vocab_path.to_str().unwrap());
|
||||
let mut config = DistilBertConfig::from_file(config_path);
|
||||
config.output_attentions = true;
|
||||
config.output_hidden_states = true;
|
||||
let distil_bert_model = DistilBertForTokenClassification::new(&vs.root(), &config);
|
||||
|
||||
// Define input
|
||||
let input = ["Looks like one thing is missing", "It\'s like comparing oranges to apples"];
|
||||
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
||||
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
|
||||
let tokenized_input = tokenized_input.
|
||||
iter().
|
||||
map(|input| input.token_ids.clone()).
|
||||
map(|mut input| {
|
||||
input.extend(vec![0; max_len - input.len()]);
|
||||
input
|
||||
}).
|
||||
map(|input|
|
||||
Tensor::of_slice(&(input))).
|
||||
collect::<Vec<_>>();
|
||||
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
|
||||
|
||||
// Forward pass
|
||||
let (output, all_hidden_states, all_attentions) = no_grad(|| {
|
||||
distil_bert_model
|
||||
.forward_t(Some(input_tensor), None, None, false)
|
||||
.unwrap()
|
||||
});
|
||||
|
||||
assert_eq!(output.size(), &[2, 11, config.num_labels]);
|
||||
assert_eq!(config.n_layers as usize, all_hidden_states.unwrap().len());
|
||||
assert_eq!(config.n_layers as usize, all_attentions.unwrap().len());
|
||||
|
||||
Ok(())
|
||||
}
|
Loading…
Reference in New Issue
Block a user