initial commit for batch processing of QA inputs

This commit is contained in:
Guillaume B 2020-03-06 21:40:32 +01:00
parent 4935437891
commit 999f6e52aa
3 changed files with 63 additions and 44 deletions

View File

@ -14,7 +14,7 @@ extern crate failure;
extern crate dirs;
use std::path::PathBuf;
use rust_bert::pipelines::question_answering::QuestionAnsweringModel;
use rust_bert::pipelines::question_answering::{QuestionAnsweringModel, QaInput};
use tch::Device;
@ -34,11 +34,15 @@ fn main() -> failure::Fallible<()> {
weights_path, device)?;
// Define input
let question = "Where does Amy live ?";
let context = "Amy lives in Amsterdam";
let question_1 = String::from("Where does Amy live ?");
let context_1 = String::from("Amy lives in Amsterdam");
let question_2 = String::from("Where does Eric live");
let context_2 = String::from("While Amy lives in Amsterdam, Eric is in The Hague.");
let qa_input_1 = QaInput { question: question_1, context: context_1 };
let qa_input_2 = QaInput { question: question_2, context: context_2 };
// Get answer
let answers = qa_model.predict(question, context, 1);
let answers = qa_model.predict(&vec!(qa_input_1, qa_input_2), 1);
println!("{:?}", answers);
Ok(())
}

View File

@ -22,6 +22,12 @@ use tch::nn::VarStore;
use crate::common::config::Config;
use tch::kind::Kind::Float;
pub struct QaInput {
pub question: String,
pub context: String,
}
#[derive(Debug)]
pub struct QaExample {
pub question: String,
@ -145,48 +151,55 @@ impl QuestionAnsweringModel {
(qa_example, features, input_ids, attention_masks)
}
pub fn predict(&self, question: &str, context: &str, top_k: i64) -> Vec<Answer> {
let (qa_example, qa_features, input_ids, attention_mask) = self.prepare_for_model(question, context);
pub fn predict(&self, qa_inputs: &Vec<QaInput>, top_k: i64) -> Vec<Vec<Answer>> {
let inputs: Vec<(QaExample, Vec<QaFeature>, Tensor, Tensor)> = qa_inputs
.iter()
.map(|qa_input| self.prepare_for_model(&qa_input.question, &qa_input.context))
.collect();
let (start_logits, end_logits, _, _) = no_grad(|| {
self.distilbert_qa.forward_t(Some(input_ids), Some(attention_mask), None, false).unwrap()
});
let start_logits = start_logits.to(Device::Cpu);
let end_logits = end_logits.to(Device::Cpu);
let mut all_answers = vec!();
let mut answers: Vec<Answer> = vec!();
for (feature_idx, feature) in (0..start_logits.size()[0]).zip(qa_features) {
let start = start_logits.get(feature_idx);
let end = end_logits.get(feature_idx);
let p_mask = (Tensor::of_slice(&feature.p_mask) - 1).abs();
for (qa_example, qa_features, input_ids, attention_mask) in inputs {
let (start_logits, end_logits, _, _) = no_grad(|| {
self.distilbert_qa.forward_t(Some(input_ids), Some(attention_mask), None, false).unwrap()
});
let start_logits = start_logits.to(Device::Cpu);
let end_logits = end_logits.to(Device::Cpu);
let start: Tensor = start.exp() / start.exp().sum(Float) * &p_mask;
let end: Tensor = end.exp() / end.exp().sum(Float) * &p_mask;
let mut answers: Vec<Answer> = vec!();
for (feature_idx, feature) in (0..start_logits.size()[0]).zip(qa_features) {
let start = start_logits.get(feature_idx);
let end = end_logits.get(feature_idx);
let p_mask = (Tensor::of_slice(&feature.p_mask) - 1).abs();
let (starts, ends, scores) = self.decode(&start, &end, top_k);
let start: Tensor = start.exp() / start.exp().sum(Float) * &p_mask;
let end: Tensor = end.exp() / end.exp().sum(Float) * &p_mask;
for idx in 0..starts.len() {
let start_pos = feature.token_to_orig_map[&starts[idx]] as usize;
let end_pos = feature.token_to_orig_map[&ends[idx]] as usize;
let answer = qa_example.doc_tokens[start_pos..end_pos + 1].join(" ");
let (starts, ends, scores) = self.decode(&start, &end, top_k);
let start = qa_example.char_to_word_offset
.iter()
.position(|&v| v as usize == start_pos)
.unwrap();
for idx in 0..starts.len() {
let start_pos = feature.token_to_orig_map[&starts[idx]] as usize;
let end_pos = feature.token_to_orig_map[&ends[idx]] as usize;
let answer = qa_example.doc_tokens[start_pos..end_pos + 1].join(" ");
let end = qa_example.char_to_word_offset
.iter()
.rposition(|&v| v as usize == end_pos)
.unwrap();
let start = qa_example.char_to_word_offset
.iter()
.position(|&v| v as usize == start_pos)
.unwrap();
answers.push(
Answer { score: scores[idx], start, end, answer })
let end = qa_example.char_to_word_offset
.iter()
.rposition(|&v| v as usize == end_pos)
.unwrap();
answers.push(
Answer { score: scores[idx], start, end, answer })
}
}
answers.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
all_answers.push(answers[..(top_k as usize)].to_vec());
}
answers.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
answers[..(top_k as usize)].to_vec()
all_answers
}
fn decode(&self, start: &Tensor, end: &Tensor, top_k: i64) -> (Vec<i64>, Vec<i64>, Vec<f64>) {

View File

@ -6,7 +6,7 @@ use rust_tokenizers::bert_tokenizer::BertTokenizer;
use rust_tokenizers::preprocessing::vocab::base_vocab::Vocab;
use rust_bert::{SentimentClassifier, SentimentPolarity};
use rust_bert::common::config::Config;
use rust_bert::pipelines::question_answering::QuestionAnsweringModel;
use rust_bert::pipelines::question_answering::{QuestionAnsweringModel, QaInput};
extern crate failure;
extern crate dirs;
@ -226,16 +226,18 @@ fn distilbert_question_answering() -> failure::Fallible<()> {
let qa_model = QuestionAnsweringModel::new(vocab_path, config_path, weights_path, device)?;
// Define input
let question = "Where does Amy live ?";
let context = "Amy lives in Amsterdam";
let question = String::from("Where does Amy live ?");
let context = String::from("Amy lives in Amsterdam");
let qa_input = QaInput { question, context };
let answers = qa_model.predict(question, context, 1);
let answers = qa_model.predict(&vec!(qa_input), 1);
assert_eq!(answers.len(), 1 as usize);
assert_eq!(answers[0].start, 13);
assert_eq!(answers[0].end, 21);
assert!((answers[0].score - 0.9977).abs() < 1e-4);
assert_eq!(answers[0].answer, "Amsterdam");
assert_eq!(answers[0].len(), 1 as usize);
assert_eq!(answers[0][0].start, 13);
assert_eq!(answers[0][0].end, 21);
assert!((answers[0][0].score - 0.9977).abs() < 1e-4);
assert_eq!(answers[0][0].answer, "Amsterdam");
Ok(())
}