Merge pull request #201 from guillaume-be/entity_consolidation

Entity consolidation
This commit is contained in:
guillaume-be 2021-11-24 15:42:43 +01:00 committed by GitHub
commit b444780c18
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 224 additions and 7 deletions

View File

@ -11,6 +11,7 @@ All notable changes to this project will be documented in this file. The format
- Support for half-precision mode for all models (reducing memory footprint). A model can be converted to half-precision by calling the `half()` method on the `VarStore` is it currently stored in. Half-precision Torch kernels are not available for CPU (limited to CUDA devices)
- (BREAKING) Extension of the generation options that can be provided at runtime (after a model has been instantiated with a `GenerateConfig`), allowing to update the generation options from one text generation to another with the same model. This feature is implemented at the `LanguageGenerator` trait level, the high-level `TextGeneration` pipeline API remains unchanged.
- Addition of the FNet language model and support for sequence, token and multiple choice classification, question answering
- Addition of a full entities' prediction method supporting the IOBES scheme (merging entities token such as <New> + <York> -> <New York>)
## [0.16.0] - 2021-08-24
## Added

View File

@ -22,10 +22,11 @@ fn main() -> anyhow::Result<()> {
let input = [
"My name is Amélie. I live in Москва.",
"Chongqing is a city in China.",
"Asked John Smith about Acme Corp",
];
// Run model
let output = ner_model.predict(&input);
let output = ner_model.predict_full_entities(&input);
for entity in output {
println!("{:?}", entity);
}

View File

@ -117,7 +117,10 @@ impl BertLayer {
/// # let config = BertConfig::from_file(config_path);
/// let layer: BertLayer = BertLayer::new(&vs.root(), &config);
/// let (batch_size, sequence_length, hidden_size) = (64, 128, 512);
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length, hidden_size], (Kind::Float, device));
/// let input_tensor = Tensor::rand(
/// &[batch_size, sequence_length, hidden_size],
/// (Kind::Float, device),
/// );
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Kind::Int64, device));
///
/// let layer_output = no_grad(|| layer.forward_t(&input_tensor, Some(&mask), None, None, false));
@ -242,7 +245,10 @@ impl BertEncoder {
/// # let config = BertConfig::from_file(config_path);
/// let encoder: BertEncoder = BertEncoder::new(&vs.root(), &config);
/// let (batch_size, sequence_length, hidden_size) = (64, 128, 512);
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length, hidden_size], (Kind::Float, device));
/// let input_tensor = Tensor::rand(
/// &[batch_size, sequence_length, hidden_size],
/// (Kind::Float, device),
/// );
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Kind::Int8, device));
///
/// let encoder_output =
@ -368,7 +374,10 @@ impl BertPooler {
/// # let config = BertConfig::from_file(config_path);
/// let pooler: BertPooler = BertPooler::new(&vs.root(), &config);
/// let (batch_size, sequence_length, hidden_size) = (64, 128, 512);
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length, hidden_size], (Kind::Float, device));
/// let input_tensor = Tensor::rand(
/// &[batch_size, sequence_length, hidden_size],
/// (Kind::Float, device),
/// );
///
/// let pooler_output = no_grad(|| pooler.forward(&input_tensor));
/// ```

View File

@ -21,8 +21,8 @@
//! #
//! use tch::{nn, Device};
//! # use std::path::PathBuf;
//! use rust_bert::resources::{LocalResource, RemoteResource, Resource};
//! use rust_bert::fnet::{FNetConfig, FNetForMaskedLM};
//! use rust_bert::resources::{LocalResource, RemoteResource, Resource};
//! use rust_bert::Config;
//! use rust_tokenizers::tokenizer::{BertTokenizer, FNetTokenizer};
//!

View File

