Updated NER Pipeline

This commit is contained in:
Guillaume B 2020-04-26 09:13:13 +02:00
parent 2aeace4a6e
commit 0a85d5ba21
4 changed files with 42 additions and 71 deletions

View File

@ -11,35 +11,12 @@
// limitations under the License.
extern crate failure;
extern crate dirs;
use std::path::PathBuf;
use tch::Device;
use failure::err_msg;
use rust_bert::pipelines::ner::NERModel;
fn main() -> failure::Fallible<()> {
// Resources paths
let mut home: PathBuf = dirs::home_dir().unwrap();
home.push("rustbert");
home.push("bert-ner");
let config_path = &home.as_path().join("config.json");
let vocab_path = &home.as_path().join("vocab.txt");
let weights_path = &home.as_path().join("model.ot");
if !config_path.is_file() | !vocab_path.is_file() | !weights_path.is_file() {
return Err(
err_msg("Could not find required resources to run example. \
Please run ../utils/download_dependencies_bert_ner.py \
in a Python environment with dependencies listed in ../requirements.txt"));
}
// Set-up model
let device = Device::cuda_if_available();
let ner_model = NERModel::new(vocab_path,
config_path,
weights_path, device)?;
let ner_model = NERModel::new(Default::default())?;
// Define input
let input = [

View File

@ -56,14 +56,14 @@
//! ```
use rust_tokenizers::bert_tokenizer::BertTokenizer;
use std::path::Path;
use tch::nn::VarStore;
use rust_tokenizers::preprocessing::tokenizer::base_tokenizer::{TruncationStrategy, MultiThreadedTokenizer};
use std::collections::HashMap;
use tch::{Tensor, no_grad, Device};
use tch::kind::Kind::Float;
use crate::bert::{BertForTokenClassification, BertConfig};
use crate::bert::{BertForTokenClassification, BertConfig, BertModelResources, BertConfigResources, BertVocabResources};
use crate::Config;
use crate::common::resources::{Resource, RemoteResource, download_resource};
#[derive(Debug)]
@ -77,6 +77,30 @@ pub struct Entity {
pub label: String,
}
/// # Configuration for NER
/// Contains information regarding the model to load and device to place the model on.
pub struct NERConfig {
/// Model weights resource (default: pretrained BERT model on CoNLL)
pub model_resource: Resource,
/// Config resource (default: pretrained BERT model on CoNLL)
pub config_resource: Resource,
/// Vocab resource (default: pretrained BERT model on CoNLL)
pub vocab_resource: Resource,
/// Device to place the model on (default: CUDA/GPU when available)
pub device: Device,
}
impl Default for NERConfig {
fn default() -> NERConfig {
NERConfig {
model_resource: Resource::Remote(RemoteResource::from_pretrained(BertModelResources::BERT_NER)),
config_resource: Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT_NER)),
vocab_resource: Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT_NER)),
device: Device::cuda_if_available(),
}
}
}
/// # NERModel to extract named entities
pub struct NERModel {
tokenizer: BertTokenizer,
@ -90,34 +114,25 @@ impl NERModel {
///
/// # Arguments
///
/// * `vocab_path` - Path to the model vocabulary, expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers) convention
/// * `config_path` - Path to the model configuration, expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers) convention
/// * `weights_path` - Path to the model weight files. These need to be converted form the `.bin` to `.ot` format using the utility script provided.
/// * `device` - Device to run the model on, e.g. `Device::Cpu` or `Device::Cuda(0)`
/// * `ner_config` - `NERConfig` object containing the resource references (model, vocabulary, configuration) and device placement (CPU/GPU)
///
/// # Example
///
/// ```no_run
///# fn main() -> failure::Fallible<()> {
/// use tch::Device;
/// use std::path::{Path, PathBuf};
/// use rust_bert::pipelines::ner::NERModel;
///
/// let mut home: PathBuf = dirs::home_dir().unwrap();
/// let config_path = &home.as_path().join("config.json");
/// let vocab_path = &home.as_path().join("vocab.txt");
/// let weights_path = &home.as_path().join("model.ot");
/// let device = Device::Cpu;
/// let ner_model = NERModel::new(vocab_path,
/// config_path,
/// weights_path,
/// device)?;
/// let ner_model = NERModel::new(Default::default())?;
///# Ok(())
///# }
/// ```
///
pub fn new(vocab_path: &Path, config_path: &Path, weights_path: &Path, device: Device)
-> failure::Fallible<NERModel> {
pub fn new(ner_config: NERConfig) -> failure::Fallible<NERModel> {
let config_path = download_resource(&ner_config.config_resource)?;
let vocab_path = download_resource(&ner_config.vocab_resource)?;
let weights_path = download_resource(&ner_config.model_resource)?;
let device = ner_config.device;
let tokenizer = BertTokenizer::from_file(vocab_path.to_str().unwrap(), false);
let mut var_store = VarStore::new(device);
let config = BertConfig::from_file(config_path);
@ -160,19 +175,9 @@ impl NERModel {
///
/// ```no_run
///# fn main() -> failure::Fallible<()> {
///# use tch::Device;
///# use std::path::{Path, PathBuf};
///# use rust_bert::pipelines::ner::NERModel;
///#
///# let mut home: PathBuf = dirs::home_dir().unwrap();
///# let config_path = &home.as_path().join("config.json");
///# let vocab_path = &home.as_path().join("vocab.txt");
///# let weights_path = &home.as_path().join("model.ot");
///# let device = Device::Cpu;
/// let ner_model = NERModel::new(vocab_path,
/// config_path,
/// weights_path,
/// device)?;
///
/// let ner_model = NERModel::new(Default::default())?;
/// let input = [
/// "My name is Amy. I live in Paris.",
/// "Paris is a city in France."

View File

@ -145,14 +145,14 @@ impl QaExample {
}
}
/// # Configuration for sentiment classification
/// # Configuration for question answering
/// Contains information regarding the model to load and device to place the model on.
pub struct QuestionAnsweringConfig {
/// Model weights resource (default: pretrained DistilBERT model on SST-2)
/// Model weights resource (default: pretrained DistilBERT model on SQuAD)
pub model_resource: Resource,
/// Config resource (default: pretrained DistilBERT model on SST-2)
/// Config resource (default: pretrained DistilBERT model on SQuAD)
pub config_resource: Resource,
/// Vocab resource (default: pretrained DistilBERT model on SST-2)
/// Vocab resource (default: pretrained DistilBERT model on SQuAD)
pub vocab_resource: Resource,
/// Device to place the model on (default: CUDA/GPU when available)
pub device: Device,

View File

@ -296,19 +296,8 @@ fn bert_for_question_answering() -> failure::Fallible<()> {
#[test]
fn bert_pre_trained_ner() -> failure::Fallible<()> {
// Resources paths
let config_dependency = Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT_NER));
let vocab_dependency = Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT_NER));
let weights_dependency = Resource::Remote(RemoteResource::from_pretrained(BertModelResources::BERT_NER));
let config_path = download_resource(&config_dependency)?;
let vocab_path = download_resource(&vocab_dependency)?;
let weights_path = download_resource(&weights_dependency)?;
// Set-up model
let device = Device::cuda_if_available();
let ner_model = NERModel::new(vocab_path,
config_path,
weights_path, device)?;
let ner_model = NERModel::new(Default::default())?;
// Define input
let input = [