diff --git a/.gitignore b/.gitignore index b4706ae..d9d3001 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ +.idea/* # Generated by Cargo # will have compiled files and executables /target/ diff --git a/examples/ner.rs b/examples/ner.rs new file mode 100644 index 0000000..32e04a1 --- /dev/null +++ b/examples/ner.rs @@ -0,0 +1,37 @@ +extern crate failure; +extern crate dirs; + +use std::path::PathBuf; +use rust_bert::pipelines::ner::NERModel; +use tch::Device; + + +fn main() -> failure::Fallible<()> { + // Resources paths + let mut home: PathBuf = dirs::home_dir().unwrap(); + home.push("rustbert"); + home.push("bert-ner"); + let config_path = &home.as_path().join("config.json"); + let vocab_path = &home.as_path().join("vocab.txt"); + let weights_path = &home.as_path().join("model.ot"); + +// Set-up model + let device = Device::cuda_if_available(); + let ner_model = NERModel::new(vocab_path, + config_path, + weights_path, device)?; + +// Define input + let input = [ + "My name is Amy. I live in Paris.", + "Paris is a city in France." + ]; + +// Run model + let output = ner_model.predict(input.to_vec()); + for entity in output { + println!("{:?}", entity); + } + + Ok(()) +} \ No newline at end of file diff --git a/examples/sentiment.rs b/examples/sentiment.rs new file mode 100644 index 0000000..32e04a1 --- /dev/null +++ b/examples/sentiment.rs @@ -0,0 +1,37 @@ +extern crate failure; +extern crate dirs; + +use std::path::PathBuf; +use rust_bert::pipelines::ner::NERModel; +use tch::Device; + + +fn main() -> failure::Fallible<()> { + // Resources paths + let mut home: PathBuf = dirs::home_dir().unwrap(); + home.push("rustbert"); + home.push("bert-ner"); + let config_path = &home.as_path().join("config.json"); + let vocab_path = &home.as_path().join("vocab.txt"); + let weights_path = &home.as_path().join("model.ot"); + +// Set-up model + let device = Device::cuda_if_available(); + let ner_model = NERModel::new(vocab_path, + config_path, + weights_path, device)?; + +// Define input + let input = [ + "My name is Amy. I live in Paris.", + "Paris is a city in France." + ]; + +// Run model + let output = ner_model.predict(input.to_vec()); + for entity in output { + println!("{:?}", entity); + } + + Ok(()) +} \ No newline at end of file diff --git a/src/bert/bert.rs b/src/bert/bert.rs index 7328719..a8a2307 100644 --- a/src/bert/bert.rs +++ b/src/bert/bert.rs @@ -47,8 +47,8 @@ pub struct BertConfig { pub output_attentions: Option, pub output_hidden_states: Option, pub is_decoder: Option, - pub id2label: Option>, - pub label2id: Option>, + pub id2label: Option>, + pub label2id: Option>, pub num_labels: Option, } diff --git a/src/distilbert/mod.rs b/src/distilbert/mod.rs index 105f611..cffea9f 100644 --- a/src/distilbert/mod.rs +++ b/src/distilbert/mod.rs @@ -2,4 +2,3 @@ pub mod distilbert; mod embeddings; mod attention; mod transformer; -pub mod sentiment; \ No newline at end of file diff --git a/src/lib.rs b/src/lib.rs index 5b18c5b..6448ec1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,11 +2,13 @@ pub mod distilbert; pub mod bert; pub mod roberta; pub mod common; +pub mod pipelines; pub use distilbert::distilbert::{DistilBertConfig, DistilBertModel, DistilBertModelClassifier, DistilBertModelMaskedLM, DistilBertForTokenClassification, DistilBertForQuestionAnswering}; -pub use distilbert::sentiment::{Sentiment, SentimentPolarity, SentimentClassifier}; pub use bert::bert::BertConfig; pub use bert::bert::{BertModel, BertForSequenceClassification, BertForMaskedLM, BertForQuestionAnswering, BertForTokenClassification, BertForMultipleChoice}; -pub use roberta::roberta::{RobertaForSequenceClassification, RobertaForMaskedLM, RobertaForQuestionAnswering, RobertaForTokenClassification, RobertaForMultipleChoice}; \ No newline at end of file +pub use roberta::roberta::{RobertaForSequenceClassification, RobertaForMaskedLM, RobertaForQuestionAnswering, RobertaForTokenClassification, RobertaForMultipleChoice}; + +pub use pipelines::sentiment::{Sentiment, SentimentPolarity, SentimentClassifier}; \ No newline at end of file diff --git a/src/main.rs b/src/main.rs index 180d47a..6b347de 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,6 +1,6 @@ use std::path::PathBuf; use tch::Device; -use rust_bert::distilbert::sentiment::SentimentClassifier; +use rust_bert::pipelines::sentiment::SentimentClassifier; extern crate failure; extern crate dirs; diff --git a/src/pipelines/mod.rs b/src/pipelines/mod.rs new file mode 100644 index 0000000..3198847 --- /dev/null +++ b/src/pipelines/mod.rs @@ -0,0 +1,2 @@ +pub mod sentiment; +pub mod ner; \ No newline at end of file diff --git a/src/pipelines/ner.rs b/src/pipelines/ner.rs new file mode 100644 index 0000000..a075069 --- /dev/null +++ b/src/pipelines/ner.rs @@ -0,0 +1,101 @@ +// Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc. +// Copyright 2019 Guillaume Becquin +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + + +use rust_tokenizers::bert_tokenizer::BertTokenizer; +use std::path::Path; +use tch::nn::VarStore; +use rust_tokenizers::preprocessing::tokenizer::base_tokenizer::{TruncationStrategy, MultiThreadedTokenizer}; +use crate::{BertForTokenClassification, BertConfig}; +use std::collections::HashMap; +use crate::common::config::Config; +use tch::{Tensor, no_grad, Device}; +use tch::kind::Kind::Float; + + +#[derive(Debug)] +pub struct Entity { + pub word: String, + pub score: f64, + pub label: String, +} + +pub struct NERModel { + tokenizer: BertTokenizer, + bert_sequence_classifier: BertForTokenClassification, + label_mapping: HashMap, + var_store: VarStore, +} + +impl NERModel { + pub fn new(vocab_path: &Path, model_config_path: &Path, model_weight_path: &Path, device: Device) + -> failure::Fallible { + let tokenizer = BertTokenizer::from_file(vocab_path.to_str().unwrap(), false); + let mut var_store = VarStore::new(device); + let config = BertConfig::from_file(model_config_path); + let bert_sequence_classifier = BertForTokenClassification::new(&var_store.root(), &config); + let label_mapping = config.id2label.expect("No label dictionary (id2label) provided in configuration file"); + var_store.load(model_weight_path)?; + Ok(NERModel { tokenizer, bert_sequence_classifier, label_mapping, var_store }) + } + + fn prepare_for_model(&self, input: Vec<&str>) -> Tensor { + let tokenized_input = self.tokenizer.encode_list(input.to_vec(), + 128, + &TruncationStrategy::LongestFirst, + 0); + let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap(); + let tokenized_input = tokenized_input. + iter(). + map(|input| input.token_ids.clone()). + map(|mut input| { + input.extend(vec![0; max_len - input.len()]); + input + }). + map(|input| + Tensor::of_slice(&(input))). + collect::>(); + Tensor::stack(tokenized_input.as_slice(), 0).to(self.var_store.device()) + } + + pub fn predict(&self, input: Vec<&str>) -> Vec { + let input_tensor = self.prepare_for_model(input); + let (output, _, _) = no_grad(|| { + self.bert_sequence_classifier + .forward_t(Some(input_tensor.copy()), + None, + None, + None, + None, + false) + }); + let output = output.detach().to(Device::Cpu); + let score: Tensor = output.exp() / output.exp().sum1(&[-1], true, Float); + let labels_idx = &score.argmax(-1, true); + + let mut entities: Vec = vec!(); + for sentence_idx in 0..labels_idx.size()[0] { + let labels = labels_idx.get(sentence_idx); + for position_idx in 0..labels.size()[0] { + let label = labels.int64_value(&[position_idx]); + if label != 0 { + entities.push(Entity { + word: rust_tokenizers::preprocessing::tokenizer::base_tokenizer::Tokenizer::decode(&self.tokenizer, vec!(input_tensor.int64_value(&[sentence_idx, position_idx])), true, true), + score: score.double_value(&[sentence_idx, position_idx, label]), + label: self.label_mapping.get(&label).expect("Index out of vocabulary bounds.").to_owned(), + }); + } + } + } + entities + } +} \ No newline at end of file diff --git a/src/distilbert/sentiment.rs b/src/pipelines/sentiment.rs similarity index 100% rename from src/distilbert/sentiment.rs rename to src/pipelines/sentiment.rs diff --git a/utils/download-dependencies_bert_ner.py b/utils/download-dependencies_bert_ner.py new file mode 100644 index 0000000..a9aba58 --- /dev/null +++ b/utils/download-dependencies_bert_ner.py @@ -0,0 +1,46 @@ +from transformers.file_utils import get_from_cache, S3_BUCKET_PREFIX +from transformers.pipelines import SUPPORTED_TASKS +from pathlib import Path +import shutil +import os +import numpy as np +import torch +import subprocess + +ROOT_PATH = S3_BUCKET_PREFIX + '/' + SUPPORTED_TASKS['ner']['default']['model']['pt'] + +config_path = ROOT_PATH + '/config.json' +vocab_path = ROOT_PATH + '/vocab.txt' +weights_path = ROOT_PATH + '/pytorch_model.bin' + +target_path = Path.home() / 'rustbert' / 'bert-ner' + +temp_config = get_from_cache(config_path) +temp_vocab = get_from_cache(vocab_path) +temp_weights = get_from_cache(weights_path) + +os.makedirs(str(target_path), exist_ok=True) + +config_path = str(target_path / 'config.json') +vocab_path = str(target_path / 'vocab.txt') +model_path = str(target_path / 'model.bin') + +shutil.copy(temp_config, config_path) +shutil.copy(temp_vocab, vocab_path) +shutil.copy(temp_weights, model_path) + +weights = torch.load(temp_weights, map_location='cpu') +nps = {} +for k, v in weights.items(): + k = k.replace("gamma", "weight").replace("beta", "bias") + nps[k] = np.ascontiguousarray(v.cpu().numpy()) + +np.savez(target_path / 'model.npz', **nps) + +source = str(target_path / 'model.npz') +target = str(target_path / 'model.ot') + +toml_location = (Path(__file__).resolve() / '..' / '..' / 'Cargo.toml').resolve() + +subprocess.call( + ['cargo', 'run', '--bin=convert-tensor', '--manifest-path=%s' % toml_location, '--', source, target])