From fb5a35cf4f9420c2b647efcc50aa587d0dfab8a0 Mon Sep 17 00:00:00 2001 From: Guillaume B Date: Tue, 22 Sep 2020 21:07:28 +0200 Subject: [PATCH 1/5] Addition of entity chunking --- examples/ner.rs | 3 +- src/pipelines/ner.rs | 165 ++++++++++++++++++++++++++++++++++++++++++- tests/bert.rs | 24 +++++++ 3 files changed, 189 insertions(+), 3 deletions(-) diff --git a/examples/ner.rs b/examples/ner.rs index 676175e..a25975b 100644 --- a/examples/ner.rs +++ b/examples/ner.rs @@ -22,10 +22,11 @@ fn main() -> anyhow::Result<()> { let input = [ "My name is Amélie. I live in Москва.", "Chongqing is a city in China.", + "Asked John Smith about Acme Corp", ]; // Run model - let output = ner_model.predict(&input); + let output = ner_model.predict_full_entities(&input); for entity in output { println!("{:?}", entity); } diff --git a/src/pipelines/ner.rs b/src/pipelines/ner.rs index 1cbc696..54d63a3 100644 --- a/src/pipelines/ner.rs +++ b/src/pipelines/ner.rs @@ -1,4 +1,5 @@ // Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc. +// Copyright (c) 2018 chakki (https://github.com/chakki-works/seqeval/blob/master/seqeval/metrics/sequence_labeling.py) // Copyright 2019 Guillaume Becquin // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -118,9 +119,12 @@ //! Dutch| XLM_ROBERTA_NER_NL | use crate::common::error::RustBertError; -use crate::pipelines::token_classification::{TokenClassificationConfig, TokenClassificationModel}; +use crate::pipelines::token_classification::{ + Token, TokenClassificationConfig, TokenClassificationModel, +}; +use std::borrow::BorrowMut; -#[derive(Debug)] +#[derive(Debug, Clone)] /// # Entity generated by a `NERModel` pub struct Entity { /// String representation of the Entity @@ -200,4 +204,161 @@ impl NERModel { }) .collect() } + + /// Extract full entities from a text performing entity chunking. Follows the algorithm for entities + /// chunking proposed by the [CoNLL-2000 shared task](https://www.clips.uantwerpen.be/conll2002/ner/bin/conlleval.txt) + /// and described in [Erik F. Tjong Kim Sang, Jorn Veenstra, Representing Text Chunks](https://www.aclweb.org/anthology/E99-1023/) + /// The proposed implementation is inspired by the [Python seqeval library](https://github.com/chakki-works/seqeval) (shared under MIT license). + /// + /// # Arguments + /// + /// * `input` - `&[&str]` Array of texts to extract entities from. + /// + /// # Returns + /// + /// * `Vec` containing consolidated extracted entities + /// + /// # Example + /// + /// ```no_run + /// # fn main() -> anyhow::Result<()> { + /// # use rust_bert::pipelines::ner::NERModel; + /// + /// let ner_model = NERModel::new(Default::default())?; + /// let input = [ + /// "Asked John Smith about Acme Corp", + /// ]; + /// let output = ner_model.predict_full_entities(&input); + /// # Ok(()) + /// # } + /// ``` + /// + /// Outputs: + /// + /// Output: \ + /// ```no_run + /// # use rust_bert::pipelines::question_answering::Answer; + /// # use rust_bert::pipelines::ner::Entity; + /// # let output = + /// [ + /// Entity { + /// word: String::from("John Smith"), + /// score: 0.9747, + /// label: String::from("PER"), + /// }, + /// Entity { + /// word: String::from("Acme Corp"), + /// score: 0.8847, + /// label: String::from("I-LOC"), + /// }, + /// ] + /// # ; + ///``` + /// + /// + pub fn predict_full_entities(&self, input: &[&str]) -> Vec { + let mut tokens = self.token_classification_model.predict(input, true, false); + let mut entities: Vec = Vec::new(); + + let mut current_entity: Option = None; + let mut previous_tag = Tag::Outside; + let mut previous_label = ""; + let mut current_tag: Tag; + let mut current_label: &str; + + tokens.push(Token { + text: "X".into(), + score: 1.0, + label: "O-X".to_string(), + label_index: 0, + sentence: 0, + index: 0, + word_index: 0, + offset: None, + mask: Default::default(), + }); + + for token in tokens.iter() { + println!("{:?}", token); + current_tag = token.get_tag(); + current_label = token.get_label(); + + if (previous_tag == Tag::End) + | (previous_tag == Tag::Single) + | match (previous_tag, current_tag) { + (Tag::Begin, Tag::Begin) => true, + (Tag::Begin, Tag::Outside) => true, + (Tag::Begin, Tag::Single) => true, + (Tag::Inside, Tag::Begin) => true, + (Tag::Inside, Tag::Outside) => true, + (Tag::Inside, Tag::Single) => true, + _ => false, + } + | ((previous_label != current_label) & (previous_tag != Tag::Outside)) + { + if let Some(entity) = current_entity { + entities.push(entity.clone()); + current_entity = None; + }; + } else if let Some(current_entity_value) = current_entity.borrow_mut() { + current_entity_value.word.push(' '); + current_entity_value.word.push_str(token.text.as_str()); + current_entity_value.score *= token.score; + }; + + if (current_tag == Tag::Begin) + | (current_tag == Tag::Single) + | match (previous_tag, current_tag) { + (Tag::End, Tag::End) => true, + (Tag::Single, Tag::End) => true, + (Tag::Outside, Tag::End) => true, + (Tag::End, Tag::Inside) => true, + (Tag::Single, Tag::Inside) => true, + (Tag::Outside, Tag::Inside) => true, + _ => false, + } + | ((previous_label != current_label) & (previous_tag != Tag::Outside)) + { + current_entity = Some(Entity { + word: token.text.clone(), + score: token.score, + label: current_label.to_string(), + }); + }; + previous_tag = current_tag; + previous_label = current_label; + } + entities + } +} + +#[derive(Debug, Clone, Copy, PartialEq)] +enum Tag { + Begin, + Inside, + Outside, + End, + Single, +} + +impl Token { + fn get_tag(&self) -> Tag { + match self.label.split('-').collect::>()[0] { + "B" => Tag::Begin, + "I" => Tag::Inside, + "O" => Tag::Outside, + "E" => Tag::End, + "S" => Tag::Single, + _ => panic!("Invalid tag encountered for token {:?}", self), + } + } + + fn get_label(&self) -> &str { + let split_label = self.label.split('-').collect::>(); + if split_label.len() > 1 { + split_label[1] + } else { + "" + } + } } diff --git a/tests/bert.rs b/tests/bert.rs index db8cd44..f2a44e2 100644 --- a/tests/bert.rs +++ b/tests/bert.rs @@ -391,6 +391,30 @@ fn bert_pre_trained_ner() -> anyhow::Result<()> { Ok(()) } +#[test] +fn bert_pre_trained_ner_full_entities() -> anyhow::Result<()> { + // Set-up model + let ner_model = NERModel::new(Default::default())?; + + // Define input + let input = ["Asked John Smith about Acme Corp."]; + + // Run model + let output = ner_model.predict_full_entities(&input); + + assert_eq!(output.len(), 2); + + assert_eq!(output[0].word, "John Smith "); + assert!((output[0].score - 0.9747).abs() < 1e-4); + assert_eq!(output[0].label, "PER"); + + assert_eq!(output[1].word, "Acme Corp"); + assert!((output[1].score - 0.8847).abs() < 1e-4); + assert_eq!(output[1].label, "ORG"); + + Ok(()) +} + #[test] fn bert_question_answering() -> anyhow::Result<()> { // Set-up question answering model From 0fbf3f7156e92dd947fb02b74c048f0739fdd26c Mon Sep 17 00:00:00 2001 From: Guillaume B Date: Wed, 6 Jan 2021 08:41:36 +0100 Subject: [PATCH 2/5] updated entity collection docstring --- src/pipelines/ner.rs | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/pipelines/ner.rs b/src/pipelines/ner.rs index 54d63a3..8cf4960 100644 --- a/src/pipelines/ner.rs +++ b/src/pipelines/ner.rs @@ -206,8 +206,7 @@ impl NERModel { } /// Extract full entities from a text performing entity chunking. Follows the algorithm for entities - /// chunking proposed by the [CoNLL-2000 shared task](https://www.clips.uantwerpen.be/conll2002/ner/bin/conlleval.txt) - /// and described in [Erik F. Tjong Kim Sang, Jorn Veenstra, Representing Text Chunks](https://www.aclweb.org/anthology/E99-1023/) + /// chunking described in [Erik F. Tjong Kim Sang, Jorn Veenstra, Representing Text Chunks](https://www.aclweb.org/anthology/E99-1023/) /// The proposed implementation is inspired by the [Python seqeval library](https://github.com/chakki-works/seqeval) (shared under MIT license). /// /// # Arguments @@ -279,7 +278,6 @@ impl NERModel { }); for token in tokens.iter() { - println!("{:?}", token); current_tag = token.get_tag(); current_label = token.get_label(); From 1e6875019a426f6b4a7aa9fee2c47e7f7022d314 Mon Sep 17 00:00:00 2001 From: Guillaume Becquin Date: Sat, 20 Nov 2021 11:48:56 +0100 Subject: [PATCH 3/5] Updated based on latest main branch changes --- src/bert/encoder.rs | 15 ++++++-- src/fnet/mod.rs | 2 +- src/pipelines/ner.rs | 81 ++++++++++++++++++++++++-------------------- tests/bert.rs | 18 ++++++---- 4 files changed, 68 insertions(+), 48 deletions(-) diff --git a/src/bert/encoder.rs b/src/bert/encoder.rs index b88c13f..6a7c89d 100644 --- a/src/bert/encoder.rs +++ b/src/bert/encoder.rs @@ -117,7 +117,10 @@ impl BertLayer { /// # let config = BertConfig::from_file(config_path); /// let layer: BertLayer = BertLayer::new(&vs.root(), &config); /// let (batch_size, sequence_length, hidden_size) = (64, 128, 512); - /// let input_tensor = Tensor::rand(&[batch_size, sequence_length, hidden_size], (Kind::Float, device)); + /// let input_tensor = Tensor::rand( + /// &[batch_size, sequence_length, hidden_size], + /// (Kind::Float, device), + /// ); /// let mask = Tensor::zeros(&[batch_size, sequence_length], (Kind::Int64, device)); /// /// let layer_output = no_grad(|| layer.forward_t(&input_tensor, Some(&mask), None, None, false)); @@ -242,7 +245,10 @@ impl BertEncoder { /// # let config = BertConfig::from_file(config_path); /// let encoder: BertEncoder = BertEncoder::new(&vs.root(), &config); /// let (batch_size, sequence_length, hidden_size) = (64, 128, 512); - /// let input_tensor = Tensor::rand(&[batch_size, sequence_length, hidden_size], (Kind::Float, device)); + /// let input_tensor = Tensor::rand( + /// &[batch_size, sequence_length, hidden_size], + /// (Kind::Float, device), + /// ); /// let mask = Tensor::zeros(&[batch_size, sequence_length], (Kind::Int8, device)); /// /// let encoder_output = @@ -368,7 +374,10 @@ impl BertPooler { /// # let config = BertConfig::from_file(config_path); /// let pooler: BertPooler = BertPooler::new(&vs.root(), &config); /// let (batch_size, sequence_length, hidden_size) = (64, 128, 512); - /// let input_tensor = Tensor::rand(&[batch_size, sequence_length, hidden_size], (Kind::Float, device)); + /// let input_tensor = Tensor::rand( + /// &[batch_size, sequence_length, hidden_size], + /// (Kind::Float, device), + /// ); /// /// let pooler_output = no_grad(|| pooler.forward(&input_tensor)); /// ``` diff --git a/src/fnet/mod.rs b/src/fnet/mod.rs index 7ac8914..388a665 100644 --- a/src/fnet/mod.rs +++ b/src/fnet/mod.rs @@ -21,8 +21,8 @@ //! # //! use tch::{nn, Device}; //! # use std::path::PathBuf; -//! use rust_bert::resources::{LocalResource, RemoteResource, Resource}; //! use rust_bert::fnet::{FNetConfig, FNetForMaskedLM}; +//! use rust_bert::resources::{LocalResource, RemoteResource, Resource}; //! use rust_bert::Config; //! use rust_tokenizers::tokenizer::{BertTokenizer, FNetTokenizer}; //! diff --git a/src/pipelines/ner.rs b/src/pipelines/ner.rs index 252ed70..a5da4c0 100644 --- a/src/pipelines/ner.rs +++ b/src/pipelines/ner.rs @@ -122,7 +122,9 @@ //! Dutch| XLM_ROBERTA_NER_NL | use crate::common::error::RustBertError; -use crate::pipelines::token_classification::{TokenClassificationConfig, TokenClassificationModel}; +use crate::pipelines::token_classification::{ + Token, TokenClassificationConfig, TokenClassificationModel, +}; use serde::{Deserialize, Serialize}; #[derive(Debug, Clone, Serialize, Deserialize)] @@ -213,17 +215,6 @@ impl NERModel { }) .collect::>>() } -} -#[cfg(test)] -mod test { - use super::*; - - #[test] - #[ignore] // no need to run, compilation is enough to verify it is Send - fn test() { - let config = NERConfig::default(); - let _: Box = Box::new(NERModel::new(config)); - } /// Extract full entities from a text performing entity chunking. Follows the algorithm for entities /// chunking described in [Erik F. Tjong Kim Sang, Jorn Veenstra, Representing Text Chunks](https://www.aclweb.org/anthology/E99-1023/) @@ -244,9 +235,7 @@ mod test { /// # use rust_bert::pipelines::ner::NERModel; /// /// let ner_model = NERModel::new(Default::default())?; - /// let input = [ - /// "Asked John Smith about Acme Corp", - /// ]; + /// let input = ["Asked John Smith about Acme Corp"]; /// let output = ner_model.predict_full_entities(&input); /// # Ok(()) /// # } @@ -259,7 +248,7 @@ mod test { /// # use rust_bert::pipelines::question_answering::Answer; /// # use rust_bert::pipelines::ner::Entity; /// # let output = - /// [ + /// [[ /// Entity { /// word: String::from("John Smith"), /// score: 0.9747, @@ -270,20 +259,28 @@ mod test { /// score: 0.8847, /// label: String::from("I-LOC"), /// }, - /// ] + /// ]] /// # ; - ///``` + /// ``` /// - /// - pub fn predict_full_entities(&self, input: &[&str]) -> Vec { - let mut tokens = self.token_classification_model.predict(input, true, false); + pub fn predict_full_entities(&self, input: &[&str]) -> Vec> { + let tokens = self.token_classification_model.predict(input, true, false); + let mut entities: Vec> = Vec::new(); + + for mut sequence_tokens in tokens { + entities.push(Self::consolidate_entities(&mut sequence_tokens)); + } + entities + } + + fn consolidate_entities(tokens: &mut Vec) -> Vec { let mut entities: Vec = Vec::new(); - let mut current_entity: Option = None; let mut previous_tag = Tag::Outside; let mut previous_label = ""; let mut current_tag: Tag; let mut current_label: &str; + let mut begin_offset = 0; tokens.push(Token { text: "X".into(), @@ -297,7 +294,7 @@ mod test { mask: Default::default(), }); - for token in tokens.iter() { + for (position, token) in tokens.iter().enumerate() { current_tag = token.get_tag(); current_label = token.get_label(); @@ -314,15 +311,17 @@ mod test { } | ((previous_label != current_label) & (previous_tag != Tag::Outside)) { - if let Some(entity) = current_entity { - entities.push(entity.clone()); - current_entity = None; - }; - } else if let Some(current_entity_value) = current_entity.borrow_mut() { - current_entity_value.word.push(' '); - current_entity_value.word.push_str(token.text.as_str()); - current_entity_value.score *= token.score; - }; + let entity_tokens = &tokens[begin_offset..position]; + entities.push(Entity { + word: entity_tokens + .iter() + .map(|token| token.text.as_str()) + .collect::>() + .join(" "), + score: entity_tokens.iter().map(|token| token.score).product(), + label: previous_label.to_string(), + }) + } if (current_tag == Tag::Begin) | (current_tag == Tag::Single) @@ -337,11 +336,7 @@ mod test { } | ((previous_label != current_label) & (previous_tag != Tag::Outside)) { - current_entity = Some(Entity { - word: token.text.clone(), - score: token.score, - label: current_label.to_string(), - }); + begin_offset = position; }; previous_tag = current_tag; previous_label = current_label; @@ -380,3 +375,15 @@ impl Token { } } } + +#[cfg(test)] +mod test { + use super::*; + + #[test] + #[ignore] // no need to run, compilation is enough to verify it is Send + fn test() { + let config = NERConfig::default(); + let _: Box = Box::new(NERModel::new(config)); + } +} diff --git a/tests/bert.rs b/tests/bert.rs index 27f09fa..caf20bb 100644 --- a/tests/bert.rs +++ b/tests/bert.rs @@ -395,20 +395,24 @@ fn bert_pre_trained_ner_full_entities() -> anyhow::Result<()> { let ner_model = NERModel::new(Default::default())?; // Define input - let input = ["Asked John Smith about Acme Corp."]; + let input = ["Asked John Smith about Acme Corp", "Let's go to New York!"]; // Run model let output = ner_model.predict_full_entities(&input); assert_eq!(output.len(), 2); - assert_eq!(output[0].word, "John Smith "); - assert!((output[0].score - 0.9747).abs() < 1e-4); - assert_eq!(output[0].label, "PER"); + assert_eq!(output[0][0].word, "John Smith"); + assert!((output[0][0].score - 0.9872).abs() < 1e-4); + assert_eq!(output[0][0].label, "PER"); - assert_eq!(output[1].word, "Acme Corp"); - assert!((output[1].score - 0.8847).abs() < 1e-4); - assert_eq!(output[1].label, "ORG"); + assert_eq!(output[0][1].word, "Acme Corp"); + assert!((output[0][1].score - 0.9622).abs() < 1e-4); + assert_eq!(output[0][1].label, "ORG"); + + assert_eq!(output[1][0].word, "New York"); + assert!((output[1][0].score - 0.9991).abs() < 1e-4); + assert_eq!(output[1][0].label, "LOC"); Ok(()) } From 067bac0d55b816bcd19150785078d9a5d30aaa4c Mon Sep 17 00:00:00 2001 From: Guillaume Becquin Date: Sun, 21 Nov 2021 12:39:13 +0100 Subject: [PATCH 4/5] Updated changelog, fixed Clippy warning --- CHANGELOG.md | 1 + src/pipelines/ner.rs | 36 ++++++++++++++++++------------------ 2 files changed, 19 insertions(+), 18 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 87d1451..6518ce3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,7 @@ All notable changes to this project will be documented in this file. The format - Support for half-precision mode for all models (reducing memory footprint). A model can be converted to half-precision by calling the `half()` method on the `VarStore` is it currently stored in. Half-precision Torch kernels are not available for CPU (limited to CUDA devices) - (BREAKING) Extension of the generation options that can be provided at runtime (after a model has been instantiated with a `GenerateConfig`), allowing to update the generation options from one text generation to another with the same model. This feature is implemented at the `LanguageGenerator` trait level, the high-level `TextGeneration` pipeline API remains unchanged. - Addition of the FNet language model and support for sequence, token and multiple choice classification, question answering +- Addition of a full entities' prediction method supporting the IOBES scheme (merging entities token such as + -> ) ## [0.16.0] - 2021-08-24 ## Added diff --git a/src/pipelines/ner.rs b/src/pipelines/ner.rs index a5da4c0..a23109e 100644 --- a/src/pipelines/ner.rs +++ b/src/pipelines/ner.rs @@ -300,15 +300,15 @@ impl NERModel { if (previous_tag == Tag::End) | (previous_tag == Tag::Single) - | match (previous_tag, current_tag) { - (Tag::Begin, Tag::Begin) => true, - (Tag::Begin, Tag::Outside) => true, - (Tag::Begin, Tag::Single) => true, - (Tag::Inside, Tag::Begin) => true, - (Tag::Inside, Tag::Outside) => true, - (Tag::Inside, Tag::Single) => true, - _ => false, - } + | matches!( + (previous_tag, current_tag), + (Tag::Begin, Tag::Begin) + | (Tag::Begin, Tag::Outside) + | (Tag::Begin, Tag::Single) + | (Tag::Inside, Tag::Begin) + | (Tag::Inside, Tag::Outside) + | (Tag::Inside, Tag::Single) + ) | ((previous_label != current_label) & (previous_tag != Tag::Outside)) { let entity_tokens = &tokens[begin_offset..position]; @@ -325,15 +325,15 @@ impl NERModel { if (current_tag == Tag::Begin) | (current_tag == Tag::Single) - | match (previous_tag, current_tag) { - (Tag::End, Tag::End) => true, - (Tag::Single, Tag::End) => true, - (Tag::Outside, Tag::End) => true, - (Tag::End, Tag::Inside) => true, - (Tag::Single, Tag::Inside) => true, - (Tag::Outside, Tag::Inside) => true, - _ => false, - } + | matches!( + (previous_tag, current_tag), + (Tag::End, Tag::End) + | (Tag::Single, Tag::End) + | (Tag::Outside, Tag::End) + | (Tag::End, Tag::Inside) + | (Tag::Single, Tag::Inside) + | (Tag::Outside, Tag::Inside) + ) | ((previous_label != current_label) & (previous_tag != Tag::Outside)) { begin_offset = position; From 28fc22c70f6b54861a3529f4370f38160e817260 Mon Sep 17 00:00:00 2001 From: Guillaume Becquin Date: Mon, 22 Nov 2021 15:37:34 +0100 Subject: [PATCH 5/5] Updated entity consolidation logic --- src/pipelines/ner.rs | 145 ++++++++++++++++++++++++------------------- 1 file changed, 80 insertions(+), 65 deletions(-) diff --git a/src/pipelines/ner.rs b/src/pipelines/ner.rs index a23109e..90fa05c 100644 --- a/src/pipelines/ner.rs +++ b/src/pipelines/ner.rs @@ -267,84 +267,99 @@ impl NERModel { let tokens = self.token_classification_model.predict(input, true, false); let mut entities: Vec> = Vec::new(); - for mut sequence_tokens in tokens { - entities.push(Self::consolidate_entities(&mut sequence_tokens)); + for sequence_tokens in tokens { + entities.push(Self::consolidate_entities(&sequence_tokens)); } entities } - fn consolidate_entities(tokens: &mut Vec) -> Vec { + fn consolidate_entities(tokens: &[Token]) -> Vec { let mut entities: Vec = Vec::new(); - let mut previous_tag = Tag::Outside; - let mut previous_label = ""; - let mut current_tag: Tag; - let mut current_label: &str; - let mut begin_offset = 0; - - tokens.push(Token { - text: "X".into(), - score: 1.0, - label: "O-X".to_string(), - label_index: 0, - sentence: 0, - index: 0, - word_index: 0, - offset: None, - mask: Default::default(), - }); - + let mut entity_builder = EntityBuilder::new(); for (position, token) in tokens.iter().enumerate() { - current_tag = token.get_tag(); - current_label = token.get_label(); - - if (previous_tag == Tag::End) - | (previous_tag == Tag::Single) - | matches!( - (previous_tag, current_tag), - (Tag::Begin, Tag::Begin) - | (Tag::Begin, Tag::Outside) - | (Tag::Begin, Tag::Single) - | (Tag::Inside, Tag::Begin) - | (Tag::Inside, Tag::Outside) - | (Tag::Inside, Tag::Single) - ) - | ((previous_label != current_label) & (previous_tag != Tag::Outside)) - { - let entity_tokens = &tokens[begin_offset..position]; - entities.push(Entity { - word: entity_tokens - .iter() - .map(|token| token.text.as_str()) - .collect::>() - .join(" "), - score: entity_tokens.iter().map(|token| token.score).product(), - label: previous_label.to_string(), - }) + let tag = token.get_tag(); + let label = token.get_label(); + if let Some(entity) = entity_builder.handle_current_tag(tag, label, position, tokens) { + entities.push(entity) } - - if (current_tag == Tag::Begin) - | (current_tag == Tag::Single) - | matches!( - (previous_tag, current_tag), - (Tag::End, Tag::End) - | (Tag::Single, Tag::End) - | (Tag::Outside, Tag::End) - | (Tag::End, Tag::Inside) - | (Tag::Single, Tag::Inside) - | (Tag::Outside, Tag::Inside) - ) - | ((previous_label != current_label) & (previous_tag != Tag::Outside)) - { - begin_offset = position; - }; - previous_tag = current_tag; - previous_label = current_label; + } + if let Some(entity) = entity_builder.flush_and_reset(tokens.len(), tokens) { + entities.push(entity); } entities } } +struct EntityBuilder<'a> { + previous_node: Option<(usize, Tag, &'a str)>, +} + +impl<'a> EntityBuilder<'a> { + fn new() -> Self { + EntityBuilder { + previous_node: None, + } + } + + fn handle_current_tag( + &mut self, + tag: Tag, + label: &'a str, + position: usize, + tokens: &[Token], + ) -> Option { + match tag { + Tag::Outside => self.flush_and_reset(position, tokens), + Tag::Begin | Tag::Single => { + let entity = self.flush_and_reset(position, tokens); + self.start_new(position, tag, label); + entity + } + Tag::Inside | Tag::End => { + if let Some((_, previous_tag, previous_label)) = self.previous_node { + if (previous_tag == Tag::End) + | (previous_tag == Tag::Single) + | (previous_label != label) + { + let entity = self.flush_and_reset(position, tokens); + self.start_new(position, tag, label); + entity + } else { + None + } + } else { + self.start_new(position, tag, label); + None + } + } + } + } + + fn flush_and_reset(&mut self, position: usize, tokens: &[Token]) -> Option { + let entity = if let Some((start, _, label)) = self.previous_node { + let entity_tokens = &tokens[start..position]; + Some(Entity { + word: entity_tokens + .iter() + .map(|token| token.text.as_str()) + .collect::>() + .join(" "), + score: entity_tokens.iter().map(|token| token.score).product(), + label: label.to_string(), + }) + } else { + None + }; + self.previous_node = None; + entity + } + + fn start_new(&mut self, position: usize, tag: Tag, label: &'a str) { + self.previous_node = Some((position, tag, label)) + } +} + #[derive(Debug, Clone, Copy, PartialEq)] enum Tag { Begin,