@ -1,4 +1,5 @@
// Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc.
// Copyright (c) 2018 chakki (https://github.com/chakki-works/seqeval/blob/master/seqeval/metrics/sequence_labeling.py)
// Copyright 2019 Guillaume Becquin
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@ -121,10 +122,12 @@
//! Dutch| XLM_ROBERTA_NER_NL |
use crate::common::error::RustBertError;
use crate::pipelines::token_classification::{TokenClassificationConfig, TokenClassificationModel};
use crate::pipelines::token_classification::{
Token, TokenClassificationConfig, TokenClassificationModel,
};
use serde::{Deserialize, Serialize};
#[derive(Debug, Serialize, Deserialize)]
#[derive(Debug, Clone, Serialize, Deserialize)]
/// # Entity generated by a `NERModel`
pub struct Entity {
/// String representation of the Entity
@ -212,7 +215,182 @@ impl NERModel {
})
.collect::<Vec<Vec<Entity>>>()
}
/// Extract full entities from a text performing entity chunking. Follows the algorithm for entities
/// chunking described in [Erik F. Tjong Kim Sang, Jorn Veenstra, Representing Text Chunks](https://www.aclweb.org/anthology/E99-1023/)
/// The proposed implementation is inspired by the [Python seqeval library](https://github.com/chakki-works/seqeval) (shared under MIT license).
///
/// # Arguments
///
/// * `input` - `&[&str]` Array of texts to extract entities from.
///
/// # Returns
///
/// * `Vec<Entity>` containing consolidated extracted entities
///
/// # Example
///
/// ```no_run
/// # fn main() -> anyhow::Result<()> {
/// # use rust_bert::pipelines::ner::NERModel;
///
/// let ner_model = NERModel::new(Default::default())?;
/// let input = ["Asked John Smith about Acme Corp"];
/// let output = ner_model.predict_full_entities(&input);
/// # Ok(())
/// # }
/// ```
///
/// Outputs:
///
/// Output: \
/// ```no_run
/// # use rust_bert::pipelines::question_answering::Answer;
/// # use rust_bert::pipelines::ner::Entity;
/// # let output =
/// [[
/// Entity {
/// word: String::from("John Smith"),
/// score: 0.9747,
/// label: String::from("PER"),
/// },
/// Entity {
/// word: String::from("Acme Corp"),
/// score: 0.8847,
/// label: String::from("I-LOC"),
/// },
/// ]]
/// # ;
/// ```
///
pub fn predict_full_entities(&self, input: &[&str]) -> Vec<Vec<Entity>> {
let tokens = self.token_classification_model.predict(input, true, false);
let mut entities: Vec<Vec<Entity>> = Vec::new();
for sequence_tokens in tokens {
entities.push(Self::consolidate_entities(&sequence_tokens));
}
entities
}
fn consolidate_entities(tokens: &[Token]) -> Vec<Entity> {
let mut entities: Vec<Entity> = Vec::new();
let mut entity_builder = EntityBuilder::new();
for (position, token) in tokens.iter().enumerate() {
let tag = token.get_tag();
let label = token.get_label();
if let Some(entity) = entity_builder.handle_current_tag(tag, label, position, tokens) {
entities.push(entity)
}
}
if let Some(entity) = entity_builder.flush_and_reset(tokens.len(), tokens) {
entities.push(entity);
}
entities
}
}
struct EntityBuilder<'a> {
previous_node: Option<(usize, Tag, &'a str)>,
}
impl<'a> EntityBuilder<'a> {
fn new() -> Self {
EntityBuilder {
previous_node: None,
}
}
fn handle_current_tag(
&mut self,
tag: Tag,
label: &'a str,
position: usize,
tokens: &[Token],
) -> Option<Entity> {
match tag {
Tag::Outside => self.flush_and_reset(position, tokens),
Tag::Begin | Tag::Single => {
let entity = self.flush_and_reset(position, tokens);
self.start_new(position, tag, label);
entity
}
Tag::Inside | Tag::End => {
if let Some((_, previous_tag, previous_label)) = self.previous_node {
if (previous_tag == Tag::End)
| (previous_tag == Tag::Single)
| (previous_label != label)
{
let entity = self.flush_and_reset(position, tokens);
self.start_new(position, tag, label);
entity
} else {
None
}
} else {
self.start_new(position, tag, label);
None
}
}
}
}
fn flush_and_reset(&mut self, position: usize, tokens: &[Token]) -> Option<Entity> {
let entity = if let Some((start, _, label)) = self.previous_node {
let entity_tokens = &tokens[start..position];
Some(Entity {
word: entity_tokens
.iter()
.map(|token| token.text.as_str())
.collect::<Vec<&str>>()
.join(" "),
score: entity_tokens.iter().map(|token| token.score).product(),
label: label.to_string(),
})
} else {
None
};
self.previous_node = None;
entity
}
fn start_new(&mut self, position: usize, tag: Tag, label: &'a str) {
self.previous_node = Some((position, tag, label))
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
enum Tag {
Begin,
Inside,
Outside,
End,
Single,
}
impl Token {
fn get_tag(&self) -> Tag {
match self.label.split('-').collect::<Vec<&str>>()[0] {
"B" => Tag::Begin,
"I" => Tag::Inside,
"O" => Tag::Outside,
"E" => Tag::End,
"S" => Tag::Single,
_ => panic!("Invalid tag encountered for token {:?}", self),
}
}
fn get_label(&self) -> &str {
let split_label = self.label.split('-').collect::<Vec<&str>>();
if split_label.len() > 1 {
split_label[1]
} else {
""
}
}
}
#[cfg(test)]
mod test {
use super::*;

View File

@ -389,6 +389,34 @@ fn bert_pre_trained_ner() -> anyhow::Result<()> {
Ok(())
}
#[test]
fn bert_pre_trained_ner_full_entities() -> anyhow::Result<()> {
// Set-up model
let ner_model = NERModel::new(Default::default())?;
// Define input
let input = ["Asked John Smith about Acme Corp", "Let's go to New York!"];
// Run model
let output = ner_model.predict_full_entities(&input);
assert_eq!(output.len(), 2);
assert_eq!(output[0][0].word, "John Smith");
assert!((output[0][0].score - 0.9872).abs() < 1e-4);
assert_eq!(output[0][0].label, "PER");
assert_eq!(output[0][1].word, "Acme Corp");
assert!((output[0][1].score - 0.9622).abs() < 1e-4);
assert_eq!(output[0][1].label, "ORG");
assert_eq!(output[1][0].word, "New York");
assert!((output[1][0].score - 0.9991).abs() < 1e-4);
assert_eq!(output[1][0].label, "LOC");
Ok(())
}
#[test]
fn bert_question_answering() -> anyhow::Result<()> {
// Set-up question answering model