Longformer integration tests

This commit is contained in:
Guillaume B 2021-02-16 10:16:09 +01:00
parent 9fab2d9e10
commit 545d52ec9d
9 changed files with 493 additions and 36 deletions

View File

@ -28,14 +28,10 @@ fn main() -> anyhow::Result<()> {
BertConfigResources::BERT_QA,
)),
Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT_QA)),
None, //merges resource only relevant with ModelType::Roberta
false, //lowercase
None, //merges resource only relevant with ModelType::Roberta
false,
false,
None,
384,
128,
64,
15,
);
let qa_model = QuestionAnsweringModel::new(config)?;

View File

@ -41,10 +41,6 @@ fn main() -> anyhow::Result<()> {
false,
None,
false,
384,
128,
64,
15,
);
let qa_model = QuestionAnsweringModel::new(config)?;

View File

@ -677,8 +677,8 @@ impl LongformerSelfAttention {
.compute_global_attention_output_from_hidden(
&hidden_states,
max_num_global_attention_indices.unwrap(),
is_local_index_global_attention_nonzero.as_ref().unwrap(),
is_index_global_attn_nonzero.as_ref().unwrap(),
is_local_index_global_attention_nonzero.as_ref().unwrap(),
is_local_index_no_global_attention_nonzero.as_ref().unwrap(),
is_index_masked,
train,

View File

@ -959,7 +959,7 @@ impl LongformerForMultipleChoice {
false,
));
}
Some(Tensor::cat(masks.as_slice(), 1))
Some(Tensor::stack(masks.as_slice(), 1))
} else {
return Err(RustBertError::ValueError(
"Inputs ids must be provided to LongformerQuestionAnsweringOutput if the global_attention_mask is not given".into(),
@ -969,10 +969,14 @@ impl LongformerForMultipleChoice {
None
};
let flat_input_ids = input_ids.map(|tensor| tensor.view((-1, tensor.size()[1])));
let flat_attention_mask = attention_mask.map(|tensor| tensor.view((-1, tensor.size()[1])));
let flat_token_type_ids = token_type_ids.map(|tensor| tensor.view((-1, tensor.size()[1])));
let flat_position_ids = position_ids.map(|tensor| tensor.view((-1, tensor.size()[1])));
let flat_input_ids =
input_ids.map(|tensor| tensor.view((-1, *tensor.size().last().unwrap())));
let flat_attention_mask =
attention_mask.map(|tensor| tensor.view((-1, *tensor.size().last().unwrap())));
let flat_token_type_ids =
token_type_ids.map(|tensor| tensor.view((-1, *tensor.size().last().unwrap())));
let flat_position_ids =
position_ids.map(|tensor| tensor.view((-1, *tensor.size().last().unwrap())));
let flat_input_embeds =
input_embeds.map(|tensor| tensor.view((-1, tensor.size()[1], tensor.size()[2])));
@ -981,8 +985,8 @@ impl LongformerForMultipleChoice {
} else {
calc_global_attention_mask.as_ref()
};
let flat_global_attention_mask = global_attention_mask
.map(|tensor| tensor.view((-1, tensor.size()[1], tensor.size()[2])));
let flat_global_attention_mask =
global_attention_mask.map(|tensor| tensor.view((-1, *tensor.size().last().unwrap())));
let base_model_output = self.longformer.forward_t(
flat_input_ids.as_ref(),

View File

@ -144,7 +144,7 @@ pub struct QuestionAnsweringConfig {
/// Maximum length for the query
pub max_query_length: usize,
/// Maximum length for the answer
pub max_answer_len: usize,
pub max_answer_length: usize,
}
impl QuestionAnsweringConfig {
@ -167,10 +167,51 @@ impl QuestionAnsweringConfig {
lower_case: bool,
strip_accents: impl Into<Option<bool>>,
add_prefix_space: impl Into<Option<bool>>,
) -> QuestionAnsweringConfig {
QuestionAnsweringConfig {
model_type,
model_resource,
config_resource,
vocab_resource,
merges_resource,
lower_case,
strip_accents: strip_accents.into(),
add_prefix_space: add_prefix_space.into(),
device: Device::cuda_if_available(),
max_seq_length: 384,
doc_stride: 128,
max_query_length: 64,
max_answer_length: 15,
}
}
/// Instantiate a new question answering configuration of the supplied type.
///
/// # Arguments
///
/// * `model_type` - `ModelType` indicating the model type to load (must match with the actual data to be loaded!)
/// * model_resource - The `Resource` pointing to the model to load (e.g. model.ot)
/// * config_resource - The `Resource' pointing to the model configuration to load (e.g. config.json)
/// * vocab_resource - The `Resource' pointing to the tokenizer's vocabulary to load (e.g. vocab.txt/vocab.json)
/// * merges_resource - An optional `Resource` tuple (`Option<Resource>`) pointing to the tokenizer's merge file to load (e.g. merges.txt), needed only for Roberta.
/// * lower_case - A `bool' indicating whether the tokenizer should lower case all input (in case of a lower-cased model)
/// * max_seq_length - Optional maximum sequence token length to limit memory footprint. If the context is too long, it will be processed with sliding windows. Defaults to 384.
/// * max_query_length - Optional maximum question token length. Defaults to 64.
/// * doc_stride - Optional stride to apply if a sliding window is required to process the input context. Represents the number of overlapping tokens between sliding windows. This should be lower than the max_seq_length minus max_query_length (otherwise there is a risk for the sliding window not to progress). Defaults to 128.
/// * max_answer_length - Optional maximum token length for the extracted answer. Defaults to 15.
pub fn custom_new(
model_type: ModelType,
model_resource: Resource,
config_resource: Resource,
vocab_resource: Resource,
merges_resource: Option<Resource>,
lower_case: bool,
strip_accents: impl Into<Option<bool>>,
add_prefix_space: impl Into<Option<bool>>,
max_seq_length: impl Into<Option<usize>>,
doc_stride: impl Into<Option<usize>>,
max_query_length: impl Into<Option<usize>>,
max_answer_len: impl Into<Option<usize>>,
max_answer_length: impl Into<Option<usize>>,
) -> QuestionAnsweringConfig {
QuestionAnsweringConfig {
model_type,
@ -185,7 +226,7 @@ impl QuestionAnsweringConfig {
max_seq_length: max_seq_length.into().unwrap_or(384),
doc_stride: doc_stride.into().unwrap_or(128),
max_query_length: max_query_length.into().unwrap_or(64),
max_answer_len: max_answer_len.into().unwrap_or(15),
max_answer_length: max_answer_length.into().unwrap_or(15),
}
}
}
@ -211,7 +252,7 @@ impl Default for QuestionAnsweringConfig {
max_seq_length: 384,
doc_stride: 128,
max_query_length: 64,
max_answer_len: 15,
max_answer_length: 15,
}
}
}
@ -549,7 +590,7 @@ impl QuestionAnsweringModel {
max_seq_len: question_answering_config.max_seq_length,
doc_stride: question_answering_config.doc_stride,
max_query_length: question_answering_config.max_query_length,
max_answer_len: question_answering_config.max_answer_len,
max_answer_len: question_answering_config.max_answer_length,
qa_model,
var_store,
})

View File

@ -398,8 +398,8 @@ fn bert_question_answering() -> anyhow::Result<()> {
)),
Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT_QA)),
None, //merges resource only relevant with ModelType::Roberta
true,
true,
false,
false,
None,
);
@ -415,8 +415,8 @@ fn bert_question_answering() -> anyhow::Result<()> {
assert_eq!(answers.len(), 1usize);
assert_eq!(answers[0].len(), 1usize);
assert_eq!(answers[0][0].start, 13);
assert_eq!(answers[0][0].end, 21);
assert!((answers[0][0].score - 0.8111).abs() < 1e-4);
assert_eq!(answers[0][0].end, 22);
assert!((answers[0][0].score - 0.9806).abs() < 1e-4);
assert_eq!(answers[0][0].answer, "Amsterdam");
Ok(())

View File

@ -270,8 +270,8 @@ fn distilbert_question_answering() -> anyhow::Result<()> {
assert_eq!(answers.len(), 1usize);
assert_eq!(answers[0].len(), 1usize);
assert_eq!(answers[0][0].start, 13);
assert_eq!(answers[0][0].end, 21);
assert!((answers[0][0].score - 0.9977).abs() < 1e-4);
assert_eq!(answers[0][0].end, 22);
assert!((answers[0][0].score - 0.9978).abs() < 1e-4);
assert_eq!(answers[0][0].answer, "Amsterdam");
Ok(())

420
tests/longformer.rs Normal file
View File

@ -0,0 +1,420 @@
extern crate anyhow;
extern crate dirs;
use rust_bert::longformer::{
LongformerConfig, LongformerConfigResources, LongformerForMaskedLM,
LongformerForMultipleChoice, LongformerForSequenceClassification,
LongformerForTokenClassification, LongformerMergesResources, LongformerModelResources,
LongformerVocabResources,
};
use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::question_answering::{
QaInput, QuestionAnsweringConfig, QuestionAnsweringModel,
};
use rust_bert::resources::{RemoteResource, Resource};
use rust_bert::Config;
use rust_tokenizers::tokenizer::{MultiThreadedTokenizer, RobertaTokenizer, TruncationStrategy};
use rust_tokenizers::vocab::{RobertaVocab, Vocab};
use std::collections::HashMap;
use tch::{nn, no_grad, Device, Tensor};
#[test]
fn longformer_masked_lm() -> anyhow::Result<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
LongformerConfigResources::LONGFORMER_BASE_4096,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
LongformerVocabResources::LONGFORMER_BASE_4096,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
LongformerMergesResources::LONGFORMER_BASE_4096,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
LongformerModelResources::LONGFORMER_BASE_4096,
));
let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?;
let merges_path = merges_resource.get_local_path()?;
let weights_path = weights_resource.get_local_path()?;
// Set-up masked LM model
let device = Device::cuda_if_available();
let mut vs = nn::VarStore::new(device);
let tokenizer = RobertaTokenizer::from_file(
vocab_path.to_str().unwrap(),
merges_path.to_str().unwrap(),
false,
false,
)?;
let mut config = LongformerConfig::from_file(config_path);
config.output_attentions = Some(true);
config.output_hidden_states = Some(true);
let model = LongformerForMaskedLM::new(vs.root(), &config);
vs.load(weights_path)?;
// Define input
let input = [
"Looks like one <mask> is missing",
"It was a very nice and <mask> day",
];
let tokenized_input = tokenizer.encode_list(&input, 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let tokenized_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![
tokenizer.vocab().token_to_id(RobertaVocab::pad_value());
max_len - input.len()
]);
input
})
.map(|input| Tensor::of_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
let model_output =
no_grad(|| model.forward_t(Some(&input_tensor), None, None, None, None, None, false))?;
// Print masked tokens
let index_1 = model_output
.prediction_scores
.get(0)
.get(4)
.argmax(0, false);
let index_2 = model_output
.prediction_scores
.get(1)
.get(7)
.argmax(0, false);
let word_1 = tokenizer.vocab().id_to_token(&index_1.int64_value(&[]));
let word_2 = tokenizer.vocab().id_to_token(&index_2.int64_value(&[]));
let score_1 = model_output
.prediction_scores
.get(0)
.get(4)
.double_value(&[i64::from(&index_1)]);
let score_2 = model_output
.prediction_scores
.get(1)
.get(7)
.double_value(&[i64::from(&index_2)]);
assert_eq!("Ġeye", word_1); // Outputs "person" : "Looks like one [eye] is missing"
assert_eq!("Ġsunny", word_2); // Outputs "pear" : "It was a nice and [sunny] day"
assert!((score_1 - 11.7605).abs() < 1e-4);
assert!((score_2 - 17.0088).abs() < 1e-4);
assert_eq!(
model_output.prediction_scores.size(),
vec!(2, 10, config.vocab_size)
);
assert!(model_output.all_attentions.is_some());
assert!(model_output.all_hidden_states.is_some());
assert_eq!(
config.num_hidden_layers as usize + 1,
model_output.all_hidden_states.as_ref().unwrap().len()
);
assert_eq!(
config.num_hidden_layers as usize,
model_output.all_attentions.as_ref().unwrap().len()
);
assert_eq!(
model_output.all_attentions.as_ref().unwrap()[0].size(),
vec!(
2,
config.num_hidden_layers,
*config.attention_window.iter().max().unwrap(),
*config.attention_window.iter().max().unwrap() + 1
)
);
assert_eq!(
model_output.all_hidden_states.as_ref().unwrap()[0].size(),
vec!(
*config.attention_window.iter().max().unwrap(),
2,
config.hidden_size
)
);
Ok(())
}
#[test]
fn longformer_for_sequence_classification() -> anyhow::Result<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
LongformerConfigResources::LONGFORMER_BASE_4096,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
LongformerVocabResources::LONGFORMER_BASE_4096,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
LongformerMergesResources::LONGFORMER_BASE_4096,
));
let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?;
let merges_path = merges_resource.get_local_path()?;
// Set-up model
let device = Device::cuda_if_available();
let vs = nn::VarStore::new(device);
let tokenizer = RobertaTokenizer::from_file(
vocab_path.to_str().unwrap(),
merges_path.to_str().unwrap(),
false,
false,
)?;
let mut config = LongformerConfig::from_file(config_path);
let mut dummy_label_mapping = HashMap::new();
dummy_label_mapping.insert(0, String::from("Positive"));
dummy_label_mapping.insert(1, String::from("Negative"));
dummy_label_mapping.insert(3, String::from("Neutral"));
config.id2label = Some(dummy_label_mapping);
let model = LongformerForSequenceClassification::new(&vs.root(), &config);
// Define input
let input = ["Very positive sentence", "Second sentence input"];
let tokenized_input = tokenizer.encode_list(&input, 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let tokenized_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]);
input
})
.map(|input| Tensor::of_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
let model_output = no_grad(|| {
model.forward_t(
Some(input_tensor.as_ref()),
None,
None,
None,
None,
None,
false,
)
})?;
assert_eq!(model_output.logits.size(), &[2, 3]);
Ok(())
}
#[test]
fn longformer_for_multiple_choice() -> anyhow::Result<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
LongformerConfigResources::LONGFORMER_BASE_4096,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
LongformerVocabResources::LONGFORMER_BASE_4096,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
LongformerMergesResources::LONGFORMER_BASE_4096,
));
let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?;
let merges_path = merges_resource.get_local_path()?;
// Set-up model
let device = Device::Cpu;
let vs = nn::VarStore::new(device);
let tokenizer = RobertaTokenizer::from_file(
vocab_path.to_str().unwrap(),
merges_path.to_str().unwrap(),
false,
false,
)?;
let config = LongformerConfig::from_file(config_path);
let model = LongformerForMultipleChoice::new(&vs.root(), &config);
// Define input
let prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced.";
let inputs = ["Very positive sentence", "Second sentence input"];
let tokenized_input = tokenizer.encode_pair_list(
inputs
.iter()
.map(|&inp| (prompt, inp))
.collect::<Vec<(&str, &str)>>(),
128,
&TruncationStrategy::LongestFirst,
0,
);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let tokenized_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]);
input
})
.map(|input| Tensor::of_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0)
.to(device)
.unsqueeze(0);
// Forward pass
let model_output = no_grad(|| {
model.forward_t(
Some(input_tensor.as_ref()),
None,
None,
None,
None,
None,
false,
)
})?;
assert_eq!(model_output.logits.size(), &[1, 2]);
Ok(())
}
#[test]
fn mobilebert_for_token_classification() -> anyhow::Result<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
LongformerConfigResources::LONGFORMER_BASE_4096,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
LongformerVocabResources::LONGFORMER_BASE_4096,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
LongformerMergesResources::LONGFORMER_BASE_4096,
));
let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?;
let merges_path = merges_resource.get_local_path()?;
// Set-up model
let device = Device::cuda_if_available();
let vs = nn::VarStore::new(device);
let tokenizer = RobertaTokenizer::from_file(
vocab_path.to_str().unwrap(),
merges_path.to_str().unwrap(),
false,
false,
)?;
let mut config = LongformerConfig::from_file(config_path);
let mut dummy_label_mapping = HashMap::new();
dummy_label_mapping.insert(0, String::from("O"));
dummy_label_mapping.insert(1, String::from("LOC"));
dummy_label_mapping.insert(2, String::from("PER"));
dummy_label_mapping.insert(3, String::from("ORG"));
config.id2label = Some(dummy_label_mapping);
let model = LongformerForTokenClassification::new(&vs.root(), &config);
// Define input
let inputs = ["Where's Paris?", "In Kentucky, United States"];
let tokenized_input = tokenizer.encode_list(&inputs, 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let tokenized_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]);
input
})
.map(|input| Tensor::of_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
let model_output = no_grad(|| {
model.forward_t(
Some(input_tensor.as_ref()),
None,
None,
None,
None,
None,
false,
)
})?;
assert_eq!(model_output.logits.size(), &[2, 7, 4]);
Ok(())
}
#[test]
fn longformer_for_question_answering() -> anyhow::Result<()> {
// Set-up Question Answering model
let config = QuestionAnsweringConfig::new(
ModelType::Longformer,
Resource::Remote(RemoteResource::from_pretrained(
LongformerModelResources::LONGFORMER_BASE_SQUAD1,
)),
Resource::Remote(RemoteResource::from_pretrained(
LongformerConfigResources::LONGFORMER_BASE_SQUAD1,
)),
Resource::Remote(RemoteResource::from_pretrained(
LongformerVocabResources::LONGFORMER_BASE_SQUAD1,
)),
Some(Resource::Remote(RemoteResource::from_pretrained(
LongformerMergesResources::LONGFORMER_BASE_SQUAD1,
))),
false,
None,
false,
);
let qa_model = QuestionAnsweringModel::new(config)?;
// Define input
let question_1 = String::from("Where does Amy live ?");
let context_1 = String::from("Amy lives in Amsterdam");
let question_2 = String::from("Where does Eric live");
let context_2 = String::from("While Amy lives in Amsterdam, Eric is in The Hague.");
let qa_input_1 = QaInput {
question: question_1,
context: context_1,
};
let qa_input_2 = QaInput {
question: question_2,
context: context_2,
};
// Get answer
let answers = qa_model.predict(&[qa_input_1, qa_input_2], 1, 32);
assert_eq!(answers.len(), 2usize);
assert_eq!(answers[0].len(), 1usize);
assert_eq!(answers[0][0].start, 12);
assert_eq!(answers[0][0].end, 22);
assert!((answers[0][0].score - 0.8060).abs() < 1e-4);
assert_eq!(answers[0][0].answer, " Amsterdam");
assert_eq!(answers[1].len(), 1usize);
assert_eq!(answers[1][0].start, 40);
assert_eq!(answers[1][0].end, 50);
assert!((answers[1][0].score - 0.7503).abs() < 1e-4);
assert_eq!(answers[1][0].answer, " The Hague");
Ok(())
}

View File

@ -348,10 +348,10 @@ fn roberta_question_answering() -> anyhow::Result<()> {
)),
Some(Resource::Remote(RemoteResource::from_pretrained(
RobertaMergesResources::ROBERTA_QA,
))), //merges resource only relevant with ModelType::Roberta
true, //lowercase
))),
false,
None,
true,
false,
);
let qa_model = QuestionAnsweringModel::new(config)?;
@ -365,10 +365,10 @@ fn roberta_question_answering() -> anyhow::Result<()> {
assert_eq!(answers.len(), 1usize);
assert_eq!(answers[0].len(), 1usize);
assert_eq!(answers[0][0].start, 13);
assert_eq!(answers[0][0].end, 21);
assert!((answers[0][0].score - 0.7354).abs() < 1e-4);
assert_eq!(answers[0][0].answer, "Amsterdam");
assert_eq!(answers[0][0].start, 12);
assert_eq!(answers[0][0].end, 22);
assert!((answers[0][0].score - 0.9997).abs() < 1e-4);
assert_eq!(answers[0][0].answer, " Amsterdam");
Ok(())
}