Addition of entity chunking

This commit is contained in:
Guillaume B 2020-09-22 21:07:28 +02:00
parent ed219f42e0
commit fb5a35cf4f
3 changed files with 189 additions and 3 deletions

View File

@ -22,10 +22,11 @@ fn main() -> anyhow::Result<()> {
let input = [
"My name is Amélie. I live in Москва.",
"Chongqing is a city in China.",
"Asked John Smith about Acme Corp",
];
// Run model
let output = ner_model.predict(&input);
let output = ner_model.predict_full_entities(&input);
for entity in output {
println!("{:?}", entity);
}

View File

@ -1,4 +1,5 @@
// Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc.
// Copyright (c) 2018 chakki (https://github.com/chakki-works/seqeval/blob/master/seqeval/metrics/sequence_labeling.py)
// Copyright 2019 Guillaume Becquin
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@ -118,9 +119,12 @@
//! Dutch| XLM_ROBERTA_NER_NL |
use crate::common::error::RustBertError;
use crate::pipelines::token_classification::{TokenClassificationConfig, TokenClassificationModel};
use crate::pipelines::token_classification::{
Token, TokenClassificationConfig, TokenClassificationModel,
};
use std::borrow::BorrowMut;
#[derive(Debug)]
#[derive(Debug, Clone)]
/// # Entity generated by a `NERModel`
pub struct Entity {
/// String representation of the Entity
@ -200,4 +204,161 @@ impl NERModel {
})
.collect()
}
/// Extract full entities from a text performing entity chunking. Follows the algorithm for entities
/// chunking proposed by the [CoNLL-2000 shared task](https://www.clips.uantwerpen.be/conll2002/ner/bin/conlleval.txt)
/// and described in [Erik F. Tjong Kim Sang, Jorn Veenstra, Representing Text Chunks](https://www.aclweb.org/anthology/E99-1023/)
/// The proposed implementation is inspired by the [Python seqeval library](https://github.com/chakki-works/seqeval) (shared under MIT license).
///
/// # Arguments
///
/// * `input` - `&[&str]` Array of texts to extract entities from.
///
/// # Returns
///
/// * `Vec<Entity>` containing consolidated extracted entities
///
/// # Example
///
/// ```no_run
/// # fn main() -> anyhow::Result<()> {
/// # use rust_bert::pipelines::ner::NERModel;
///
/// let ner_model = NERModel::new(Default::default())?;
/// let input = [
/// "Asked John Smith about Acme Corp",
/// ];
/// let output = ner_model.predict_full_entities(&input);
/// # Ok(())
/// # }
/// ```
///
/// Outputs:
///
/// Output: \
/// ```no_run
/// # use rust_bert::pipelines::question_answering::Answer;
/// # use rust_bert::pipelines::ner::Entity;
/// # let output =
/// [
/// Entity {
/// word: String::from("John Smith"),
/// score: 0.9747,
/// label: String::from("PER"),
/// },
/// Entity {
/// word: String::from("Acme Corp"),
/// score: 0.8847,
/// label: String::from("I-LOC"),
/// },
/// ]
/// # ;
///```
///
///
pub fn predict_full_entities(&self, input: &[&str]) -> Vec<Entity> {
let mut tokens = self.token_classification_model.predict(input, true, false);
let mut entities: Vec<Entity> = Vec::new();
let mut current_entity: Option<Entity> = None;
let mut previous_tag = Tag::Outside;
let mut previous_label = "";
let mut current_tag: Tag;
let mut current_label: &str;
tokens.push(Token {
text: "X".into(),
score: 1.0,
label: "O-X".to_string(),
label_index: 0,
sentence: 0,
index: 0,
word_index: 0,
offset: None,
mask: Default::default(),
});
for token in tokens.iter() {
println!("{:?}", token);
current_tag = token.get_tag();
current_label = token.get_label();
if (previous_tag == Tag::End)
| (previous_tag == Tag::Single)
| match (previous_tag, current_tag) {
(Tag::Begin, Tag::Begin) => true,
(Tag::Begin, Tag::Outside) => true,
(Tag::Begin, Tag::Single) => true,
(Tag::Inside, Tag::Begin) => true,
(Tag::Inside, Tag::Outside) => true,
(Tag::Inside, Tag::Single) => true,
_ => false,
}
| ((previous_label != current_label) & (previous_tag != Tag::Outside))
{
if let Some(entity) = current_entity {
entities.push(entity.clone());
current_entity = None;
};
} else if let Some(current_entity_value) = current_entity.borrow_mut() {
current_entity_value.word.push(' ');
current_entity_value.word.push_str(token.text.as_str());
current_entity_value.score *= token.score;
};
if (current_tag == Tag::Begin)
| (current_tag == Tag::Single)
| match (previous_tag, current_tag) {
(Tag::End, Tag::End) => true,
(Tag::Single, Tag::End) => true,
(Tag::Outside, Tag::End) => true,
(Tag::End, Tag::Inside) => true,
(Tag::Single, Tag::Inside) => true,
(Tag::Outside, Tag::Inside) => true,
_ => false,
}
| ((previous_label != current_label) & (previous_tag != Tag::Outside))
{
current_entity = Some(Entity {
word: token.text.clone(),
score: token.score,
label: current_label.to_string(),
});
};
previous_tag = current_tag;
previous_label = current_label;
}
entities
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
enum Tag {
Begin,
Inside,
Outside,
End,
Single,
}
impl Token {
fn get_tag(&self) -> Tag {
match self.label.split('-').collect::<Vec<&str>>()[0] {
"B" => Tag::Begin,
"I" => Tag::Inside,
"O" => Tag::Outside,
"E" => Tag::End,
"S" => Tag::Single,
_ => panic!("Invalid tag encountered for token {:?}", self),
}
}
fn get_label(&self) -> &str {
let split_label = self.label.split('-').collect::<Vec<&str>>();
if split_label.len() > 1 {
split_label[1]
} else {
""
}
}
}

View File

@ -391,6 +391,30 @@ fn bert_pre_trained_ner() -> anyhow::Result<()> {
Ok(())
}
#[test]
fn bert_pre_trained_ner_full_entities() -> anyhow::Result<()> {
// Set-up model
let ner_model = NERModel::new(Default::default())?;
// Define input
let input = ["Asked John Smith about Acme Corp."];
// Run model
let output = ner_model.predict_full_entities(&input);
assert_eq!(output.len(), 2);
assert_eq!(output[0].word, "John Smith ");
assert!((output[0].score - 0.9747).abs() < 1e-4);
assert_eq!(output[0].label, "PER");
assert_eq!(output[1].word, "Acme Corp");
assert!((output[1].score - 0.8847).abs() < 1e-4);
assert_eq!(output[1].label, "ORG");
Ok(())
}
#[test]
fn bert_question_answering() -> anyhow::Result<()> {
// Set-up question answering model