From 9d3a944051257b36e8c4b3fd041daf1869abc05d Mon Sep 17 00:00:00 2001 From: Guillaume B Date: Sun, 10 May 2020 11:38:28 +0200 Subject: [PATCH 1/6] Updated token classification pipeline to use next tokenization features, reference to original text --- examples/ner.rs | 4 +- src/pipelines/ner.rs | 23 +++++---- src/pipelines/token_classification.rs | 70 ++++++++++++++++++--------- 3 files changed, 62 insertions(+), 35 deletions(-) diff --git a/examples/ner.rs b/examples/ner.rs index 074699e..448e427 100644 --- a/examples/ner.rs +++ b/examples/ner.rs @@ -20,8 +20,8 @@ fn main() -> failure::Fallible<()> { // Define input let input = [ - "My name is Amy. I live in Paris.", - "Paris is a city in France." + "My name is Amélie. I live in 北京市.", + "Chongqing is a city in China." ]; // Run model diff --git a/src/pipelines/ner.rs b/src/pipelines/ner.rs index 549e82d..785b852 100644 --- a/src/pipelines/ner.rs +++ b/src/pipelines/ner.rs @@ -45,7 +45,7 @@ //!# ; //! ``` -use crate::pipelines::token_classification::{TokenClassificationModel,TokenClassificationConfig}; +use crate::pipelines::token_classification::{TokenClassificationModel, TokenClassificationConfig}; #[derive(Debug)] @@ -58,6 +58,7 @@ pub struct Entity { /// Entity label (e.g. ORG, LOC...) pub label: String, } + //type alias for some backward compatibility type NERConfig = TokenClassificationConfig; @@ -86,7 +87,7 @@ impl NERModel { /// pub fn new(ner_config: NERConfig) -> failure::Fallible { let model = TokenClassificationModel::new(ner_config)?; - Ok(NERModel { token_classification_model: model}) + Ok(NERModel { token_classification_model: model }) } /// Extract entities from a text @@ -116,12 +117,16 @@ impl NERModel { /// ``` /// pub fn predict(&self, input: &[&str]) -> Vec { - self.token_classification_model.predict(input, true).into_iter().map(|token| { - Entity { - word: token.text, - score: token.score, - label: token.label, - } - }).collect() + self.token_classification_model + .predict(input, true, false) + .into_iter() + .filter(|token| token.label != "O") + .map(|token| { + Entity { + word: token.text, + score: token.score, + label: token.label, + } + }).collect() } } diff --git a/src/pipelines/token_classification.rs b/src/pipelines/token_classification.rs index 544ba55..11be5e3 100644 --- a/src/pipelines/token_classification.rs +++ b/src/pipelines/token_classification.rs @@ -55,7 +55,7 @@ //! ``` use tch::nn::VarStore; -use rust_tokenizers::preprocessing::tokenizer::base_tokenizer::{Tokenizer, TokenizedInput, TruncationStrategy}; +use rust_tokenizers::preprocessing::tokenizer::base_tokenizer::{Tokenizer, TokenizedInput, TruncationStrategy, Mask, Offset}; use std::collections::HashMap; use tch::{Tensor, no_grad, Device}; use tch::kind::Kind::Float; @@ -63,12 +63,12 @@ use crate::bert::{BertForTokenClassification, BertModelResources, BertConfigReso use crate::roberta::RobertaForTokenClassification; use crate::distilbert::DistilBertForTokenClassification; use crate::common::resources::{Resource, RemoteResource, download_resource}; -use serde::{Serialize, Deserialize}; use crate::pipelines::common::{ModelType, ConfigOption, TokenizerOption}; use crate::electra::ElectraForTokenClassification; +use itertools::Itertools; -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug)] /// # Token generated by a `TokenClassificationModel` pub struct Token { /// String representation of the Token @@ -81,11 +81,16 @@ pub struct Token { pub label: String, /// Sentence index - #[serde(default)] pub sentence: usize, /// Token position index pub index: u16, + + /// Token word position index + pub word_index: u16, + + /// Token offsets + pub offset: Option, } /// # Configuration for TokenClassificationModel @@ -277,7 +282,7 @@ impl TokenClassificationModel { Ok(TokenClassificationModel { tokenizer, token_sequence_classifier, label_mapping, var_store }) } - fn prepare_for_model(&self, input: Vec<&str>) -> Tensor { + fn prepare_for_model(&self, input: Vec<&str>) -> (Vec, Tensor) { let tokenized_input: Vec = self.tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, @@ -293,7 +298,7 @@ impl TokenClassificationModel { map(|input| Tensor::of_slice(&(input))). collect::>(); - Tensor::stack(tokenized_input_tensors.as_slice(), 0).to(self.var_store.device()) + (tokenized_input, Tensor::stack(tokenized_input_tensors.as_slice(), 0).to(self.var_store.device())) } /// Extract entities from a text @@ -321,8 +326,8 @@ impl TokenClassificationModel { ///# Ok(()) ///# } /// ``` - pub fn predict(&self, input: &[&str], ignore_first_label: bool) -> Vec { - let input_tensor = self.prepare_for_model(input.to_vec()); + pub fn predict(&self, input: &[&str], return_tokens: bool, return_special: bool) -> Vec { + let (tokenized_input, input_tensor) = self.prepare_for_model(input.to_vec()); let (output, _, _) = no_grad(|| { self.token_sequence_classifier .forward_t(Some(input_tensor.copy()), @@ -335,39 +340,56 @@ impl TokenClassificationModel { let output = output.detach().to(Device::Cpu); let score: Tensor = output.exp() / output.exp().sum1(&[-1], true, Float); let labels_idx = &score.argmax(-1, true); - let mut tokens: Vec = vec!(); for sentence_idx in 0..labels_idx.size()[0] { let labels = labels_idx.get(sentence_idx); - //let mut word_idx: u8 = 0; - for position_idx in 0..labels.size()[0] { - let label_id = labels.int64_value(&[position_idx]); + let sentence_tokens = &tokenized_input[sentence_idx as usize]; + let original_chars = input[sentence_idx as usize].chars().collect_vec(); + let mut word_idx: u16 = 0; + for position_idx in 0..sentence_tokens.token_ids.len() { + let mask = sentence_tokens.mask[position_idx]; + if (mask == Mask::Special) & (!return_special) { + continue; + } let token = { - let token_id = input_tensor.int64_value(&[sentence_idx, position_idx]); - self.decode_token(token_id, label_id, &score, sentence_idx, position_idx) //, &mut word_idx) + self.decode_token(&original_chars, sentence_tokens, &input_tensor, &labels, &score, sentence_idx, position_idx as i64, word_idx) }; - if let Some(token) = token { - if !ignore_first_label || label_id != 0 { - tokens.push(token); - } + tokens.push(token); + if !(mask == Mask::Continuation) || !(mask == Mask::InexactContinuation) { + word_idx += 1; } } } tokens } - fn decode_token(&self, token_id: i64, label_id: i64, score: &Tensor, sentence_idx: i64, position_idx: i64) -> Option { - let text = match self.tokenizer { - TokenizerOption::Bert(ref tokenizer) => Tokenizer::decode(tokenizer, vec!(token_id), false, false), - TokenizerOption::Roberta(ref tokenizer) => Tokenizer::decode(tokenizer, vec!(token_id), false, false), + fn decode_token(&self, original_sentence_chars: &Vec, sentence_tokens: &TokenizedInput, input_tensor: &Tensor, + labels: &Tensor, score: &Tensor, sentence_idx: i64, position_idx: i64, word_index: u16) -> Token { + let label_id = labels.int64_value(&[position_idx as i64]); + let token_id = input_tensor.int64_value(&[sentence_idx, position_idx as i64]); + + let offsets = &sentence_tokens.token_offsets[position_idx as usize]; + + let text = match offsets { + None => match self.tokenizer { + TokenizerOption::Bert(ref tokenizer) => Tokenizer::decode(tokenizer, vec!(token_id), false, false), + TokenizerOption::Roberta(ref tokenizer) => Tokenizer::decode(tokenizer, vec!(token_id), false, false), + }, + Some(offsets) => { + let (start_char, end_char) = (offsets.begin as usize, offsets.end as usize); + let text = original_sentence_chars[start_char..end_char].iter().collect(); + text + } }; - Some(Token { + Token { text, score: score.double_value(&[sentence_idx, position_idx, label_id]), label: self.label_mapping.get(&label_id).expect("Index out of vocabulary bounds.").to_owned(), sentence: sentence_idx as usize, index: position_idx as u16, - }) + word_index, + offset: offsets.to_owned(), + } } } From 1a9e315edfdb527ce60430cbc7355e7c57a4912b Mon Sep 17 00:00:00 2001 From: Guillaume B Date: Sun, 10 May 2020 12:05:13 +0200 Subject: [PATCH 2/6] Added sub-token consolidation --- examples/ner.rs | 2 +- src/pipelines/token_classification.rs | 65 ++++++++++++++++++++++++++- 2 files changed, 64 insertions(+), 3 deletions(-) diff --git a/examples/ner.rs b/examples/ner.rs index 448e427..d7aa11d 100644 --- a/examples/ner.rs +++ b/examples/ner.rs @@ -20,7 +20,7 @@ fn main() -> failure::Fallible<()> { // Define input let input = [ - "My name is Amélie. I live in 北京市.", + "My name is Amélie. I live in Москва.", "Chongqing is a city in China." ]; diff --git a/src/pipelines/token_classification.rs b/src/pipelines/token_classification.rs index 11be5e3..5213ed2 100644 --- a/src/pipelines/token_classification.rs +++ b/src/pipelines/token_classification.rs @@ -68,7 +68,7 @@ use crate::electra::ElectraForTokenClassification; use itertools::Itertools; -#[derive(Debug)] +#[derive(Debug, Clone)] /// # Token generated by a `TokenClassificationModel` pub struct Token { /// String representation of the Token @@ -91,6 +91,9 @@ pub struct Token { /// Token offsets pub offset: Option, + + /// Token mask + pub mask: Mask, } /// # Configuration for TokenClassificationModel @@ -326,7 +329,7 @@ impl TokenClassificationModel { ///# Ok(()) ///# } /// ``` - pub fn predict(&self, input: &[&str], return_tokens: bool, return_special: bool) -> Vec { + pub fn predict(&self, input: &[&str], consolidate_subtokens: bool, return_special: bool) -> Vec { let (tokenized_input, input_tensor) = self.prepare_for_model(input.to_vec()); let (output, _, _) = no_grad(|| { self.token_sequence_classifier @@ -360,6 +363,9 @@ impl TokenClassificationModel { } } } + if consolidate_subtokens { + tokens = self.consolidate_tokens(tokens); + } tokens } @@ -390,6 +396,61 @@ impl TokenClassificationModel { index: position_idx as u16, word_index, offset: offsets.to_owned(), + mask: sentence_tokens.mask[position_idx as usize], } } + + fn consolidate_tokens(&self, tokens: Vec) -> Vec { + let mut consolidated_tokens: Vec = vec!(); + let mut current_token: Vec = vec!(); + for sub_token in tokens.iter() { + if (sub_token.mask != Mask::Continuation) & (sub_token.mask != Mask::InexactContinuation) { + match current_token.len() { + 0 => {} + 1 => consolidated_tokens.push(current_token[0].clone()), + _ => { + let mut text = String::new(); + let mut score = 1f64; + let label: String = (¤t_token[0]).label.clone(); + let sentence = (¤t_token[0]).sentence; + let index = (¤t_token[0]).index; + let word_index = (¤t_token[0]).word_index; + let offset_start = match ¤t_token.first().unwrap().offset { + Some(offset) => Some(offset.begin), + None => None + }; + let offset_end = match ¤t_token.last().unwrap().offset { + Some(offset) => Some(offset.end), + None => None + }; + let offset = if offset_start.is_some() & offset_end.is_some() { + Some(Offset::new(offset_start.unwrap(), offset_end.unwrap())) + } else { + None + }; + for current_sub_token in current_token.into_iter() { + text.push_str(current_sub_token.text.as_str()); + score *= current_sub_token.score; + } + consolidated_tokens.push( + Token { + text, + score, + label, + sentence, + index, + word_index, + offset, + mask: Default::default(), + } + ) + } + }; + current_token = vec!(sub_token.clone()); + } else { + current_token.push(sub_token.clone()); + } + } + consolidated_tokens + } } From 6d61074f7fb23c96ffa655de5f5d90651541418d Mon Sep 17 00:00:00 2001 From: Guillaume B Date: Sun, 10 May 2020 12:27:42 +0200 Subject: [PATCH 3/6] Updated consolidation and documentation --- examples/token_classification.rs | 42 +++++++++++++++++++++++++++ src/pipelines/token_classification.rs | 33 ++++++++++++--------- 2 files changed, 62 insertions(+), 13 deletions(-) create mode 100644 examples/token_classification.rs diff --git a/examples/token_classification.rs b/examples/token_classification.rs new file mode 100644 index 0000000..82da0b8 --- /dev/null +++ b/examples/token_classification.rs @@ -0,0 +1,42 @@ +// Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc. +// Copyright 2019 Guillaume Becquin +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use rust_bert::pipelines::token_classification::{TokenClassificationModel, TokenClassificationConfig}; +use rust_bert::resources::{Resource, RemoteResource}; +use rust_bert::bert::{BertModelResources, BertVocabResources, BertConfigResources}; +use rust_bert::pipelines::common::ModelType; + +fn main() -> failure::Fallible<()> { + +// Load a configuration + let config = TokenClassificationConfig::new(ModelType::Bert, + Resource::Remote(RemoteResource::from_pretrained(BertModelResources::BERT_NER)), + Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT_NER)), + Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT_NER)), + None, //merges resource only relevant with ModelType::Roberta + false, //lowercase + ); + +// Create the model + let token_classification_model = TokenClassificationModel::new(config)?; + let input = [ + "My name is Amélie. I live in Москва.", + "Chongqing is a city in China." + ]; + let token_outputs = token_classification_model.predict(&input, true, false); //ignore_first_label = true (only returns the NER parts, ignoring first label O) + + for token in token_outputs { + println!("{:?}", token); + } + + Ok(()) +} \ No newline at end of file diff --git a/src/pipelines/token_classification.rs b/src/pipelines/token_classification.rs index 5213ed2..a8e7bf9 100644 --- a/src/pipelines/token_classification.rs +++ b/src/pipelines/token_classification.rs @@ -37,19 +37,23 @@ //! "My name is Amy. I live in Paris.", //! "Paris is a city in France." //! ]; -//! let output = token_classification_model.predict(&input, true); //ignore_first_label = true (only returns the NER parts, ignoring first label O) +//! let output = token_classification_model.predict(&input, true, true); //ignore_first_label = true (only returns the NER parts, ignoring first label O) //!# Ok(()) //!# } //! ``` //! Output: \ //! ```no_run //!# use rust_bert::pipelines::token_classification::Token; +//! use rust_tokenizers::preprocessing::tokenizer::base_tokenizer::Mask::Special; +//! use rust_tokenizers::preprocessing::tokenizer::base_tokenizer::{Offset, Mask}; //!# let output = //! [ -//! Token { text: String::from("Amy"), score: 0.9986, label: String::from("I-PER"), sentence: 0, index: 0}, -//! Token { text: String::from("Paris"), score: 0.9985, label: String::from("I-LOC"), sentence: 0, index: 9}, -//! Token { text: String::from("Paris"), score: 0.9988, label: String::from("I-LOC"), sentence: 1, index: 1}, -//! Token { text: String::from("France"), score: 0.9993, label: String::from("I-LOC"), sentence: 1, index: 6} +//! Token { text: String::from("[CLS]"), score: 0.9995001554489136, label: String::from("O"), sentence: 0, index: 0, word_index: 0, offset: None, mask: Special }, +//! Token { text: String::from("My"), score: 0.9980450868606567, label: String::from("O"), sentence: 0, index: 1, word_index: 1, offset: Some(Offset { begin: 0, end: 2 }), mask: Mask::None }, +//! Token { text: String::from("name"), score: 0.9995062351226807, label: String::from("O"), sentence: 0, index: 2, word_index: 2, offset: Some(Offset { begin: 3, end: 7 }), mask: Mask::None }, +//! Token { text: String::from("is"), score: 0.9997343420982361, label: String::from("O"), sentence: 0, index: 3, word_index: 3, offset: Some(Offset { begin: 8, end: 10 }), mask: Mask::None }, +//! Token { text: String::from("Amélie"), score: 0.9913727683112525, label: String::from("I-PER"), sentence: 0, index: 4, word_index: 4, offset: Some(Offset { begin: 11, end: 17 }), mask: Mask::None } +//! // ... //! ] //!# ; //! ``` @@ -146,8 +150,8 @@ impl Default for TokenClassificationConfig { TokenClassificationConfig { model_type: ModelType::Bert, model_resource: Resource::Remote(RemoteResource::from_pretrained(BertModelResources::BERT_NER)), - config_resource: Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT_NER)), - vocab_resource: Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT_NER)), + config_resource: Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT_NER)), + vocab_resource: Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT_NER)), merges_resource: None, lower_case: false, device: Device::cuda_if_available(), @@ -309,6 +313,8 @@ impl TokenClassificationModel { /// # Arguments /// /// * `input` - `&[&str]` Array of texts to extract entities from. + /// * `consolidate_subtokens` - bool flag indicating if subtokens should be consolidated at the token level + /// * `return_special` - bool flag indicating if labels for special tokens should be returned /// /// # Returns /// @@ -325,7 +331,7 @@ impl TokenClassificationModel { /// "My name is Amy. I live in Paris.", /// "Paris is a city in France." /// ]; - /// let output = ner_model.predict(&input, true); + /// let output = ner_model.predict(&input, true, true); ///# Ok(()) ///# } /// ``` @@ -354,13 +360,14 @@ impl TokenClassificationModel { if (mask == Mask::Special) & (!return_special) { continue; } - let token = { - self.decode_token(&original_chars, sentence_tokens, &input_tensor, &labels, &score, sentence_idx, position_idx as i64, word_idx) - }; - tokens.push(token); - if !(mask == Mask::Continuation) || !(mask == Mask::InexactContinuation) { + if !(mask == Mask::Continuation) & ! + (mask == Mask::InexactContinuation) { word_idx += 1; } + let token = { + self.decode_token(&original_chars, sentence_tokens, &input_tensor, &labels, &score, sentence_idx, position_idx as i64, word_idx - 1) + }; + tokens.push(token); } } if consolidate_subtokens { From 0bbb47d1db480d7bbdc0b8b6eacc5170d9496461 Mon Sep 17 00:00:00 2001 From: Guillaume B Date: Mon, 11 May 2020 16:40:25 +0200 Subject: [PATCH 4/6] Added options for label consolidation for sub tokens --- examples/token_classification.rs | 3 +- src/pipelines/token_classification.rs | 142 +++++++++++++++++++------- 2 files changed, 107 insertions(+), 38 deletions(-) diff --git a/examples/token_classification.rs b/examples/token_classification.rs index 82da0b8..d697065 100644 --- a/examples/token_classification.rs +++ b/examples/token_classification.rs @@ -10,7 +10,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -use rust_bert::pipelines::token_classification::{TokenClassificationModel, TokenClassificationConfig}; +use rust_bert::pipelines::token_classification::{TokenClassificationModel, TokenClassificationConfig, LabelAggregationOption}; use rust_bert::resources::{Resource, RemoteResource}; use rust_bert::bert::{BertModelResources, BertVocabResources, BertConfigResources}; use rust_bert::pipelines::common::ModelType; @@ -24,6 +24,7 @@ fn main() -> failure::Fallible<()> { Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT_NER)), None, //merges resource only relevant with ModelType::Roberta false, //lowercase + LabelAggregationOption::Mode, ); // Create the model diff --git a/src/pipelines/token_classification.rs b/src/pipelines/token_classification.rs index a8e7bf9..ebddfb0 100644 --- a/src/pipelines/token_classification.rs +++ b/src/pipelines/token_classification.rs @@ -22,12 +22,14 @@ //!# fn main() -> failure::Fallible<()> { //! //! //Load a configuration +//! use rust_bert::pipelines::token_classification::LabelAggregationOption; //! let config = TokenClassificationConfig::new(ModelType::Bert, //! Resource::Remote(RemoteResource::from_pretrained(BertModelResources::BERT_NER)), //! Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT_NER)), //! Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT_NER)), //! None, //merges resource only relevant with ModelType::Roberta //! false, //lowercase +//! LabelAggregationOption::Mode //! ); //! //! //Create the model @@ -48,11 +50,11 @@ //! use rust_tokenizers::preprocessing::tokenizer::base_tokenizer::{Offset, Mask}; //!# let output = //! [ -//! Token { text: String::from("[CLS]"), score: 0.9995001554489136, label: String::from("O"), sentence: 0, index: 0, word_index: 0, offset: None, mask: Special }, -//! Token { text: String::from("My"), score: 0.9980450868606567, label: String::from("O"), sentence: 0, index: 1, word_index: 1, offset: Some(Offset { begin: 0, end: 2 }), mask: Mask::None }, -//! Token { text: String::from("name"), score: 0.9995062351226807, label: String::from("O"), sentence: 0, index: 2, word_index: 2, offset: Some(Offset { begin: 3, end: 7 }), mask: Mask::None }, -//! Token { text: String::from("is"), score: 0.9997343420982361, label: String::from("O"), sentence: 0, index: 3, word_index: 3, offset: Some(Offset { begin: 8, end: 10 }), mask: Mask::None }, -//! Token { text: String::from("Amélie"), score: 0.9913727683112525, label: String::from("I-PER"), sentence: 0, index: 4, word_index: 4, offset: Some(Offset { begin: 11, end: 17 }), mask: Mask::None } +//! Token { text: String::from("[CLS]"), score: 0.9995001554489136, label: String::from("O"), label_index: 0, sentence: 0, index: 0, word_index: 0, offset: None, mask: Special }, +//! Token { text: String::from("My"), score: 0.9980450868606567, label: String::from("O"), label_index: 0, sentence: 0, index: 1, word_index: 1, offset: Some(Offset { begin: 0, end: 2 }), mask: Mask::None }, +//! Token { text: String::from("name"), score: 0.9995062351226807, label: String::from("O"), label_index: 0, sentence: 0, index: 2, word_index: 2, offset: Some(Offset { begin: 3, end: 7 }), mask: Mask::None }, +//! Token { text: String::from("is"), score: 0.9997343420982361, label: String::from("O"), label_index: 0, sentence: 0, index: 3, word_index: 3, offset: Some(Offset { begin: 8, end: 10 }), mask: Mask::None }, +//! Token { text: String::from("Amélie"), score: 0.9913727683112525, label: String::from("I-PER"), label_index: 4, sentence: 0, index: 4, word_index: 4, offset: Some(Offset { begin: 11, end: 17 }), mask: Mask::None } //! // ... //! ] //!# ; @@ -70,6 +72,7 @@ use crate::common::resources::{Resource, RemoteResource, download_resource}; use crate::pipelines::common::{ModelType, ConfigOption, TokenizerOption}; use crate::electra::ElectraForTokenClassification; use itertools::Itertools; +use std::cmp::min; #[derive(Debug, Clone)] @@ -84,6 +87,9 @@ pub struct Token { /// Token label (e.g. ORG, LOC in case of NER) pub label: String, + /// Label index + pub label_index: i64, + /// Sentence index pub sentence: usize, @@ -100,6 +106,21 @@ pub struct Token { pub mask: Mask, } +/// # Enum defining the label aggregation method for sub tokens +/// Defines the behaviour for labels aggregation if the consolidation of sub-tokens is enabled. +/// Some defaults options are provided: +/// - First (the label of the first sub token is assigned to the entire token) +/// - Last (the label of the last sub token is assigned to the entire token) +/// - Mode (the most frequent sub- token is assigned to the entire token) +/// - Custom: the user can provide a function mapping a `&Vec` to a `(i64, String)` tuple corresponding to the label index, label String to return +pub enum LabelAggregationOption { + First, + Last, + Mode, + Custom(Box) -> (i64, String)>), +} + + /// # Configuration for TokenClassificationModel /// Contains information regarding the model to load and device to place the model on. pub struct TokenClassificationConfig { @@ -117,6 +138,8 @@ pub struct TokenClassificationConfig { pub lower_case: bool, /// Device to place the model on (default: CUDA/GPU when available) pub device: Device, + /// Sub-tokens aggregation method (default: `LabelAggregationOption::First`) + pub label_aggregation_function: LabelAggregationOption, } impl TokenClassificationConfig { @@ -131,7 +154,13 @@ impl TokenClassificationConfig { /// * vocab - An optional `Resource` tuple (`Option`) pointing to the tokenizer's merge file to load (e.g. merges.txt), needed only for Roberta. /// * lower_case - A `bool' indicating whether the tokenizer should lower case all input (in case of a lower-cased model) /// - pub fn new(model_type: ModelType, model_resource: Resource, config_resource: Resource, vocab_resource: Resource, merges_resource: Option, lower_case: bool) -> TokenClassificationConfig { + pub fn new(model_type: ModelType, + model_resource: Resource, + config_resource: Resource, + vocab_resource: Resource, + merges_resource: Option, + lower_case: bool, + label_aggregation_function: LabelAggregationOption) -> TokenClassificationConfig { TokenClassificationConfig { model_type, model_resource, @@ -140,6 +169,7 @@ impl TokenClassificationConfig { merges_resource, lower_case, device: Device::cuda_if_available(), + label_aggregation_function, } } } @@ -150,11 +180,12 @@ impl Default for TokenClassificationConfig { TokenClassificationConfig { model_type: ModelType::Bert, model_resource: Resource::Remote(RemoteResource::from_pretrained(BertModelResources::BERT_NER)), - config_resource: Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT_NER)), - vocab_resource: Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT_NER)), + config_resource: Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT_NER)), + vocab_resource: Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT_NER)), merges_resource: None, lower_case: false, device: Device::cuda_if_available(), + label_aggregation_function: LabelAggregationOption::First, } } } @@ -224,14 +255,13 @@ impl TokenClassificationOption { } } - /// Interface method to forward_t() of the particular models. - pub fn forward_t(&self, - input_ids: Option, - mask: Option, - token_type_ids: Option, - position_ids: Option, - input_embeds: Option, - train: bool) -> (Tensor, Option>, Option>) { + fn forward_t(&self, + input_ids: Option, + mask: Option, + token_type_ids: Option, + position_ids: Option, + input_embeds: Option, + train: bool) -> (Tensor, Option>, Option>) { match *self { Self::Bert(ref model) => model.forward_t(input_ids, mask, token_type_ids, position_ids, input_embeds, train), Self::DistilBert(ref model) => model.forward_t(input_ids, mask, input_embeds, train).expect("Error in distilbert forward_t"), @@ -246,9 +276,9 @@ impl TokenClassificationOption { pub struct TokenClassificationModel { tokenizer: TokenizerOption, token_sequence_classifier: TokenClassificationOption, - //e.g. BertForTokenClassification, label_mapping: HashMap, var_store: VarStore, + label_aggregation_function: LabelAggregationOption, } impl TokenClassificationModel { @@ -279,6 +309,7 @@ impl TokenClassificationModel { None }; let device = config.device; + let label_aggregation_function = config.label_aggregation_function; let tokenizer = TokenizerOption::from_file(config.model_type, vocab_path.to_str().unwrap(), merges_path.map(|path| path.to_str().unwrap()), config.lower_case); let mut var_store = VarStore::new(device); @@ -286,7 +317,7 @@ impl TokenClassificationModel { let token_sequence_classifier = TokenClassificationOption::new(config.model_type, &var_store.root(), &model_config); let label_mapping = model_config.get_label_mapping(); var_store.load(weights_path)?; - Ok(TokenClassificationModel { tokenizer, token_sequence_classifier, label_mapping, var_store }) + Ok(TokenClassificationModel { tokenizer, token_sequence_classifier, label_mapping, var_store, label_aggregation_function }) } fn prepare_for_model(&self, input: Vec<&str>) -> (Vec, Tensor) { @@ -308,7 +339,7 @@ impl TokenClassificationModel { (tokenized_input, Tensor::stack(tokenized_input_tensors.as_slice(), 0).to(self.var_store.device())) } - /// Extract entities from a text + /// Classify tokens in a text sequence /// /// # Arguments /// @@ -318,7 +349,7 @@ impl TokenClassificationModel { /// /// # Returns /// - /// * `Vec` containing extracted entities + /// * `Vec` containing Tokens with associated labels (for example POS tags) /// /// # Example /// @@ -335,7 +366,7 @@ impl TokenClassificationModel { ///# Ok(()) ///# } /// ``` - pub fn predict(&self, input: &[&str], consolidate_subtokens: bool, return_special: bool) -> Vec { + pub fn predict(&self, input: &[&str], consolidate_sub_tokens: bool, return_special: bool) -> Vec { let (tokenized_input, input_tensor) = self.prepare_for_model(input.to_vec()); let (output, _, _) = no_grad(|| { self.token_sequence_classifier @@ -370,8 +401,8 @@ impl TokenClassificationModel { tokens.push(token); } } - if consolidate_subtokens { - tokens = self.consolidate_tokens(tokens); + if consolidate_sub_tokens { + tokens = self.consolidate_tokens(tokens, &self.label_aggregation_function); } tokens } @@ -390,6 +421,7 @@ impl TokenClassificationModel { }, Some(offsets) => { let (start_char, end_char) = (offsets.begin as usize, offsets.end as usize); + let end_char = min(end_char, original_sentence_chars.len()); let text = original_sentence_chars[start_char..end_char].iter().collect(); text } @@ -399,6 +431,7 @@ impl TokenClassificationModel { text, score: score.double_value(&[sentence_idx, position_idx, label_id]), label: self.label_mapping.get(&label_id).expect("Index out of vocabulary bounds.").to_owned(), + label_index: label_id, sentence: sentence_idx as usize, index: position_idx as u16, word_index, @@ -407,26 +440,26 @@ impl TokenClassificationModel { } } - fn consolidate_tokens(&self, tokens: Vec) -> Vec { + fn consolidate_tokens(&self, tokens: Vec, label_aggregation_function: &LabelAggregationOption) -> Vec { let mut consolidated_tokens: Vec = vec!(); - let mut current_token: Vec = vec!(); + let mut current_tokens: Vec = vec!(); for sub_token in tokens.iter() { if (sub_token.mask != Mask::Continuation) & (sub_token.mask != Mask::InexactContinuation) { - match current_token.len() { + match current_tokens.len() { 0 => {} - 1 => consolidated_tokens.push(current_token[0].clone()), + 1 => consolidated_tokens.push(current_tokens[0].clone()), _ => { let mut text = String::new(); let mut score = 1f64; - let label: String = (¤t_token[0]).label.clone(); - let sentence = (¤t_token[0]).sentence; - let index = (¤t_token[0]).index; - let word_index = (¤t_token[0]).word_index; - let offset_start = match ¤t_token.first().unwrap().offset { + let (label_index, label) = self.consolidate_labels(¤t_tokens, label_aggregation_function); + let sentence = (¤t_tokens[0]).sentence; + let index = (¤t_tokens[0]).index; + let word_index = (¤t_tokens[0]).word_index; + let offset_start = match ¤t_tokens.first().unwrap().offset { Some(offset) => Some(offset.begin), None => None }; - let offset_end = match ¤t_token.last().unwrap().offset { + let offset_end = match ¤t_tokens.last().unwrap().offset { Some(offset) => Some(offset.end), None => None }; @@ -435,15 +468,20 @@ impl TokenClassificationModel { } else { None }; - for current_sub_token in current_token.into_iter() { + for current_sub_token in current_tokens.into_iter() { text.push_str(current_sub_token.text.as_str()); - score *= current_sub_token.score; + score *= if current_sub_token.label_index == label_index { + current_sub_token.score + } else { + 1.0 - current_sub_token.score + }; } consolidated_tokens.push( Token { text, score, label, + label_index, sentence, index, word_index, @@ -453,11 +491,41 @@ impl TokenClassificationModel { ) } }; - current_token = vec!(sub_token.clone()); + current_tokens = vec!(sub_token.clone()); } else { - current_token.push(sub_token.clone()); + current_tokens.push(sub_token.clone()); } } consolidated_tokens } + + fn consolidate_labels(&self, tokens: &Vec, aggregation: &LabelAggregationOption) -> (i64, String) { + match aggregation { + LabelAggregationOption::First => { + let token = tokens.first().unwrap(); + (token.label_index, token.label.clone()) + } + LabelAggregationOption::Last => { + let token = tokens.last().unwrap(); + (token.label_index, token.label.clone()) + } + LabelAggregationOption::Mode => { + let counts = tokens + .iter() + .fold( + HashMap::new(), + |mut m, c| { + *m.entry((c.label_index, c.label.as_str())).or_insert(0) += 1; + m + }, + ); + counts + .into_iter() + .max_by(|a, b| a.1.cmp(&b.1)) + .map(|((label_index, label), _)| (label_index, label.to_owned())) + .unwrap() + } + LabelAggregationOption::Custom(function) => function(tokens) + } + } } From b31a569e506012802048c15a3b4a5a1880a265f4 Mon Sep 17 00:00:00 2001 From: Guillaume B Date: Mon, 11 May 2020 22:05:11 +0200 Subject: [PATCH 5/6] Added documentation --- src/pipelines/token_classification.rs | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/src/pipelines/token_classification.rs b/src/pipelines/token_classification.rs index ebddfb0..b86a580 100644 --- a/src/pipelines/token_classification.rs +++ b/src/pipelines/token_classification.rs @@ -73,9 +73,9 @@ use crate::pipelines::common::{ModelType, ConfigOption, TokenizerOption}; use crate::electra::ElectraForTokenClassification; use itertools::Itertools; use std::cmp::min; +use serde::{Serialize, Deserialize}; - -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Serialize, Deserialize)] /// # Token generated by a `TokenClassificationModel` pub struct Token { /// String representation of the Token @@ -108,15 +108,14 @@ pub struct Token { /// # Enum defining the label aggregation method for sub tokens /// Defines the behaviour for labels aggregation if the consolidation of sub-tokens is enabled. -/// Some defaults options are provided: -/// - First (the label of the first sub token is assigned to the entire token) -/// - Last (the label of the last sub token is assigned to the entire token) -/// - Mode (the most frequent sub- token is assigned to the entire token) -/// - Custom: the user can provide a function mapping a `&Vec` to a `(i64, String)` tuple corresponding to the label index, label String to return pub enum LabelAggregationOption { + /// The label of the first sub token is assigned to the entire token First, + /// The label of the last sub token is assigned to the entire token Last, + /// The most frequent sub- token is assigned to the entire token Mode, + /// The user can provide a function mapping a `&Vec` to a `(i64, String)` tuple corresponding to the label index, label String to return Custom(Box) -> (i64, String)>), } @@ -449,8 +448,6 @@ impl TokenClassificationModel { 0 => {} 1 => consolidated_tokens.push(current_tokens[0].clone()), _ => { - let mut text = String::new(); - let mut score = 1f64; let (label_index, label) = self.consolidate_labels(¤t_tokens, label_aggregation_function); let sentence = (¤t_tokens[0]).sentence; let index = (¤t_tokens[0]).index; @@ -468,6 +465,8 @@ impl TokenClassificationModel { } else { None }; + let mut text = String::new(); + let mut score = 1f64; for current_sub_token in current_tokens.into_iter() { text.push_str(current_sub_token.text.as_str()); score *= if current_sub_token.label_index == label_index { From 39454161924dcb07b13f8f831bd5600321e61c2d Mon Sep 17 00:00:00 2001 From: Guillaume B Date: Mon, 11 May 2020 23:24:33 +0200 Subject: [PATCH 6/6] Updated token consolidation avoiding copy --- Cargo.toml | 2 +- src/pipelines/token_classification.rs | 132 +++++++++++++++----------- 2 files changed, 75 insertions(+), 59 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 959cf1c..520fa81 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -30,7 +30,7 @@ all-tests = [] features = [ "doc-only" ] [dependencies] -rust_tokenizers = "~3.0.0" +rust_tokenizers = "~3.0.2" tch = "~0.1.7" serde_json = "1.0.51" serde = {version = "1.0.106", features = ["derive"]} diff --git a/src/pipelines/token_classification.rs b/src/pipelines/token_classification.rs index b86a580..61e5839 100644 --- a/src/pipelines/token_classification.rs +++ b/src/pipelines/token_classification.rs @@ -61,7 +61,7 @@ //! ``` use tch::nn::VarStore; -use rust_tokenizers::preprocessing::tokenizer::base_tokenizer::{Tokenizer, TokenizedInput, TruncationStrategy, Mask, Offset}; +use rust_tokenizers::preprocessing::tokenizer::base_tokenizer::{Tokenizer, TokenizedInput, TruncationStrategy, Mask, Offset, ConsolidatableTokens, ConsolidatedTokenIterator, TokenTrait}; use std::collections::HashMap; use tch::{Tensor, no_grad, Device}; use tch::kind::Kind::Float; @@ -75,6 +75,7 @@ use itertools::Itertools; use std::cmp::min; use serde::{Serialize, Deserialize}; + #[derive(Debug, Clone, Serialize, Deserialize)] /// # Token generated by a `TokenClassificationModel` pub struct Token { @@ -106,6 +107,26 @@ pub struct Token { pub mask: Mask, } +impl TokenTrait for Token { + fn offset(&self) -> Option { + self.offset + } + + fn mask(&self) -> Mask { + self.mask + } + + fn as_str(&self) -> &str { + self.text.as_str() + } +} + +impl ConsolidatableTokens for Vec { + fn iter_consolidate_tokens(&self) -> ConsolidatedTokenIterator { + ConsolidatedTokenIterator::new(self) + } +} + /// # Enum defining the label aggregation method for sub tokens /// Defines the behaviour for labels aggregation if the consolidation of sub-tokens is enabled. pub enum LabelAggregationOption { @@ -116,7 +137,7 @@ pub enum LabelAggregationOption { /// The most frequent sub- token is assigned to the entire token Mode, /// The user can provide a function mapping a `&Vec` to a `(i64, String)` tuple corresponding to the label index, label String to return - Custom(Box) -> (i64, String)>), + Custom(Box (i64, String)>), } @@ -401,7 +422,7 @@ impl TokenClassificationModel { } } if consolidate_sub_tokens { - tokens = self.consolidate_tokens(tokens, &self.label_aggregation_function); + self.consolidate_tokens(&mut tokens, &self.label_aggregation_function); } tokens } @@ -439,66 +460,61 @@ impl TokenClassificationModel { } } - fn consolidate_tokens(&self, tokens: Vec, label_aggregation_function: &LabelAggregationOption) -> Vec { - let mut consolidated_tokens: Vec = vec!(); - let mut current_tokens: Vec = vec!(); - for sub_token in tokens.iter() { - if (sub_token.mask != Mask::Continuation) & (sub_token.mask != Mask::InexactContinuation) { - match current_tokens.len() { - 0 => {} - 1 => consolidated_tokens.push(current_tokens[0].clone()), - _ => { - let (label_index, label) = self.consolidate_labels(¤t_tokens, label_aggregation_function); - let sentence = (¤t_tokens[0]).sentence; - let index = (¤t_tokens[0]).index; - let word_index = (¤t_tokens[0]).word_index; - let offset_start = match ¤t_tokens.first().unwrap().offset { - Some(offset) => Some(offset.begin), - None => None - }; - let offset_end = match ¤t_tokens.last().unwrap().offset { - Some(offset) => Some(offset.end), - None => None - }; - let offset = if offset_start.is_some() & offset_end.is_some() { - Some(Offset::new(offset_start.unwrap(), offset_end.unwrap())) - } else { - None - }; - let mut text = String::new(); - let mut score = 1f64; - for current_sub_token in current_tokens.into_iter() { - text.push_str(current_sub_token.text.as_str()); - score *= if current_sub_token.label_index == label_index { - current_sub_token.score - } else { - 1.0 - current_sub_token.score - }; - } - consolidated_tokens.push( - Token { - text, - score, - label, - label_index, - sentence, - index, - word_index, - offset, - mask: Default::default(), - } - ) - } + fn consolidate_tokens(&self, tokens: &mut Vec, label_aggregation_function: &LabelAggregationOption) { + let mut tokens_to_replace = vec!(); + let mut token_iter = tokens.iter_consolidate_tokens(); + let mut cursor = 0; + + while let Some(sub_tokens) = token_iter.next() { + if sub_tokens.len() > 1 { + let (label_index, label) = self.consolidate_labels(sub_tokens, label_aggregation_function); + let sentence = (&sub_tokens[0]).sentence; + let index = (&sub_tokens[0]).index; + let word_index = (&sub_tokens[0]).word_index; + let offset_start = match &sub_tokens.first().unwrap().offset { + Some(offset) => Some(offset.begin), + None => None }; - current_tokens = vec!(sub_token.clone()); - } else { - current_tokens.push(sub_token.clone()); + let offset_end = match &sub_tokens.last().unwrap().offset { + Some(offset) => Some(offset.end), + None => None + }; + let offset = if offset_start.is_some() & offset_end.is_some() { + Some(Offset::new(offset_start.unwrap(), offset_end.unwrap())) + } else { + None + }; + let mut text = String::new(); + let mut score = 1f64; + for current_sub_token in sub_tokens.into_iter() { + text.push_str(current_sub_token.text.as_str()); + score *= if current_sub_token.label_index == label_index { + current_sub_token.score + } else { + 1.0 - current_sub_token.score + }; + } + let token = Token { + text, + score, + label, + label_index, + sentence, + index, + word_index, + offset, + mask: Default::default(), + }; + tokens_to_replace.push(((cursor, cursor + sub_tokens.len()), token)); } + cursor += sub_tokens.len(); + } + for ((start, end), token) in tokens_to_replace.into_iter().rev() { + tokens.splice(start..end, [token].iter().cloned()); } - consolidated_tokens } - fn consolidate_labels(&self, tokens: &Vec, aggregation: &LabelAggregationOption) -> (i64, String) { + fn consolidate_labels(&self, tokens: &[Token], aggregation: &LabelAggregationOption) -> (i64, String) { match aggregation { LabelAggregationOption::First => { let token = tokens.first().unwrap();