Addition of integration tests for BERT and RoBERTa QA

This commit is contained in:
Guillaume B 2020-07-12 09:49:22 +02:00
parent b5335824bc
commit c9262a310a
3 changed files with 118 additions and 0 deletions

View File

@ -312,6 +312,42 @@ fn download_t5_small() -> failure::Fallible<()> {
Ok(())
}
fn download_roberta_qa() -> failure::Fallible<()> {
// Shared under Apache 2.0 license by [deepset](https://deepset.ai) at https://huggingface.co/deepset/roberta-base-squad2.
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
RobertaConfigResources::ROBERTA_QA,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
RobertaVocabResources::ROBERTA_QA,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
RobertaModelResources::ROBERTA_QA,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
RobertaMergesResources::ROBERTA_QA,
));
let _ = download_resource(&config_resource)?;
let _ = download_resource(&vocab_resource)?;
let _ = download_resource(&merges_resource)?;
let _ = download_resource(&weights_resource)?;
Ok(())
}
fn download_bert_qa() -> failure::Fallible<()> {
// Shared under Apache 2.0 license by [deepset](https://deepset.ai) at https://huggingface.co/deepset/roberta-base-squad2.
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
BertConfigResources::BERT_QA,
));
let vocab_resource =
Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT_QA));
let weights_resource =
Resource::Remote(RemoteResource::from_pretrained(BertModelResources::BERT_QA));
let _ = download_resource(&config_resource)?;
let _ = download_resource(&vocab_resource)?;
let _ = download_resource(&weights_resource)?;
Ok(())
}
fn main() -> failure::Fallible<()> {
let _ = download_distil_gpt2();
let _ = download_distilbert_sst2();
@ -328,6 +364,8 @@ fn main() -> failure::Fallible<()> {
let _ = download_electra_discriminator();
let _ = download_albert_base_v2();
let _ = download_t5_small();
let _ = download_roberta_qa();
let _ = download_bert_qa();
Ok(())
}

View File

@ -6,7 +6,11 @@ use rust_bert::bert::{
BertForQuestionAnswering, BertForSequenceClassification, BertForTokenClassification,
BertModelResources, BertVocabResources,
};
use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::ner::NERModel;
use rust_bert::pipelines::question_answering::{
QaInput, QuestionAnsweringConfig, QuestionAnsweringModel,
};
use rust_bert::resources::{download_resource, RemoteResource, Resource};
use rust_bert::Config;
use rust_tokenizers::{BertTokenizer, Tokenizer, TruncationStrategy, Vocab};
@ -374,3 +378,36 @@ fn bert_pre_trained_ner() -> failure::Fallible<()> {
Ok(())
}
#[test]
fn bert_question_answering() -> failure::Fallible<()> {
// Set-up question answering model
let config = QuestionAnsweringConfig::new(
ModelType::Bert,
Resource::Remote(RemoteResource::from_pretrained(BertModelResources::BERT_QA)),
Resource::Remote(RemoteResource::from_pretrained(
BertConfigResources::BERT_QA,
)),
Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT_QA)),
None, //merges resource only relevant with ModelType::Roberta
true, //lowercase
);
let qa_model = QuestionAnsweringModel::new(config)?;
// Define input
let question = String::from("Where does Amy live ?");
let context = String::from("Amy lives in Amsterdam");
let qa_input = QaInput { question, context };
let answers = qa_model.predict(&vec![qa_input], 1, 32);
assert_eq!(answers.len(), 1 as usize);
assert_eq!(answers[0].len(), 1 as usize);
assert_eq!(answers[0][0].start, 13);
assert_eq!(answers[0][0].end, 21);
assert!((answers[0][0].score - 0.8111).abs() < 1e-4);
assert_eq!(answers[0][0].answer, "Amsterdam");
Ok(())
}

View File

@ -1,4 +1,8 @@
use rust_bert::bert::BertConfig;
use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::question_answering::{
QaInput, QuestionAnsweringConfig, QuestionAnsweringModel,
};
use rust_bert::resources::{download_resource, RemoteResource, Resource};
use rust_bert::roberta::{
RobertaConfigResources, RobertaForMaskedLM, RobertaForMultipleChoice,
@ -392,3 +396,42 @@ fn roberta_for_question_answering() -> failure::Fallible<()> {
Ok(())
}
#[test]
fn roberta_question_answering() -> failure::Fallible<()> {
// Set-up question answering model
let config = QuestionAnsweringConfig::new(
ModelType::Roberta,
Resource::Remote(RemoteResource::from_pretrained(
RobertaModelResources::ROBERTA_QA,
)),
Resource::Remote(RemoteResource::from_pretrained(
RobertaConfigResources::ROBERTA_QA,
)),
Resource::Remote(RemoteResource::from_pretrained(
RobertaVocabResources::ROBERTA_QA,
)),
Some(Resource::Remote(RemoteResource::from_pretrained(
RobertaMergesResources::ROBERTA_QA,
))), //merges resource only relevant with ModelType::Roberta
true, //lowercase
);
let qa_model = QuestionAnsweringModel::new(config)?;
// Define input
let question = String::from("Where does Amy live ?");
let context = String::from("Amy lives in Amsterdam");
let qa_input = QaInput { question, context };
let answers = qa_model.predict(&vec![qa_input], 1, 32);
assert_eq!(answers.len(), 1 as usize);
assert_eq!(answers[0].len(), 1 as usize);
assert_eq!(answers[0][0].start, 13);
assert_eq!(answers[0][0].end, 21);
assert!((answers[0][0].score - 0.7354).abs() < 1e-4);
assert_eq!(answers[0][0].answer, "Amsterdam");
Ok(())
}