Addition of BERT option for QA

This commit is contained in:
Guillaume B 2020-07-08 21:59:31 +02:00
parent 2dad825586
commit e6938e1dcb
4 changed files with 151 additions and 0 deletions

View File

@ -0,0 +1,57 @@
// Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc.
// Copyright 2019 Guillaume Becquin
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
extern crate failure;
use rust_bert::bert::{BertConfigResources, BertModelResources, BertVocabResources};
use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::question_answering::{
QaInput, QuestionAnsweringConfig, QuestionAnsweringModel,
};
use rust_bert::resources::{RemoteResource, Resource};
fn main() -> failure::Fallible<()> {
// Set-up Question Answering model
let config = QuestionAnsweringConfig::new(
ModelType::Bert,
Resource::Remote(RemoteResource::from_pretrained(BertModelResources::BERT_QA)),
Resource::Remote(RemoteResource::from_pretrained(
BertConfigResources::BERT_NER,
)),
Resource::Remote(RemoteResource::from_pretrained(
BertVocabResources::BERT_NER,
)),
None, //merges resource only relevant with ModelType::Roberta
false, //lowercase
);
let qa_model = QuestionAnsweringModel::new(config)?;
// Define input
let question_1 = String::from("Where does Amy live ?");
let context_1 = String::from("Amy lives in Amsterdam");
let question_2 = String::from("Where does Eric live");
let context_2 = String::from("While Amy lives in Amsterdam, Eric is in The Hague.");
let qa_input_1 = QaInput {
question: question_1,
context: context_1,
};
let qa_input_2 = QaInput {
question: question_2,
context: context_2,
};
// Get answer
let answers = qa_model.predict(&vec![qa_input_1, qa_input_2], 1, 32);
println!("{:?}", answers);
Ok(())
}

View File

@ -44,6 +44,11 @@ impl BertModelResources {
"bert-ner/model.ot",
"https://cdn.huggingface.co/dbmdz/bert-large-cased-finetuned-conll03-english/rust_model.ot",
);
/// Shared under Apache 2.0 license by Hugging Face Inc at https://github.com/huggingface/transformers/tree/master/examples/question-answering. Modified with conversion to C-array format.
pub const BERT_QA: (&'static str, &'static str) = (
"bert-qa/model.ot",
"https://cdn.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad/rust_model.ot",
);
}
impl BertConfigResources {
@ -57,6 +62,11 @@ impl BertConfigResources {
"bert-ner/config.json",
"https://cdn.huggingface.co/dbmdz/bert-large-cased-finetuned-conll03-english/config.json",
);
/// Shared under Apache 2.0 license by Hugging Face Inc at https://github.com/huggingface/transformers/tree/master/examples/question-answering. Modified with conversion to C-array format.
pub const BERT_QA: (&'static str, &'static str) = (
"bert-qa/config.json",
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-config.json",
);
}
impl BertVocabResources {
@ -70,6 +80,11 @@ impl BertVocabResources {
"bert-ner/vocab.txt",
"https://cdn.huggingface.co/dbmdz/bert-large-cased-finetuned-conll03-english/vocab.txt",
);
/// Shared under Apache 2.0 license by Hugging Face Inc at https://github.com/huggingface/transformers/tree/master/examples/question-answering. Modified with conversion to C-array format.
pub const BERT_QA: (&'static str, &'static str) = (
"bert-qa/vocab.txt",
"https://cdn.huggingface.co/bert-large-cased-whole-word-masking-finetuned-squad-vocab.txt",
);
}
#[allow(non_camel_case_types)]

View File

@ -191,6 +191,37 @@ pub struct QuestionAnsweringConfig {
pub lower_case: bool,
}
impl QuestionAnsweringConfig {
/// Instantiate a new question answering configuration of the supplied type.
///
/// # Arguments
///
/// * `model_type` - `ModelType` indicating the model type to load (must match with the actual data to be loaded!)
/// * model_resource - The `Resource` pointing to the model to load (e.g. model.ot)
/// * config_resource - The `Resource' pointing to the model configuration to load (e.g. config.json)
/// * vocab_resource - The `Resource' pointing to the tokenizer's vocabulary to load (e.g. vocab.txt/vocab.json)
/// * merges_resource - An optional `Resource` tuple (`Option<Resource>`) pointing to the tokenizer's merge file to load (e.g. merges.txt), needed only for Roberta.
/// * lower_case - A `bool' indicating whether the tokeniser should lower case all input (in case of a lower-cased model)
pub fn new(
model_type: ModelType,
model_resource: Resource,
config_resource: Resource,
vocab_resource: Resource,
merges_resource: Option<Resource>,
lower_case: bool,
) -> QuestionAnsweringConfig {
QuestionAnsweringConfig {
model_type,
model_resource,
config_resource,
vocab_resource,
merges_resource,
lower_case,
device: Device::cuda_if_available(),
}
}
}
impl Default for QuestionAnsweringConfig {
fn default() -> QuestionAnsweringConfig {
QuestionAnsweringConfig {

View File

@ -0,0 +1,48 @@
from transformers import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, BERT_PRETRAINED_MODEL_ARCHIVE_MAP
from transformers.tokenization_bert import PRETRAINED_VOCAB_FILES_MAP
from transformers.file_utils import get_from_cache
from pathlib import Path
import shutil
import os
import numpy as np
import torch
import subprocess
config_path = BERT_PRETRAINED_CONFIG_ARCHIVE_MAP["bert-large-cased-whole-word-masking-finetuned-squad"]
vocab_path = PRETRAINED_VOCAB_FILES_MAP["vocab_file"]["bert-large-cased-whole-word-masking-finetuned-squad"]
weights_path = BERT_PRETRAINED_MODEL_ARCHIVE_MAP["bert-large-cased-whole-word-masking-finetuned-squad"]
target_path = Path.home() / 'rustbert' / 'bert-qa'
temp_config = get_from_cache(config_path)
temp_vocab = get_from_cache(vocab_path)
temp_weights = get_from_cache(weights_path)
os.makedirs(str(target_path), exist_ok=True)
config_path = str(target_path / 'config.json')
vocab_path = str(target_path / 'vocab.txt')
model_path = str(target_path / 'model.bin')
shutil.copy(temp_config, config_path)
shutil.copy(temp_vocab, vocab_path)
shutil.copy(temp_weights, model_path)
weights = torch.load(temp_weights, map_location='cpu')
nps = {}
for k, v in weights.items():
k = k.replace("gamma", "weight").replace("beta", "bias")
nps[k] = np.ascontiguousarray(v.cpu().numpy())
np.savez(target_path / 'model.npz', **nps)
source = str(target_path / 'model.npz')
target = str(target_path / 'model.ot')
toml_location = (Path(__file__).resolve() / '..' / '..' / 'Cargo.toml').resolve()
subprocess.call(
['cargo', 'run', '--bin=convert-tensor', '--manifest-path=%s' % toml_location, '--', source, target])
os.remove(str(target_path / 'model.bin'))
os.remove(str(target_path / 'model.npz'))