Added pre-trained NER model to pipelines

This commit is contained in:
Guillaume B 2020-02-24 14:19:14 +01:00
parent 6282536bbc
commit 6a3bfee4a3
11 changed files with 231 additions and 6 deletions

1
.gitignore vendored
View File

@ -1,3 +1,4 @@
.idea/*
# Generated by Cargo
# will have compiled files and executables
/target/

37
examples/ner.rs Normal file
View File

@ -0,0 +1,37 @@
extern crate failure;
extern crate dirs;
use std::path::PathBuf;
use rust_bert::pipelines::ner::NERModel;
use tch::Device;
fn main() -> failure::Fallible<()> {
// Resources paths
let mut home: PathBuf = dirs::home_dir().unwrap();
home.push("rustbert");
home.push("bert-ner");
let config_path = &home.as_path().join("config.json");
let vocab_path = &home.as_path().join("vocab.txt");
let weights_path = &home.as_path().join("model.ot");
// Set-up model
let device = Device::cuda_if_available();
let ner_model = NERModel::new(vocab_path,
config_path,
weights_path, device)?;
// Define input
let input = [
"My name is Amy. I live in Paris.",
"Paris is a city in France."
];
// Run model
let output = ner_model.predict(input.to_vec());
for entity in output {
println!("{:?}", entity);
}
Ok(())
}

37
examples/sentiment.rs Normal file
View File

@ -0,0 +1,37 @@
extern crate failure;
extern crate dirs;
use std::path::PathBuf;
use rust_bert::pipelines::ner::NERModel;
use tch::Device;
fn main() -> failure::Fallible<()> {
// Resources paths
let mut home: PathBuf = dirs::home_dir().unwrap();
home.push("rustbert");
home.push("bert-ner");
let config_path = &home.as_path().join("config.json");
let vocab_path = &home.as_path().join("vocab.txt");
let weights_path = &home.as_path().join("model.ot");
// Set-up model
let device = Device::cuda_if_available();
let ner_model = NERModel::new(vocab_path,
config_path,
weights_path, device)?;
// Define input
let input = [
"My name is Amy. I live in Paris.",
"Paris is a city in France."
];
// Run model
let output = ner_model.predict(input.to_vec());
for entity in output {
println!("{:?}", entity);
}
Ok(())
}

View File

@ -47,8 +47,8 @@ pub struct BertConfig {
pub output_attentions: Option<bool>,
pub output_hidden_states: Option<bool>,
pub is_decoder: Option<bool>,
pub id2label: Option<HashMap<i32, String>>,
pub label2id: Option<HashMap<String, i32>>,
pub id2label: Option<HashMap<i64, String>>,
pub label2id: Option<HashMap<String, i64>>,
pub num_labels: Option<i64>,
}

View File

@ -2,4 +2,3 @@ pub mod distilbert;
mod embeddings;
mod attention;
mod transformer;
pub mod sentiment;

View File

@ -2,11 +2,13 @@ pub mod distilbert;
pub mod bert;
pub mod roberta;
pub mod common;
pub mod pipelines;
pub use distilbert::distilbert::{DistilBertConfig, DistilBertModel, DistilBertModelClassifier, DistilBertModelMaskedLM, DistilBertForTokenClassification, DistilBertForQuestionAnswering};
pub use distilbert::sentiment::{Sentiment, SentimentPolarity, SentimentClassifier};
pub use bert::bert::BertConfig;
pub use bert::bert::{BertModel, BertForSequenceClassification, BertForMaskedLM, BertForQuestionAnswering, BertForTokenClassification, BertForMultipleChoice};
pub use roberta::roberta::{RobertaForSequenceClassification, RobertaForMaskedLM, RobertaForQuestionAnswering, RobertaForTokenClassification, RobertaForMultipleChoice};
pub use roberta::roberta::{RobertaForSequenceClassification, RobertaForMaskedLM, RobertaForQuestionAnswering, RobertaForTokenClassification, RobertaForMultipleChoice};
pub use pipelines::sentiment::{Sentiment, SentimentPolarity, SentimentClassifier};

View File

@ -1,6 +1,6 @@
use std::path::PathBuf;
use tch::Device;
use rust_bert::distilbert::sentiment::SentimentClassifier;
use rust_bert::pipelines::sentiment::SentimentClassifier;
extern crate failure;
extern crate dirs;

2
src/pipelines/mod.rs Normal file
View File

@ -0,0 +1,2 @@
pub mod sentiment;
pub mod ner;

101
src/pipelines/ner.rs Normal file
View File

@ -0,0 +1,101 @@
// Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc.
// Copyright 2019 Guillaume Becquin
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use rust_tokenizers::bert_tokenizer::BertTokenizer;
use std::path::Path;
use tch::nn::VarStore;
use rust_tokenizers::preprocessing::tokenizer::base_tokenizer::{TruncationStrategy, MultiThreadedTokenizer};
use crate::{BertForTokenClassification, BertConfig};
use std::collections::HashMap;
use crate::common::config::Config;
use tch::{Tensor, no_grad, Device};
use tch::kind::Kind::Float;
#[derive(Debug)]
pub struct Entity {
pub word: String,
pub score: f64,
pub label: String,
}
pub struct NERModel {
tokenizer: BertTokenizer,
bert_sequence_classifier: BertForTokenClassification,
label_mapping: HashMap<i64, String>,
var_store: VarStore,
}
impl NERModel {
pub fn new(vocab_path: &Path, model_config_path: &Path, model_weight_path: &Path, device: Device)
-> failure::Fallible<NERModel> {
let tokenizer = BertTokenizer::from_file(vocab_path.to_str().unwrap(), false);
let mut var_store = VarStore::new(device);
let config = BertConfig::from_file(model_config_path);
let bert_sequence_classifier = BertForTokenClassification::new(&var_store.root(), &config);
let label_mapping = config.id2label.expect("No label dictionary (id2label) provided in configuration file");
var_store.load(model_weight_path)?;
Ok(NERModel { tokenizer, bert_sequence_classifier, label_mapping, var_store })
}
fn prepare_for_model(&self, input: Vec<&str>) -> Tensor {
let tokenized_input = self.tokenizer.encode_list(input.to_vec(),
128,
&TruncationStrategy::LongestFirst,
0);
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
let tokenized_input = tokenized_input.
iter().
map(|input| input.token_ids.clone()).
map(|mut input| {
input.extend(vec![0; max_len - input.len()]);
input
}).
map(|input|
Tensor::of_slice(&(input))).
collect::<Vec<_>>();
Tensor::stack(tokenized_input.as_slice(), 0).to(self.var_store.device())
}
pub fn predict(&self, input: Vec<&str>) -> Vec<Entity> {
let input_tensor = self.prepare_for_model(input);
let (output, _, _) = no_grad(|| {
self.bert_sequence_classifier
.forward_t(Some(input_tensor.copy()),
None,
None,
None,
None,
false)
});
let output = output.detach().to(Device::Cpu);
let score: Tensor = output.exp() / output.exp().sum1(&[-1], true, Float);
let labels_idx = &score.argmax(-1, true);
let mut entities: Vec<Entity> = vec!();
for sentence_idx in 0..labels_idx.size()[0] {
let labels = labels_idx.get(sentence_idx);
for position_idx in 0..labels.size()[0] {
let label = labels.int64_value(&[position_idx]);
if label != 0 {
entities.push(Entity {
word: rust_tokenizers::preprocessing::tokenizer::base_tokenizer::Tokenizer::decode(&self.tokenizer, vec!(input_tensor.int64_value(&[sentence_idx, position_idx])), true, true),
score: score.double_value(&[sentence_idx, position_idx, label]),
label: self.label_mapping.get(&label).expect("Index out of vocabulary bounds.").to_owned(),
});
}
}
}
entities
}
}

View File

@ -0,0 +1,46 @@
from transformers.file_utils import get_from_cache, S3_BUCKET_PREFIX
from transformers.pipelines import SUPPORTED_TASKS
from pathlib import Path
import shutil
import os
import numpy as np
import torch
import subprocess
ROOT_PATH = S3_BUCKET_PREFIX + '/' + SUPPORTED_TASKS['ner']['default']['model']['pt']
config_path = ROOT_PATH + '/config.json'
vocab_path = ROOT_PATH + '/vocab.txt'
weights_path = ROOT_PATH + '/pytorch_model.bin'
target_path = Path.home() / 'rustbert' / 'bert-ner'
temp_config = get_from_cache(config_path)
temp_vocab = get_from_cache(vocab_path)
temp_weights = get_from_cache(weights_path)
os.makedirs(str(target_path), exist_ok=True)
config_path = str(target_path / 'config.json')
vocab_path = str(target_path / 'vocab.txt')
model_path = str(target_path / 'model.bin')
shutil.copy(temp_config, config_path)
shutil.copy(temp_vocab, vocab_path)
shutil.copy(temp_weights, model_path)
weights = torch.load(temp_weights, map_location='cpu')
nps = {}
for k, v in weights.items():
k = k.replace("gamma", "weight").replace("beta", "bias")
nps[k] = np.ascontiguousarray(v.cpu().numpy())
np.savez(target_path / 'model.npz', **nps)
source = str(target_path / 'model.npz')
target = str(target_path / 'model.ot')
toml_location = (Path(__file__).resolve() / '..' / '..' / 'Cargo.toml').resolve()
subprocess.call(
['cargo', 'run', '--bin=convert-tensor', '--manifest-path=%s' % toml_location, '--', source, target])