Merge pull request #33 from guillaume-be/token_entity_parsing

Token entity parsing
This commit is contained in:
guillaume-be 2020-05-13 15:11:09 +00:00 committed by GitHub
commit 16753ea8fb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 278 additions and 57 deletions

View File

@ -30,7 +30,7 @@ all-tests = []
features = [ "doc-only" ]
[dependencies]
rust_tokenizers = "~3.0.0"
rust_tokenizers = "~3.0.2"
tch = "~0.1.7"
serde_json = "1.0.51"
serde = {version = "1.0.106", features = ["derive"]}

View File

@ -20,8 +20,8 @@ fn main() -> failure::Fallible<()> {
// Define input
let input = [
"My name is Amy. I live in Paris.",
"Paris is a city in France."
"My name is Amélie. I live in Москва.",
"Chongqing is a city in China."
];
// Run model

View File

@ -0,0 +1,43 @@
// Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc.
// Copyright 2019 Guillaume Becquin
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use rust_bert::pipelines::token_classification::{TokenClassificationModel, TokenClassificationConfig, LabelAggregationOption};
use rust_bert::resources::{Resource, RemoteResource};
use rust_bert::bert::{BertModelResources, BertVocabResources, BertConfigResources};
use rust_bert::pipelines::common::ModelType;
fn main() -> failure::Fallible<()> {
// Load a configuration
let config = TokenClassificationConfig::new(ModelType::Bert,
Resource::Remote(RemoteResource::from_pretrained(BertModelResources::BERT_NER)),
Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT_NER)),
Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT_NER)),
None, //merges resource only relevant with ModelType::Roberta
false, //lowercase
LabelAggregationOption::Mode,
);
// Create the model
let token_classification_model = TokenClassificationModel::new(config)?;
let input = [
"My name is Amélie. I live in Москва.",
"Chongqing is a city in China."
];
let token_outputs = token_classification_model.predict(&input, true, false); //ignore_first_label = true (only returns the NER parts, ignoring first label O)
for token in token_outputs {
println!("{:?}", token);
}
Ok(())
}

View File

