mirror of
https://github.com/guillaume-be/rust-bert.git
synced 2024-11-09 17:05:51 +03:00
Merge pull request #11 from guillaume-be/qa_batch_input
Question Answering batched input
This commit is contained in:
commit
e9ac97b705
@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "rust-bert"
|
||||
version = "0.4.4"
|
||||
version = "0.4.5"
|
||||
authors = ["Guillaume Becquin <guillaume.becquin@gmail.com>"]
|
||||
edition = "2018"
|
||||
default-run = "rust-bert"
|
||||
|
@ -20,7 +20,7 @@ Next token prediction| | | |✅|✅|
|
||||
|
||||
## Ready-to-use pipelines
|
||||
|
||||
Leveraging Huggingface's pipelines, ready to use end-to-end NLP pipelines are available as part of this crate. The following capabilities are currently available:
|
||||
Based on Huggingface's pipelines, ready to use end-to-end NLP pipelines are available as part of this crate. The following capabilities are currently available:
|
||||
#### 1. Question Answering
|
||||
Extractive question answering from a given question and context. DistilBERT model finetuned on SQuAD (Stanford Question Answering Dataset)
|
||||
|
||||
@ -30,10 +30,10 @@ Extractive question answering from a given question and context. DistilBERT mode
|
||||
config_path,
|
||||
weights_path, device)?;
|
||||
|
||||
let question = "Where does Amy live ?";
|
||||
let context = "Amy lives in Amsterdam";
|
||||
let question = String::from("Where does Amy live ?");
|
||||
let context = String::from("Amy lives in Amsterdam");
|
||||
|
||||
let answers = qa_model.predict(question, context, 1);
|
||||
let answers = qa_model.predict(QaInput { question, context }, 1, 32);
|
||||
```
|
||||
|
||||
Output:
|
||||
|
@ -14,7 +14,7 @@ extern crate failure;
|
||||
extern crate dirs;
|
||||
|
||||
use std::path::PathBuf;
|
||||
use rust_bert::pipelines::question_answering::QuestionAnsweringModel;
|
||||
use rust_bert::pipelines::question_answering::{QuestionAnsweringModel, QaInput};
|
||||
use tch::Device;
|
||||
|
||||
|
||||
@ -34,11 +34,15 @@ fn main() -> failure::Fallible<()> {
|
||||
weights_path, device)?;
|
||||
|
||||
// Define input
|
||||
let question = "Where does Amy live ?";
|
||||
let context = "Amy lives in Amsterdam";
|
||||
let question_1 = String::from("Where does Amy live ?");
|
||||
let context_1 = String::from("Amy lives in Amsterdam");
|
||||
let question_2 = String::from("Where does Eric live");
|
||||
let context_2 = String::from("While Amy lives in Amsterdam, Eric is in The Hague.");
|
||||
let qa_input_1 = QaInput { question: question_1, context: context_1 };
|
||||
let qa_input_2 = QaInput { question: question_2, context: context_2 };
|
||||
|
||||
// Get answer
|
||||
let answers = qa_model.predict(question, context, 1);
|
||||
let answers = qa_model.predict(&vec!(qa_input_1, qa_input_2), 1, 32);
|
||||
println!("{:?}", answers);
|
||||
Ok(())
|
||||
}
|
47
examples/squad.rs
Normal file
47
examples/squad.rs
Normal file
@ -0,0 +1,47 @@
|
||||
// Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc.
|
||||
// Copyright 2019 Guillaume Becquin
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
extern crate failure;
|
||||
extern crate dirs;
|
||||
|
||||
use std::path::PathBuf;
|
||||
use rust_bert::pipelines::question_answering::{QuestionAnsweringModel, squad_processor};
|
||||
use tch::Device;
|
||||
use std::env;
|
||||
|
||||
fn main() -> failure::Fallible<()> {
|
||||
// Resources paths
|
||||
let mut home: PathBuf = dirs::home_dir().unwrap();
|
||||
home.push("rustbert");
|
||||
home.push("distilbert-qa");
|
||||
let config_path = &home.as_path().join("config.json");
|
||||
let vocab_path = &home.as_path().join("vocab.txt");
|
||||
let weights_path = &home.as_path().join("model.ot");
|
||||
|
||||
// Set-up Question Answering model
|
||||
let device = Device::cuda_if_available();
|
||||
let qa_model = QuestionAnsweringModel::new(vocab_path,
|
||||
config_path,
|
||||
weights_path, device)?;
|
||||
|
||||
// Define input
|
||||
let mut squad_path = PathBuf::from(env::var("squad_dataset")
|
||||
.expect("Please set the \"squad_dataset\" environment variable pointing to the SQuAD dataset folder"));
|
||||
squad_path.push("dev-v2.0.json");
|
||||
let qa_inputs = squad_processor(squad_path);
|
||||
|
||||
// Get answer
|
||||
let answers = qa_model.predict(&qa_inputs, 1, 64);
|
||||
println!("Sample answer: {:?}", answers.first().unwrap());
|
||||
println!("{}", answers.len());
|
||||
Ok(())
|
||||
}
|
@ -13,7 +13,7 @@
|
||||
|
||||
use rust_tokenizers::{BertTokenizer, Tokenizer, TruncationStrategy, TokenizedInput};
|
||||
use tch::{Device, Tensor, no_grad};
|
||||
use std::path::Path;
|
||||
use std::path::{Path, PathBuf};
|
||||
use rust_tokenizers::tokenization_utils::truncate_sequences;
|
||||
use std::collections::HashMap;
|
||||
use std::cmp::min;
|
||||
@ -21,6 +21,12 @@ use crate::{DistilBertForQuestionAnswering, DistilBertConfig};
|
||||
use tch::nn::VarStore;
|
||||
use crate::common::config::Config;
|
||||
use tch::kind::Kind::Float;
|
||||
use std::fs;
|
||||
|
||||
pub struct QaInput {
|
||||
pub question: String,
|
||||
pub context: String,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct QaExample {
|
||||
@ -36,6 +42,7 @@ pub struct QaFeature {
|
||||
pub attention_mask: Vec<i64>,
|
||||
pub token_to_orig_map: HashMap<i64, i64>,
|
||||
pub p_mask: Vec<i8>,
|
||||
pub example_index: i64,
|
||||
|
||||
}
|
||||
|
||||
@ -130,63 +137,111 @@ impl QuestionAnsweringModel {
|
||||
})
|
||||
}
|
||||
|
||||
fn prepare_for_model(&self, question: &str, context: &str) -> (QaExample, Vec<QaFeature>, Tensor, Tensor) {
|
||||
let qa_example = QaExample::new(question, context);
|
||||
let mut input_ids: Vec<Tensor> = vec!();
|
||||
let mut attention_masks: Vec<Tensor> = vec!();
|
||||
|
||||
let features = self.generate_features(&qa_example, self.max_seq_len, self.doc_stride, self.max_query_length);
|
||||
for feature in &features {
|
||||
input_ids.push(Tensor::of_slice(&feature.input_ids).to(self.var_store.device()));
|
||||
attention_masks.push(Tensor::of_slice(&feature.attention_mask).to(self.var_store.device()));
|
||||
fn generate_batch_indices(&self, features: &Vec<QaFeature>, batch_size: usize) -> Vec<(usize, usize)> {
|
||||
let mut example_features_length: HashMap<i64, usize> = HashMap::new();
|
||||
for feature in features {
|
||||
let count = example_features_length.entry(feature.example_index).or_insert(0);
|
||||
*count += 1;
|
||||
}
|
||||
let input_ids = Tensor::stack(&input_ids, 0);
|
||||
let attention_masks = Tensor::stack(&attention_masks, 0);
|
||||
(qa_example, features, input_ids, attention_masks)
|
||||
}
|
||||
|
||||
pub fn predict(&self, question: &str, context: &str, top_k: i64) -> Vec<Answer> {
|
||||
let (qa_example, qa_features, input_ids, attention_mask) = self.prepare_for_model(question, context);
|
||||
let mut batch_indices: Vec<(usize, usize)> = Vec::with_capacity(features.len());
|
||||
|
||||
let (start_logits, end_logits, _, _) = no_grad(|| {
|
||||
self.distilbert_qa.forward_t(Some(input_ids), Some(attention_mask), None, false).unwrap()
|
||||
});
|
||||
let start_logits = start_logits.to(Device::Cpu);
|
||||
let end_logits = end_logits.to(Device::Cpu);
|
||||
let mut batch_length = 0usize;
|
||||
let mut start = 0usize;
|
||||
let mut end = 0usize;
|
||||
|
||||
let mut answers: Vec<Answer> = vec!();
|
||||
for (feature_idx, feature) in (0..start_logits.size()[0]).zip(qa_features) {
|
||||
let start = start_logits.get(feature_idx);
|
||||
let end = end_logits.get(feature_idx);
|
||||
let p_mask = (Tensor::of_slice(&feature.p_mask) - 1).abs();
|
||||
|
||||
let start: Tensor = start.exp() / start.exp().sum(Float) * &p_mask;
|
||||
let end: Tensor = end.exp() / end.exp().sum(Float) * &p_mask;
|
||||
|
||||
let (starts, ends, scores) = self.decode(&start, &end, top_k);
|
||||
|
||||
for idx in 0..starts.len() {
|
||||
let start_pos = feature.token_to_orig_map[&starts[idx]] as usize;
|
||||
let end_pos = feature.token_to_orig_map[&ends[idx]] as usize;
|
||||
let answer = qa_example.doc_tokens[start_pos..end_pos + 1].join(" ");
|
||||
|
||||
let start = qa_example.char_to_word_offset
|
||||
.iter()
|
||||
.position(|&v| v as usize == start_pos)
|
||||
.unwrap();
|
||||
|
||||
let end = qa_example.char_to_word_offset
|
||||
.iter()
|
||||
.rposition(|&v| v as usize == end_pos)
|
||||
.unwrap();
|
||||
|
||||
answers.push(
|
||||
Answer { score: scores[idx], start, end, answer })
|
||||
for &feature_length in example_features_length.values() {
|
||||
if batch_length + feature_length <= batch_size {
|
||||
end += feature_length;
|
||||
batch_length += feature_length;
|
||||
} else {
|
||||
batch_indices.push((start, end));
|
||||
start = end;
|
||||
end += feature_length;
|
||||
batch_length = 1usize;
|
||||
}
|
||||
}
|
||||
answers.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
|
||||
batch_indices.push((start, end));
|
||||
batch_indices
|
||||
}
|
||||
|
||||
answers[..(top_k as usize)].to_vec()
|
||||
pub fn predict(&self, qa_inputs: &[QaInput], top_k: i64, batch_size: usize) -> Vec<Vec<Answer>> {
|
||||
let examples: Vec<QaExample> = qa_inputs
|
||||
.iter()
|
||||
.map(|qa_input| QaExample::new(&qa_input.question, &qa_input.context))
|
||||
.collect();
|
||||
|
||||
let features: Vec<QaFeature> = examples
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(example_index, qa_example)| self.generate_features(&qa_example, self.max_seq_len, self.doc_stride, self.max_query_length, example_index as i64))
|
||||
.flatten()
|
||||
.collect();
|
||||
|
||||
let batch_indices = self.generate_batch_indices(&features, batch_size);
|
||||
let mut all_answers = vec!();
|
||||
|
||||
for (start, end) in batch_indices {
|
||||
let batch_features = &features[start..end];
|
||||
let mut input_ids: Vec<Tensor> = vec!();
|
||||
let mut attention_masks: Vec<Tensor> = vec!();
|
||||
for feature in batch_features {
|
||||
input_ids.push(Tensor::of_slice(&feature.input_ids));
|
||||
attention_masks.push(Tensor::of_slice(&feature.attention_mask));
|
||||
}
|
||||
let input_ids = Tensor::stack(&input_ids, 0).to(self.var_store.device());
|
||||
let attention_masks = Tensor::stack(&attention_masks, 0).to(self.var_store.device());
|
||||
|
||||
let (start_logits, end_logits, _, _) = no_grad(|| {
|
||||
self.distilbert_qa.forward_t(Some(input_ids), Some(attention_masks), None, false).unwrap()
|
||||
});
|
||||
let start_logits = start_logits.to(Device::Cpu);
|
||||
let end_logits = end_logits.to(Device::Cpu);
|
||||
|
||||
let example_index_to_feature_end_position: Vec<(usize, i64)> = batch_features
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(feature_index, feature)| (feature.example_index as usize, feature_index as i64 + 1))
|
||||
.collect();
|
||||
|
||||
let mut feature_id_start = 0;
|
||||
for (example_id, max_feature_id) in example_index_to_feature_end_position {
|
||||
let mut answers: Vec<Answer> = vec!();
|
||||
let example = &examples[example_id];
|
||||
for feature_idx in feature_id_start..max_feature_id {
|
||||
let feature = &batch_features[feature_idx as usize];
|
||||
let start = start_logits.get(feature_idx);
|
||||
let end = end_logits.get(feature_idx);
|
||||
let p_mask = (Tensor::of_slice(&feature.p_mask) - 1).abs();
|
||||
|
||||
let start: Tensor = start.exp() / start.exp().sum(Float) * &p_mask;
|
||||
let end: Tensor = end.exp() / end.exp().sum(Float) * &p_mask;
|
||||
|
||||
let (starts, ends, scores) = self.decode(&start, &end, top_k);
|
||||
|
||||
for idx in 0..starts.len() {
|
||||
let start_pos = feature.token_to_orig_map[&starts[idx]] as usize;
|
||||
let end_pos = feature.token_to_orig_map[&ends[idx]] as usize;
|
||||
let answer = example.doc_tokens[start_pos..end_pos + 1].join(" ");
|
||||
|
||||
let start = example.char_to_word_offset
|
||||
.iter()
|
||||
.position(|&v| v as usize == start_pos)
|
||||
.unwrap();
|
||||
|
||||
let end = example.char_to_word_offset
|
||||
.iter()
|
||||
.rposition(|&v| v as usize == end_pos)
|
||||
.unwrap();
|
||||
|
||||
answers.push(Answer { score: scores[idx], start, end, answer });
|
||||
}
|
||||
}
|
||||
feature_id_start = max_feature_id;
|
||||
all_answers.push(answers[..(top_k as usize)].to_vec());
|
||||
}
|
||||
}
|
||||
all_answers
|
||||
}
|
||||
|
||||
fn decode(&self, start: &Tensor, end: &Tensor, top_k: i64) -> (Vec<i64>, Vec<i64>, Vec<f64>) {
|
||||
@ -216,7 +271,7 @@ impl QuestionAnsweringModel {
|
||||
}
|
||||
|
||||
|
||||
fn generate_features(&self, qa_example: &QaExample, max_seq_length: usize, doc_stride: usize, max_query_length: usize) -> Vec<QaFeature> {
|
||||
fn generate_features(&self, qa_example: &QaExample, max_seq_length: usize, doc_stride: usize, max_query_length: usize, example_index: i64) -> Vec<QaFeature> {
|
||||
let mut tok_to_orig_index: Vec<i64> = vec!();
|
||||
let mut all_doc_tokens: Vec<String> = vec!();
|
||||
|
||||
@ -251,7 +306,7 @@ impl QuestionAnsweringModel {
|
||||
|
||||
let p_mask = self.get_mask(&encoded_span);
|
||||
|
||||
let qa_feature = QaFeature { input_ids: encoded_span.token_ids, attention_mask, token_to_orig_map, p_mask };
|
||||
let qa_feature = QaFeature { input_ids: encoded_span.token_ids, attention_mask, token_to_orig_map, p_mask, example_index };
|
||||
|
||||
spans.push(qa_feature);
|
||||
if encoded_span.num_truncated_tokens == 0 {
|
||||
@ -320,3 +375,27 @@ impl QuestionAnsweringModel {
|
||||
p_mask
|
||||
}
|
||||
}
|
||||
|
||||
pub fn squad_processor(file_path: PathBuf) -> Vec<QaInput> {
|
||||
let file = fs::File::open(file_path).expect("unable to open file");
|
||||
let json: serde_json::Value = serde_json::from_reader(file).expect("JSON not properly formatted");
|
||||
let data = json
|
||||
.get("data").expect("SQuAD file does not contain data field")
|
||||
.as_array().expect("Data array not properly formatted");
|
||||
|
||||
let mut qa_inputs: Vec<QaInput> = Vec::with_capacity(data.len());
|
||||
for qa_input in data.iter() {
|
||||
let qa_input = qa_input.as_object().unwrap();
|
||||
let paragraphs = qa_input.get("paragraphs").unwrap().as_array().unwrap();
|
||||
for paragraph in paragraphs.iter() {
|
||||
let paragraph = paragraph.as_object().unwrap();
|
||||
let context = paragraph.get("context").unwrap().as_str().unwrap();
|
||||
let qas = paragraph.get("qas").unwrap().as_array().unwrap();
|
||||
for qa in qas.iter() {
|
||||
let question = qa.as_object().unwrap().get("question").unwrap().as_str().unwrap();
|
||||
qa_inputs.push(QaInput { question: question.to_owned(), context: context.to_owned() });
|
||||
}
|
||||
}
|
||||
}
|
||||
qa_inputs
|
||||
}
|
||||
|
@ -6,7 +6,7 @@ use rust_tokenizers::bert_tokenizer::BertTokenizer;
|
||||
use rust_tokenizers::preprocessing::vocab::base_vocab::Vocab;
|
||||
use rust_bert::{SentimentClassifier, SentimentPolarity};
|
||||
use rust_bert::common::config::Config;
|
||||
use rust_bert::pipelines::question_answering::QuestionAnsweringModel;
|
||||
use rust_bert::pipelines::question_answering::{QuestionAnsweringModel, QaInput};
|
||||
|
||||
extern crate failure;
|
||||
extern crate dirs;
|
||||
@ -226,16 +226,18 @@ fn distilbert_question_answering() -> failure::Fallible<()> {
|
||||
let qa_model = QuestionAnsweringModel::new(vocab_path, config_path, weights_path, device)?;
|
||||
|
||||
// Define input
|
||||
let question = "Where does Amy live ?";
|
||||
let context = "Amy lives in Amsterdam";
|
||||
let question = String::from("Where does Amy live ?");
|
||||
let context = String::from("Amy lives in Amsterdam");
|
||||
let qa_input = QaInput { question, context };
|
||||
|
||||
let answers = qa_model.predict(question, context, 1);
|
||||
let answers = qa_model.predict(&vec!(qa_input), 1, 32);
|
||||
|
||||
assert_eq!(answers.len(), 1 as usize);
|
||||
assert_eq!(answers[0].start, 13);
|
||||
assert_eq!(answers[0].end, 21);
|
||||
assert!((answers[0].score - 0.9977).abs() < 1e-4);
|
||||
assert_eq!(answers[0].answer, "Amsterdam");
|
||||
assert_eq!(answers[0].len(), 1 as usize);
|
||||
assert_eq!(answers[0][0].start, 13);
|
||||
assert_eq!(answers[0][0].end, 21);
|
||||
assert!((answers[0][0].score - 0.9977).abs() < 1e-4);
|
||||
assert_eq!(answers[0][0].answer, "Amsterdam");
|
||||
|
||||
Ok(())
|
||||
}
|
Loading…
Reference in New Issue
Block a user