Merge pull request #11 from guillaume-be/qa_batch_input

Question Answering batched input
This commit is contained in:
guillaume-be 2020-03-07 18:12:02 +01:00 committed by GitHub
commit e9ac97b705
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 202 additions and 70 deletions

View File

@ -1,6 +1,6 @@
[package]
name = "rust-bert"
version = "0.4.4"
version = "0.4.5"
authors = ["Guillaume Becquin <guillaume.becquin@gmail.com>"]
edition = "2018"
default-run = "rust-bert"

View File

@ -20,7 +20,7 @@ Next token prediction| | | |✅|✅|
## Ready-to-use pipelines
Leveraging Huggingface's pipelines, ready to use end-to-end NLP pipelines are available as part of this crate. The following capabilities are currently available:
Based on Huggingface's pipelines, ready to use end-to-end NLP pipelines are available as part of this crate. The following capabilities are currently available:
#### 1. Question Answering
Extractive question answering from a given question and context. DistilBERT model finetuned on SQuAD (Stanford Question Answering Dataset)
@ -30,10 +30,10 @@ Extractive question answering from a given question and context. DistilBERT mode
config_path,
weights_path, device)?;
let question = "Where does Amy live ?";
let context = "Amy lives in Amsterdam";
let question = String::from("Where does Amy live ?");
let context = String::from("Amy lives in Amsterdam");
let answers = qa_model.predict(question, context, 1);
let answers = qa_model.predict(QaInput { question, context }, 1, 32);
```
Output:

View File

@ -14,7 +14,7 @@ extern crate failure;
extern crate dirs;
use std::path::PathBuf;
use rust_bert::pipelines::question_answering::QuestionAnsweringModel;
use rust_bert::pipelines::question_answering::{QuestionAnsweringModel, QaInput};
use tch::Device;
@ -34,11 +34,15 @@ fn main() -> failure::Fallible<()> {
weights_path, device)?;
// Define input
let question = "Where does Amy live ?";
let context = "Amy lives in Amsterdam";
let question_1 = String::from("Where does Amy live ?");
let context_1 = String::from("Amy lives in Amsterdam");
let question_2 = String::from("Where does Eric live");
let context_2 = String::from("While Amy lives in Amsterdam, Eric is in The Hague.");
let qa_input_1 = QaInput { question: question_1, context: context_1 };
let qa_input_2 = QaInput { question: question_2, context: context_2 };
// Get answer
let answers = qa_model.predict(question, context, 1);
let answers = qa_model.predict(&vec!(qa_input_1, qa_input_2), 1, 32);
println!("{:?}", answers);
Ok(())
}

47
examples/squad.rs Normal file
View File

@ -0,0 +1,47 @@
// Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc.
// Copyright 2019 Guillaume Becquin
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
extern crate failure;
extern crate dirs;
use std::path::PathBuf;
use rust_bert::pipelines::question_answering::{QuestionAnsweringModel, squad_processor};
use tch::Device;
use std::env;
fn main() -> failure::Fallible<()> {
// Resources paths
let mut home: PathBuf = dirs::home_dir().unwrap();
home.push("rustbert");
home.push("distilbert-qa");
let config_path = &home.as_path().join("config.json");
let vocab_path = &home.as_path().join("vocab.txt");
let weights_path = &home.as_path().join("model.ot");
// Set-up Question Answering model
let device = Device::cuda_if_available();
let qa_model = QuestionAnsweringModel::new(vocab_path,
config_path,
weights_path, device)?;
// Define input
let mut squad_path = PathBuf::from(env::var("squad_dataset")
.expect("Please set the \"squad_dataset\" environment variable pointing to the SQuAD dataset folder"));
squad_path.push("dev-v2.0.json");
let qa_inputs = squad_processor(squad_path);
// Get answer
let answers = qa_model.predict(&qa_inputs, 1, 64);
println!("Sample answer: {:?}", answers.first().unwrap());
println!("{}", answers.len());
Ok(())
}

View File

@ -13,7 +13,7 @@
use rust_tokenizers::{BertTokenizer, Tokenizer, TruncationStrategy, TokenizedInput};
use tch::{Device, Tensor, no_grad};
use std::path::Path;
use std::path::{Path, PathBuf};
use rust_tokenizers::tokenization_utils::truncate_sequences;
use std::collections::HashMap;
use std::cmp::min;
@ -21,6 +21,12 @@ use crate::{DistilBertForQuestionAnswering, DistilBertConfig};
use tch::nn::VarStore;
use crate::common::config::Config;
use tch::kind::Kind::Float;
use std::fs;
pub struct QaInput {
pub question: String,
pub context: String,
}
#[derive(Debug)]
pub struct QaExample {
@ -36,6 +42,7 @@ pub struct QaFeature {
pub attention_mask: Vec<i64>,
pub token_to_orig_map: HashMap<i64, i64>,
pub p_mask: Vec<i8>,
pub example_index: i64,
}
@ -130,63 +137,111 @@ impl QuestionAnsweringModel {
})
}
fn prepare_for_model(&self, question: &str, context: &str) -> (QaExample, Vec<QaFeature>, Tensor, Tensor) {
let qa_example = QaExample::new(question, context);
let mut input_ids: Vec<Tensor> = vec!();
let mut attention_masks: Vec<Tensor> = vec!();
let features = self.generate_features(&qa_example, self.max_seq_len, self.doc_stride, self.max_query_length);
for feature in &features {
input_ids.push(Tensor::of_slice(&feature.input_ids).to(self.var_store.device()));
attention_masks.push(Tensor::of_slice(&feature.attention_mask).to(self.var_store.device()));
fn generate_batch_indices(&self, features: &Vec<QaFeature>, batch_size: usize) -> Vec<(usize, usize)> {
let mut example_features_length: HashMap<i64, usize> = HashMap::new();
for feature in features {
let count = example_features_length.entry(feature.example_index).or_insert(0);
*count += 1;
}
let input_ids = Tensor::stack(&input_ids, 0);
let attention_masks = Tensor::stack(&attention_masks, 0);
(qa_example, features, input_ids, attention_masks)
}
pub fn predict(&self, question: &str, context: &str, top_k: i64) -> Vec<Answer> {
let (qa_example, qa_features, input_ids, attention_mask) = self.prepare_for_model(question, context);
let mut batch_indices: Vec<(usize, usize)> = Vec::with_capacity(features.len());
let (start_logits, end_logits, _, _) = no_grad(|| {
self.distilbert_qa.forward_t(Some(input_ids), Some(attention_mask), None, false).unwrap()
});
let start_logits = start_logits.to(Device::Cpu);
let end_logits = end_logits.to(Device::Cpu);
let mut batch_length = 0usize;
let mut start = 0usize;
let mut end = 0usize;
let mut answers: Vec<Answer> = vec!();
for (feature_idx, feature) in (0..start_logits.size()[0]).zip(qa_features) {
let start = start_logits.get(feature_idx);
let end = end_logits.get(feature_idx);
let p_mask = (Tensor::of_slice(&feature.p_mask) - 1).abs();
let start: Tensor = start.exp() / start.exp().sum(Float) * &p_mask;
let end: Tensor = end.exp() / end.exp().sum(Float) * &p_mask;
let (starts, ends, scores) = self.decode(&start, &end, top_k);
for idx in 0..starts.len() {
let start_pos = feature.token_to_orig_map[&starts[idx]] as usize;
let end_pos = feature.token_to_orig_map[&ends[idx]] as usize;
let answer = qa_example.doc_tokens[start_pos..end_pos + 1].join(" ");
let start = qa_example.char_to_word_offset
.iter()
.position(|&v| v as usize == start_pos)
.unwrap();
let end = qa_example.char_to_word_offset
.iter()
.rposition(|&v| v as usize == end_pos)
.unwrap();
answers.push(
Answer { score: scores[idx], start, end, answer })
for &feature_length in example_features_length.values() {
if batch_length + feature_length <= batch_size {
end += feature_length;
batch_length += feature_length;
} else {
batch_indices.push((start, end));
start = end;
end += feature_length;
batch_length = 1usize;
}
}
answers.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
batch_indices.push((start, end));
batch_indices
}
answers[..(top_k as usize)].to_vec()
pub fn predict(&self, qa_inputs: &[QaInput], top_k: i64, batch_size: usize) -> Vec<Vec<Answer>> {
let examples: Vec<QaExample> = qa_inputs
.iter()
.map(|qa_input| QaExample::new(&qa_input.question, &qa_input.context))
.collect();
let features: Vec<QaFeature> = examples
.iter()
.enumerate()
.map(|(example_index, qa_example)| self.generate_features(&qa_example, self.max_seq_len, self.doc_stride, self.max_query_length, example_index as i64))
.flatten()
.collect();
let batch_indices = self.generate_batch_indices(&features, batch_size);
let mut all_answers = vec!();
for (start, end) in batch_indices {
let batch_features = &features[start..end];
let mut input_ids: Vec<Tensor> = vec!();
let mut attention_masks: Vec<Tensor> = vec!();
for feature in batch_features {
input_ids.push(Tensor::of_slice(&feature.input_ids));
attention_masks.push(Tensor::of_slice(&feature.attention_mask));
}
let input_ids = Tensor::stack(&input_ids, 0).to(self.var_store.device());
let attention_masks = Tensor::stack(&attention_masks, 0).to(self.var_store.device());
let (start_logits, end_logits, _, _) = no_grad(|| {
self.distilbert_qa.forward_t(Some(input_ids), Some(attention_masks), None, false).unwrap()
});
let start_logits = start_logits.to(Device::Cpu);
let end_logits = end_logits.to(Device::Cpu);
let example_index_to_feature_end_position: Vec<(usize, i64)> = batch_features
.iter()
.enumerate()
.map(|(feature_index, feature)| (feature.example_index as usize, feature_index as i64 + 1))
.collect();
let mut feature_id_start = 0;
for (example_id, max_feature_id) in example_index_to_feature_end_position {
let mut answers: Vec<Answer> = vec!();
let example = &examples[example_id];
for feature_idx in feature_id_start..max_feature_id {
let feature = &batch_features[feature_idx as usize];
let start = start_logits.get(feature_idx);
let end = end_logits.get(feature_idx);
let p_mask = (Tensor::of_slice(&feature.p_mask) - 1).abs();
let start: Tensor = start.exp() / start.exp().sum(Float) * &p_mask;
let end: Tensor = end.exp() / end.exp().sum(Float) * &p_mask;
let (starts, ends, scores) = self.decode(&start, &end, top_k);
for idx in 0..starts.len() {
let start_pos = feature.token_to_orig_map[&starts[idx]] as usize;
let end_pos = feature.token_to_orig_map[&ends[idx]] as usize;
let answer = example.doc_tokens[start_pos..end_pos + 1].join(" ");
let start = example.char_to_word_offset
.iter()
.position(|&v| v as usize == start_pos)
.unwrap();
let end = example.char_to_word_offset
.iter()
.rposition(|&v| v as usize == end_pos)
.unwrap();
answers.push(Answer { score: scores[idx], start, end, answer });
}
}
feature_id_start = max_feature_id;
all_answers.push(answers[..(top_k as usize)].to_vec());
}
}
all_answers
}
fn decode(&self, start: &Tensor, end: &Tensor, top_k: i64) -> (Vec<i64>, Vec<i64>, Vec<f64>) {
@ -216,7 +271,7 @@ impl QuestionAnsweringModel {
}
fn generate_features(&self, qa_example: &QaExample, max_seq_length: usize, doc_stride: usize, max_query_length: usize) -> Vec<QaFeature> {
fn generate_features(&self, qa_example: &QaExample, max_seq_length: usize, doc_stride: usize, max_query_length: usize, example_index: i64) -> Vec<QaFeature> {
let mut tok_to_orig_index: Vec<i64> = vec!();
let mut all_doc_tokens: Vec<String> = vec!();
@ -251,7 +306,7 @@ impl QuestionAnsweringModel {
let p_mask = self.get_mask(&encoded_span);
let qa_feature = QaFeature { input_ids: encoded_span.token_ids, attention_mask, token_to_orig_map, p_mask };
let qa_feature = QaFeature { input_ids: encoded_span.token_ids, attention_mask, token_to_orig_map, p_mask, example_index };
spans.push(qa_feature);
if encoded_span.num_truncated_tokens == 0 {
@ -320,3 +375,27 @@ impl QuestionAnsweringModel {
p_mask
}
}
pub fn squad_processor(file_path: PathBuf) -> Vec<QaInput> {
let file = fs::File::open(file_path).expect("unable to open file");
let json: serde_json::Value = serde_json::from_reader(file).expect("JSON not properly formatted");
let data = json
.get("data").expect("SQuAD file does not contain data field")
.as_array().expect("Data array not properly formatted");
let mut qa_inputs: Vec<QaInput> = Vec::with_capacity(data.len());
for qa_input in data.iter() {
let qa_input = qa_input.as_object().unwrap();
let paragraphs = qa_input.get("paragraphs").unwrap().as_array().unwrap();
for paragraph in paragraphs.iter() {
let paragraph = paragraph.as_object().unwrap();
let context = paragraph.get("context").unwrap().as_str().unwrap();
let qas = paragraph.get("qas").unwrap().as_array().unwrap();
for qa in qas.iter() {
let question = qa.as_object().unwrap().get("question").unwrap().as_str().unwrap();
qa_inputs.push(QaInput { question: question.to_owned(), context: context.to_owned() });
}
}
}
qa_inputs
}

View File

@ -6,7 +6,7 @@ use rust_tokenizers::bert_tokenizer::BertTokenizer;
use rust_tokenizers::preprocessing::vocab::base_vocab::Vocab;
use rust_bert::{SentimentClassifier, SentimentPolarity};
use rust_bert::common::config::Config;
use rust_bert::pipelines::question_answering::QuestionAnsweringModel;
use rust_bert::pipelines::question_answering::{QuestionAnsweringModel, QaInput};
extern crate failure;
extern crate dirs;
@ -226,16 +226,18 @@ fn distilbert_question_answering() -> failure::Fallible<()> {
let qa_model = QuestionAnsweringModel::new(vocab_path, config_path, weights_path, device)?;
// Define input
let question = "Where does Amy live ?";
let context = "Amy lives in Amsterdam";
let question = String::from("Where does Amy live ?");
let context = String::from("Amy lives in Amsterdam");
let qa_input = QaInput { question, context };
let answers = qa_model.predict(question, context, 1);
let answers = qa_model.predict(&vec!(qa_input), 1, 32);
assert_eq!(answers.len(), 1 as usize);
assert_eq!(answers[0].start, 13);
assert_eq!(answers[0].end, 21);
assert!((answers[0].score - 0.9977).abs() < 1e-4);
assert_eq!(answers[0].answer, "Amsterdam");
assert_eq!(answers[0].len(), 1 as usize);
assert_eq!(answers[0][0].start, 13);
assert_eq!(answers[0][0].end, 21);
assert!((answers[0][0].score - 0.9977).abs() < 1e-4);
assert_eq!(answers[0][0].answer, "Amsterdam");
Ok(())
}