@ -45,7 +45,7 @@
//!# ;
//! ```
use crate::pipelines::token_classification::{TokenClassificationModel,TokenClassificationConfig};
use crate::pipelines::token_classification::{TokenClassificationModel, TokenClassificationConfig};
#[derive(Debug)]
@ -58,6 +58,7 @@ pub struct Entity {
/// Entity label (e.g. ORG, LOC...)
pub label: String,
}
//type alias for some backward compatibility
type NERConfig = TokenClassificationConfig;
@ -86,7 +87,7 @@ impl NERModel {
///
pub fn new(ner_config: NERConfig) -> failure::Fallible<NERModel> {
let model = TokenClassificationModel::new(ner_config)?;
Ok(NERModel { token_classification_model: model})
Ok(NERModel { token_classification_model: model })
}
/// Extract entities from a text
@ -116,12 +117,16 @@ impl NERModel {
/// ```
///
pub fn predict(&self, input: &[&str]) -> Vec<Entity> {
self.token_classification_model.predict(input, true).into_iter().map(|token| {
Entity {
word: token.text,
score: token.score,
label: token.label,
}
}).collect()
self.token_classification_model
.predict(input, true, false)
.into_iter()
.filter(|token| token.label != "O")
.map(|token| {
Entity {
word: token.text,
score: token.score,
label: token.label,
}
}).collect()
}
}

View File

@ -22,12 +22,14 @@
//!# fn main() -> failure::Fallible<()> {
//!
//! //Load a configuration
//! use rust_bert::pipelines::token_classification::LabelAggregationOption;
//! let config = TokenClassificationConfig::new(ModelType::Bert,
//! Resource::Remote(RemoteResource::from_pretrained(BertModelResources::BERT_NER)),
//! Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT_NER)),
//! Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT_NER)),
//! None, //merges resource only relevant with ModelType::Roberta
//! false, //lowercase
//! LabelAggregationOption::Mode
//! );
//!
//! //Create the model
@ -37,25 +39,29 @@
//! "My name is Amy. I live in Paris.",
//! "Paris is a city in France."
//! ];
//! let output = token_classification_model.predict(&input, true); //ignore_first_label = true (only returns the NER parts, ignoring first label O)
//! let output = token_classification_model.predict(&input, true, true); //ignore_first_label = true (only returns the NER parts, ignoring first label O)
//!# Ok(())
//!# }
//! ```
//! Output: \
//! ```no_run
//!# use rust_bert::pipelines::token_classification::Token;
//! use rust_tokenizers::preprocessing::tokenizer::base_tokenizer::Mask::Special;
//! use rust_tokenizers::preprocessing::tokenizer::base_tokenizer::{Offset, Mask};
//!# let output =
//! [
//! Token { text: String::from("Amy"), score: 0.9986, label: String::from("I-PER"), sentence: 0, index: 0},
//! Token { text: String::from("Paris"), score: 0.9985, label: String::from("I-LOC"), sentence: 0, index: 9},
//! Token { text: String::from("Paris"), score: 0.9988, label: String::from("I-LOC"), sentence: 1, index: 1},
//! Token { text: String::from("France"), score: 0.9993, label: String::from("I-LOC"), sentence: 1, index: 6}
//! Token { text: String::from("[CLS]"), score: 0.9995001554489136, label: String::from("O"), label_index: 0, sentence: 0, index: 0, word_index: 0, offset: None, mask: Special },
//! Token { text: String::from("My"), score: 0.9980450868606567, label: String::from("O"), label_index: 0, sentence: 0, index: 1, word_index: 1, offset: Some(Offset { begin: 0, end: 2 }), mask: Mask::None },
//! Token { text: String::from("name"), score: 0.9995062351226807, label: String::from("O"), label_index: 0, sentence: 0, index: 2, word_index: 2, offset: Some(Offset { begin: 3, end: 7 }), mask: Mask::None },
//! Token { text: String::from("is"), score: 0.9997343420982361, label: String::from("O"), label_index: 0, sentence: 0, index: 3, word_index: 3, offset: Some(Offset { begin: 8, end: 10 }), mask: Mask::None },
//! Token { text: String::from("Amélie"), score: 0.9913727683112525, label: String::from("I-PER"), label_index: 4, sentence: 0, index: 4, word_index: 4, offset: Some(Offset { begin: 11, end: 17 }), mask: Mask::None }
//! // ...
//! ]
//!# ;
//! ```
use tch::nn::VarStore;
use rust_tokenizers::preprocessing::tokenizer::base_tokenizer::{Tokenizer, TokenizedInput, TruncationStrategy};
use rust_tokenizers::preprocessing::tokenizer::base_tokenizer::{Tokenizer, TokenizedInput, TruncationStrategy, Mask, Offset, ConsolidatableTokens, ConsolidatedTokenIterator, TokenTrait};
use std::collections::HashMap;
use tch::{Tensor, no_grad, Device};
use tch::kind::Kind::Float;
@ -63,12 +69,14 @@ use crate::bert::{BertForTokenClassification, BertModelResources, BertConfigReso
use crate::roberta::RobertaForTokenClassification;
use crate::distilbert::DistilBertForTokenClassification;
use crate::common::resources::{Resource, RemoteResource, download_resource};
use serde::{Serialize, Deserialize};
use crate::pipelines::common::{ModelType, ConfigOption, TokenizerOption};
use crate::electra::ElectraForTokenClassification;
use itertools::Itertools;
use std::cmp::min;
use serde::{Serialize, Deserialize};
#[derive(Debug, Serialize, Deserialize)]
#[derive(Debug, Clone, Serialize, Deserialize)]
/// # Token generated by a `TokenClassificationModel`
pub struct Token {
/// String representation of the Token
@ -80,14 +88,59 @@ pub struct Token {
/// Token label (e.g. ORG, LOC in case of NER)
pub label: String,
/// Label index
pub label_index: i64,
/// Sentence index
#[serde(default)]
pub sentence: usize,
/// Token position index
pub index: u16,
/// Token word position index
pub word_index: u16,
/// Token offsets
pub offset: Option<Offset>,
/// Token mask
pub mask: Mask,
}
impl TokenTrait for Token {
fn offset(&self) -> Option<Offset> {
self.offset
}
fn mask(&self) -> Mask {
self.mask
}
fn as_str(&self) -> &str {
self.text.as_str()
}
}
impl ConsolidatableTokens<Token> for Vec<Token> {
fn iter_consolidate_tokens(&self) -> ConsolidatedTokenIterator<Token> {
ConsolidatedTokenIterator::new(self)
}
}
/// # Enum defining the label aggregation method for sub tokens
/// Defines the behaviour for labels aggregation if the consolidation of sub-tokens is enabled.
pub enum LabelAggregationOption {
/// The label of the first sub token is assigned to the entire token
First,
/// The label of the last sub token is assigned to the entire token
Last,
/// The most frequent sub- token is assigned to the entire token
Mode,
/// The user can provide a function mapping a `&Vec<Token>` to a `(i64, String)` tuple corresponding to the label index, label String to return
Custom(Box<dyn Fn(&[Token]) -> (i64, String)>),
}
/// # Configuration for TokenClassificationModel
/// Contains information regarding the model to load and device to place the model on.
pub struct TokenClassificationConfig {
@ -105,6 +158,8 @@ pub struct TokenClassificationConfig {
pub lower_case: bool,
/// Device to place the model on (default: CUDA/GPU when available)
pub device: Device,
/// Sub-tokens aggregation method (default: `LabelAggregationOption::First`)
pub label_aggregation_function: LabelAggregationOption,
}
impl TokenClassificationConfig {
@ -119,7 +174,13 @@ impl TokenClassificationConfig {
/// * vocab - An optional `Resource` tuple (`Option<Resource>`) pointing to the tokenizer's merge file to load (e.g. merges.txt), needed only for Roberta.
/// * lower_case - A `bool' indicating whether the tokenizer should lower case all input (in case of a lower-cased model)
///
pub fn new(model_type: ModelType, model_resource: Resource, config_resource: Resource, vocab_resource: Resource, merges_resource: Option<Resource>, lower_case: bool) -> TokenClassificationConfig {
pub fn new(model_type: ModelType,
model_resource: Resource,
config_resource: Resource,
vocab_resource: Resource,
merges_resource: Option<Resource>,
lower_case: bool,
label_aggregation_function: LabelAggregationOption) -> TokenClassificationConfig {
TokenClassificationConfig {
model_type,
model_resource,
@ -128,6 +189,7 @@ impl TokenClassificationConfig {
merges_resource,
lower_case,
device: Device::cuda_if_available(),
label_aggregation_function,
}
}
}
@ -143,6 +205,7 @@ impl Default for TokenClassificationConfig {
merges_resource: None,
lower_case: false,
device: Device::cuda_if_available(),
label_aggregation_function: LabelAggregationOption::First,
}
}
}
@ -212,14 +275,13 @@ impl TokenClassificationOption {
}
}
/// Interface method to forward_t() of the particular models.
pub fn forward_t(&self,
input_ids: Option<Tensor>,
mask: Option<Tensor>,
token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
fn forward_t(&self,
input_ids: Option<Tensor>,
mask: Option<Tensor>,
token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
match *self {
Self::Bert(ref model) => model.forward_t(input_ids, mask, token_type_ids, position_ids, input_embeds, train),
Self::DistilBert(ref model) => model.forward_t(input_ids, mask, input_embeds, train).expect("Error in distilbert forward_t"),
@ -234,9 +296,9 @@ impl TokenClassificationOption {
pub struct TokenClassificationModel {
tokenizer: TokenizerOption,
token_sequence_classifier: TokenClassificationOption,
//e.g. BertForTokenClassification,
label_mapping: HashMap<i64, String>,
var_store: VarStore,
label_aggregation_function: LabelAggregationOption,
}
impl TokenClassificationModel {
@ -267,6 +329,7 @@ impl TokenClassificationModel {
None
};
let device = config.device;
let label_aggregation_function = config.label_aggregation_function;
let tokenizer = TokenizerOption::from_file(config.model_type, vocab_path.to_str().unwrap(), merges_path.map(|path| path.to_str().unwrap()), config.lower_case);
let mut var_store = VarStore::new(device);
@ -274,10 +337,10 @@ impl TokenClassificationModel {
let token_sequence_classifier = TokenClassificationOption::new(config.model_type, &var_store.root(), &model_config);
let label_mapping = model_config.get_label_mapping();
var_store.load(weights_path)?;
Ok(TokenClassificationModel { tokenizer, token_sequence_classifier, label_mapping, var_store })
Ok(TokenClassificationModel { tokenizer, token_sequence_classifier, label_mapping, var_store, label_aggregation_function })
}
fn prepare_for_model(&self, input: Vec<&str>) -> Tensor {
fn prepare_for_model(&self, input: Vec<&str>) -> (Vec<TokenizedInput>, Tensor) {
let tokenized_input: Vec<TokenizedInput> = self.tokenizer.encode_list(input.to_vec(),
128,
&TruncationStrategy::LongestFirst,
@ -293,18 +356,20 @@ impl TokenClassificationModel {
map(|input|
Tensor::of_slice(&(input))).
collect::<Vec<_>>();
Tensor::stack(tokenized_input_tensors.as_slice(), 0).to(self.var_store.device())
(tokenized_input, Tensor::stack(tokenized_input_tensors.as_slice(), 0).to(self.var_store.device()))
}
/// Extract entities from a text
/// Classify tokens in a text sequence
///
/// # Arguments
///
/// * `input` - `&[&str]` Array of texts to extract entities from.
/// * `consolidate_subtokens` - bool flag indicating if subtokens should be consolidated at the token level
/// * `return_special` - bool flag indicating if labels for special tokens should be returned
///
/// # Returns
///
/// * `Vec<Entity>` containing extracted entities
/// * `Vec<Token>` containing Tokens with associated labels (for example POS tags)
///
/// # Example
///
@ -317,12 +382,12 @@ impl TokenClassificationModel {
/// "My name is Amy. I live in Paris.",
/// "Paris is a city in France."
/// ];
/// let output = ner_model.predict(&input, true);
/// let output = ner_model.predict(&input, true, true);
///# Ok(())
///# }
/// ```
pub fn predict(&self, input: &[&str], ignore_first_label: bool) -> Vec<Token> {
let input_tensor = self.prepare_for_model(input.to_vec());
pub fn predict(&self, input: &[&str], consolidate_sub_tokens: bool, return_special: bool) -> Vec<Token> {
let (tokenized_input, input_tensor) = self.prepare_for_model(input.to_vec());
let (output, _, _) = no_grad(|| {
self.token_sequence_classifier
.forward_t(Some(input_tensor.copy()),
@ -335,39 +400,147 @@ impl TokenClassificationModel {
let output = output.detach().to(Device::Cpu);
let score: Tensor = output.exp() / output.exp().sum1(&[-1], true, Float);
let labels_idx = &score.argmax(-1, true);
let mut tokens: Vec<Token> = vec!();
for sentence_idx in 0..labels_idx.size()[0] {
let labels = labels_idx.get(sentence_idx);
//let mut word_idx: u8 = 0;
for position_idx in 0..labels.size()[0] {
let label_id = labels.int64_value(&[position_idx]);
let token = {
let token_id = input_tensor.int64_value(&[sentence_idx, position_idx]);
self.decode_token(token_id, label_id, &score, sentence_idx, position_idx) //, &mut word_idx)
};
if let Some(token) = token {
if !ignore_first_label || label_id != 0 {
tokens.push(token);
}
let sentence_tokens = &tokenized_input[sentence_idx as usize];
let original_chars = input[sentence_idx as usize].chars().collect_vec();
let mut word_idx: u16 = 0;
for position_idx in 0..sentence_tokens.token_ids.len() {
let mask = sentence_tokens.mask[position_idx];
if (mask == Mask::Special) & (!return_special) {
continue;
}
if !(mask == Mask::Continuation) & !
(mask == Mask::InexactContinuation) {
word_idx += 1;
}
let token = {
self.decode_token(&original_chars, sentence_tokens, &input_tensor, &labels, &score, sentence_idx, position_idx as i64, word_idx - 1)
};
tokens.push(token);
}
}
if consolidate_sub_tokens {
self.consolidate_tokens(&mut tokens, &self.label_aggregation_function);
}
tokens
}
fn decode_token(&self, token_id: i64, label_id: i64, score: &Tensor, sentence_idx: i64, position_idx: i64) -> Option<Token> {
let text = match self.tokenizer {
TokenizerOption::Bert(ref tokenizer) => Tokenizer::decode(tokenizer, vec!(token_id), false, false),
TokenizerOption::Roberta(ref tokenizer) => Tokenizer::decode(tokenizer, vec!(token_id), false, false),
fn decode_token(&self, original_sentence_chars: &Vec<char>, sentence_tokens: &TokenizedInput, input_tensor: &Tensor,
labels: &Tensor, score: &Tensor, sentence_idx: i64, position_idx: i64, word_index: u16) -> Token {
let label_id = labels.int64_value(&[position_idx as i64]);
let token_id = input_tensor.int64_value(&[sentence_idx, position_idx as i64]);
let offsets = &sentence_tokens.token_offsets[position_idx as usize];
let text = match offsets {
None => match self.tokenizer {
TokenizerOption::Bert(ref tokenizer) => Tokenizer::decode(tokenizer, vec!(token_id), false, false),
TokenizerOption::Roberta(ref tokenizer) => Tokenizer::decode(tokenizer, vec!(token_id), false, false),
},
Some(offsets) => {
let (start_char, end_char) = (offsets.begin as usize, offsets.end as usize);
let end_char = min(end_char, original_sentence_chars.len());
let text = original_sentence_chars[start_char..end_char].iter().collect();
text
}
};
Some(Token {
Token {
text,
score: score.double_value(&[sentence_idx, position_idx, label_id]),
label: self.label_mapping.get(&label_id).expect("Index out of vocabulary bounds.").to_owned(),
label_index: label_id,
sentence: sentence_idx as usize,
index: position_idx as u16,
})
word_index,
offset: offsets.to_owned(),
mask: sentence_tokens.mask[position_idx as usize],
}
}
fn consolidate_tokens(&self, tokens: &mut Vec<Token>, label_aggregation_function: &LabelAggregationOption) {
let mut tokens_to_replace = vec!();
let mut token_iter = tokens.iter_consolidate_tokens();
let mut cursor = 0;
while let Some(sub_tokens) = token_iter.next() {
if sub_tokens.len() > 1 {
let (label_index, label) = self.consolidate_labels(sub_tokens, label_aggregation_function);
let sentence = (&sub_tokens[0]).sentence;
let index = (&sub_tokens[0]).index;
let word_index = (&sub_tokens[0]).word_index;
let offset_start = match &sub_tokens.first().unwrap().offset {
Some(offset) => Some(offset.begin),
None => None
};
let offset_end = match &sub_tokens.last().unwrap().offset {
Some(offset) => Some(offset.end),
None => None
};
let offset = if offset_start.is_some() & offset_end.is_some() {
Some(Offset::new(offset_start.unwrap(), offset_end.unwrap()))
} else {
None
};
let mut text = String::new();
let mut score = 1f64;
for current_sub_token in sub_tokens.into_iter() {
text.push_str(current_sub_token.text.as_str());
score *= if current_sub_token.label_index == label_index {
current_sub_token.score
} else {
1.0 - current_sub_token.score
};
}
let token = Token {
text,
score,
label,
label_index,
sentence,
index,
word_index,
offset,
mask: Default::default(),
};
tokens_to_replace.push(((cursor, cursor + sub_tokens.len()), token));
}
cursor += sub_tokens.len();
}
for ((start, end), token) in tokens_to_replace.into_iter().rev() {
tokens.splice(start..end, [token].iter().cloned());
}
}
fn consolidate_labels(&self, tokens: &[Token], aggregation: &LabelAggregationOption) -> (i64, String) {
match aggregation {
LabelAggregationOption::First => {
let token = tokens.first().unwrap();
(token.label_index, token.label.clone())
}
LabelAggregationOption::Last => {
let token = tokens.last().unwrap();
(token.label_index, token.label.clone())
}
LabelAggregationOption::Mode => {
let counts = tokens
.iter()
.fold(
HashMap::new(),
|mut m, c| {
*m.entry((c.label_index, c.label.as_str())).or_insert(0) += 1;
m
},
);
counts
.into_iter()
.max_by(|a, b| a.1.cmp(&b.1))
.map(|((label_index, label), _)| (label_index, label.to_owned()))
.unwrap()
}
LabelAggregationOption::Custom(function) => function(tokens)
}
}
}