mirror of
https://github.com/guillaume-be/rust-bert.git
synced 2024-09-20 09:08:24 +03:00
Addition of RoBERTa for sequence classification and multiple choices
This commit is contained in:
parent
78cacdaf2e
commit
3607aa505f
@ -17,6 +17,7 @@ use crate::common::linear::{linear_no_bias, LinearNoBias};
|
||||
use tch::nn::Init;
|
||||
use crate::common::activations::_gelu;
|
||||
use crate::roberta::embeddings::RobertaEmbeddings;
|
||||
use crate::common::dropout::Dropout;
|
||||
|
||||
pub struct RobertaLMHead {
|
||||
dense: nn::Linear,
|
||||
@ -69,4 +70,104 @@ impl RobertaForMaskedLM {
|
||||
let prediction_scores = self.lm_head.forward(&hidden_state);
|
||||
(prediction_scores, all_hidden_states, all_attentions)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct RobertaClassificationHead {
|
||||
dense: nn::Linear,
|
||||
dropout: Dropout,
|
||||
out_proj: nn::Linear,
|
||||
}
|
||||
|
||||
impl RobertaClassificationHead {
|
||||
pub fn new(p: &nn::Path, config: &BertConfig) -> RobertaClassificationHead {
|
||||
let dense = nn::linear(p / "dense", config.hidden_size, config.hidden_size, Default::default());
|
||||
let num_labels = config.num_labels.expect("num_labels not provided in configuration");
|
||||
let out_proj = nn::linear(p / "out_proj", config.hidden_size, num_labels, Default::default());
|
||||
let dropout = Dropout::new(config.hidden_dropout_prob);
|
||||
|
||||
RobertaClassificationHead { dense, dropout, out_proj }
|
||||
}
|
||||
|
||||
pub fn forward_t(&self, hidden_states: &Tensor, train: bool) -> Tensor {
|
||||
hidden_states
|
||||
.select(1, 0)
|
||||
.apply_t(&self.dropout, train)
|
||||
.apply(&self.dense)
|
||||
.tanh()
|
||||
.apply_t(&self.dropout, train)
|
||||
.apply(&self.out_proj)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct RobertaForSequenceClassification {
|
||||
roberta: BertModel<RobertaEmbeddings>,
|
||||
classifier: RobertaClassificationHead,
|
||||
}
|
||||
|
||||
impl RobertaForSequenceClassification {
|
||||
pub fn new(p: &nn::Path, config: &BertConfig) -> RobertaForSequenceClassification {
|
||||
let roberta = BertModel::<RobertaEmbeddings>::new(&(p / "roberta"), config);
|
||||
let classifier = RobertaClassificationHead::new(&(p / "classifier"), config);
|
||||
|
||||
RobertaForSequenceClassification { roberta, classifier }
|
||||
}
|
||||
|
||||
pub fn forward_t(&self,
|
||||
input_ids: Option<Tensor>,
|
||||
mask: Option<Tensor>,
|
||||
token_type_ids: Option<Tensor>,
|
||||
position_ids: Option<Tensor>,
|
||||
input_embeds: Option<Tensor>,
|
||||
train: bool) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
|
||||
let (hidden_state, _, all_hidden_states, all_attentions) = self.roberta.forward_t(input_ids, mask, token_type_ids, position_ids,
|
||||
input_embeds, &None, &None, train).unwrap();
|
||||
|
||||
let output = self.classifier.forward_t(&hidden_state, train);
|
||||
(output, all_hidden_states, all_attentions)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct RobertaForMultipleChoice {
|
||||
roberta: BertModel<RobertaEmbeddings>,
|
||||
dropout: Dropout,
|
||||
classifier: nn::Linear,
|
||||
}
|
||||
|
||||
impl RobertaForMultipleChoice {
|
||||
pub fn new(p: &nn::Path, config: &BertConfig) -> RobertaForMultipleChoice {
|
||||
let roberta = BertModel::<RobertaEmbeddings>::new(&(p / "roberta"), config);
|
||||
let dropout = Dropout::new(config.hidden_dropout_prob);
|
||||
let classifier = nn::linear(p / "classifier", config.hidden_size, 1, Default::default());
|
||||
|
||||
RobertaForMultipleChoice { roberta, dropout, classifier }
|
||||
}
|
||||
|
||||
pub fn forward_t(&self,
|
||||
input_ids: Tensor,
|
||||
mask: Option<Tensor>,
|
||||
token_type_ids: Option<Tensor>,
|
||||
position_ids: Option<Tensor>,
|
||||
train: bool) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
|
||||
let num_choices = input_ids.size()[1];
|
||||
|
||||
let flat_input_ids = Some(input_ids.view((-1i64, *input_ids.size().last().unwrap())));
|
||||
let flat_position_ids = match position_ids {
|
||||
Some(value) => Some(value.view((-1i64, *value.size().last().unwrap()))),
|
||||
None => None
|
||||
};
|
||||
let flat_token_type_ids = match token_type_ids {
|
||||
Some(value) => Some(value.view((-1i64, *value.size().last().unwrap()))),
|
||||
None => None
|
||||
};
|
||||
let flat_mask = match mask {
|
||||
Some(value) => Some(value.view((-1i64, *value.size().last().unwrap()))),
|
||||
None => None
|
||||
};
|
||||
|
||||
let (_, pooled_output, all_hidden_states, all_attentions) = self.roberta.forward_t(flat_input_ids, flat_mask, flat_token_type_ids, flat_position_ids,
|
||||
None, &None, &None, train).unwrap();
|
||||
|
||||
let output = pooled_output.apply_t(&self.dropout, train).apply(&self.classifier).view((-1, num_choices));
|
||||
(output, all_hidden_states, all_attentions)
|
||||
}
|
||||
}
|
116
tests/roberta.rs
116
tests/roberta.rs
@ -2,7 +2,7 @@ use std::path::PathBuf;
|
||||
use tch::{Device, nn, Tensor, no_grad};
|
||||
use rust_tokenizers::{RobertaTokenizer, TruncationStrategy, Tokenizer, Vocab};
|
||||
use rust_bert::BertConfig;
|
||||
use rust_bert::roberta::roberta::RobertaForMaskedLM;
|
||||
use rust_bert::roberta::roberta::{RobertaForMaskedLM, RobertaForSequenceClassification, RobertaForMultipleChoice};
|
||||
use rust_bert::common::config::Config;
|
||||
|
||||
#[test]
|
||||
@ -21,7 +21,7 @@ fn bert_masked_lm() -> failure::Fallible<()> {
|
||||
let mut vs = nn::VarStore::new(device);
|
||||
let tokenizer: RobertaTokenizer = RobertaTokenizer::from_file(vocab_path.to_str().unwrap(), merges_path.to_str().unwrap());
|
||||
let config = BertConfig::from_file(config_path);
|
||||
let bert_model = RobertaForMaskedLM::new(&vs.root(), &config);
|
||||
let roberta_model = RobertaForMaskedLM::new(&vs.root(), &config);
|
||||
vs.load(weights_path)?;
|
||||
|
||||
// Define input
|
||||
@ -49,7 +49,7 @@ fn bert_masked_lm() -> failure::Fallible<()> {
|
||||
|
||||
// Forward pass
|
||||
let (output, _, _) = no_grad(|| {
|
||||
bert_model
|
||||
roberta_model
|
||||
.forward_t(Some(input_tensor),
|
||||
None,
|
||||
None,
|
||||
@ -69,5 +69,115 @@ fn bert_masked_lm() -> failure::Fallible<()> {
|
||||
assert_eq!("Ġsome", word_1); // Outputs "person" : "Looks like [some] thing is missing"
|
||||
assert_eq!("Ġapples", word_2);// Outputs "pear" : "It\'s like comparing [apples] to apples"
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn roberta_for_sequence_classification() -> failure::Fallible<()> {
|
||||
// Resources paths
|
||||
let mut home: PathBuf = dirs::home_dir().unwrap();
|
||||
home.push("rustbert");
|
||||
home.push("roberta");
|
||||
let config_path = &home.as_path().join("config.json");
|
||||
let vocab_path = &home.as_path().join("vocab.txt");
|
||||
let merges_path = &home.as_path().join("merges.txt");
|
||||
|
||||
|
||||
// Set-up model
|
||||
let device = Device::Cpu;
|
||||
let vs = nn::VarStore::new(device);
|
||||
let tokenizer: RobertaTokenizer = RobertaTokenizer::from_file(vocab_path.to_str().unwrap(), merges_path.to_str().unwrap());
|
||||
let mut config = BertConfig::from_file(config_path);
|
||||
config.num_labels = Some(42);
|
||||
config.output_attentions = Some(true);
|
||||
config.output_hidden_states = Some(true);
|
||||
let roberta_model = RobertaForSequenceClassification::new(&vs.root(), &config);
|
||||
|
||||
|
||||
// Define input
|
||||
let input = ["Looks like one thing is missing", "It\'s like comparing oranges to apples"];
|
||||
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
||||
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
|
||||
let tokenized_input = tokenized_input.
|
||||
iter().
|
||||
map(|input| input.token_ids.clone()).
|
||||
map(|mut input| {
|
||||
input.extend(vec![0; max_len - input.len()]);
|
||||
input
|
||||
}).
|
||||
map(|input|
|
||||
Tensor::of_slice(&(input))).
|
||||
collect::<Vec<_>>();
|
||||
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
|
||||
|
||||
// Forward pass
|
||||
let (output, all_hidden_states, all_attentions) = no_grad(|| {
|
||||
roberta_model
|
||||
.forward_t(Some(input_tensor),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
false)
|
||||
});
|
||||
|
||||
assert_eq!(output.size(), &[2, 42]);
|
||||
assert_eq!(config.num_hidden_layers as usize, all_hidden_states.unwrap().len());
|
||||
assert_eq!(config.num_hidden_layers as usize, all_attentions.unwrap().len());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn roberta_for_multiple_choice() -> failure::Fallible<()> {
|
||||
// Resources paths
|
||||
let mut home: PathBuf = dirs::home_dir().unwrap();
|
||||
home.push("rustbert");
|
||||
home.push("roberta");
|
||||
let config_path = &home.as_path().join("config.json");
|
||||
let vocab_path = &home.as_path().join("vocab.txt");
|
||||
let merges_path = &home.as_path().join("merges.txt");
|
||||
|
||||
|
||||
// Set-up model
|
||||
let device = Device::Cpu;
|
||||
let vs = nn::VarStore::new(device);
|
||||
let tokenizer: RobertaTokenizer = RobertaTokenizer::from_file(vocab_path.to_str().unwrap(), merges_path.to_str().unwrap());
|
||||
let mut config = BertConfig::from_file(config_path);
|
||||
config.output_attentions = Some(true);
|
||||
config.output_hidden_states = Some(true);
|
||||
let roberta_model = RobertaForMultipleChoice::new(&vs.root(), &config);
|
||||
|
||||
|
||||
// Define input
|
||||
let input = ["Looks like one thing is missing", "It\'s like comparing oranges to apples"];
|
||||
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
|
||||
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
|
||||
let tokenized_input = tokenized_input.
|
||||
iter().
|
||||
map(|input| input.token_ids.clone()).
|
||||
map(|mut input| {
|
||||
input.extend(vec![0; max_len - input.len()]);
|
||||
input
|
||||
}).
|
||||
map(|input|
|
||||
Tensor::of_slice(&(input))).
|
||||
collect::<Vec<_>>();
|
||||
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device).unsqueeze(0);
|
||||
|
||||
// Forward pass
|
||||
let (output, all_hidden_states, all_attentions) = no_grad(|| {
|
||||
roberta_model
|
||||
.forward_t(input_tensor,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
false)
|
||||
});
|
||||
|
||||
assert_eq!(output.size(), &[1, 2]);
|
||||
assert_eq!(config.num_hidden_layers as usize, all_hidden_states.unwrap().len());
|
||||
assert_eq!(config.num_hidden_layers as usize, all_attentions.unwrap().len());
|
||||
|
||||
Ok(())
|
||||
}
|
Loading…
Reference in New Issue
Block a user