mirror of
https://github.com/guillaume-be/rust-bert.git
synced 2024-08-16 16:10:25 +03:00
Merge pull request #201 from guillaume-be/entity_consolidation
Entity consolidation
This commit is contained in:
commit
b444780c18
@ -11,6 +11,7 @@ All notable changes to this project will be documented in this file. The format
|
||||
- Support for half-precision mode for all models (reducing memory footprint). A model can be converted to half-precision by calling the `half()` method on the `VarStore` is it currently stored in. Half-precision Torch kernels are not available for CPU (limited to CUDA devices)
|
||||
- (BREAKING) Extension of the generation options that can be provided at runtime (after a model has been instantiated with a `GenerateConfig`), allowing to update the generation options from one text generation to another with the same model. This feature is implemented at the `LanguageGenerator` trait level, the high-level `TextGeneration` pipeline API remains unchanged.
|
||||
- Addition of the FNet language model and support for sequence, token and multiple choice classification, question answering
|
||||
- Addition of a full entities' prediction method supporting the IOBES scheme (merging entities token such as <New> + <York> -> <New York>)
|
||||
|
||||
## [0.16.0] - 2021-08-24
|
||||
## Added
|
||||
|
@ -22,10 +22,11 @@ fn main() -> anyhow::Result<()> {
|
||||
let input = [
|
||||
"My name is Amélie. I live in Москва.",
|
||||
"Chongqing is a city in China.",
|
||||
"Asked John Smith about Acme Corp",
|
||||
];
|
||||
|
||||
// Run model
|
||||
let output = ner_model.predict(&input);
|
||||
let output = ner_model.predict_full_entities(&input);
|
||||
for entity in output {
|
||||
println!("{:?}", entity);
|
||||
}
|
||||
|
@ -117,7 +117,10 @@ impl BertLayer {
|
||||
/// # let config = BertConfig::from_file(config_path);
|
||||
/// let layer: BertLayer = BertLayer::new(&vs.root(), &config);
|
||||
/// let (batch_size, sequence_length, hidden_size) = (64, 128, 512);
|
||||
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length, hidden_size], (Kind::Float, device));
|
||||
/// let input_tensor = Tensor::rand(
|
||||
/// &[batch_size, sequence_length, hidden_size],
|
||||
/// (Kind::Float, device),
|
||||
/// );
|
||||
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Kind::Int64, device));
|
||||
///
|
||||
/// let layer_output = no_grad(|| layer.forward_t(&input_tensor, Some(&mask), None, None, false));
|
||||
@ -242,7 +245,10 @@ impl BertEncoder {
|
||||
/// # let config = BertConfig::from_file(config_path);
|
||||
/// let encoder: BertEncoder = BertEncoder::new(&vs.root(), &config);
|
||||
/// let (batch_size, sequence_length, hidden_size) = (64, 128, 512);
|
||||
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length, hidden_size], (Kind::Float, device));
|
||||
/// let input_tensor = Tensor::rand(
|
||||
/// &[batch_size, sequence_length, hidden_size],
|
||||
/// (Kind::Float, device),
|
||||
/// );
|
||||
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Kind::Int8, device));
|
||||
///
|
||||
/// let encoder_output =
|
||||
@ -368,7 +374,10 @@ impl BertPooler {
|
||||
/// # let config = BertConfig::from_file(config_path);
|
||||
/// let pooler: BertPooler = BertPooler::new(&vs.root(), &config);
|
||||
/// let (batch_size, sequence_length, hidden_size) = (64, 128, 512);
|
||||
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length, hidden_size], (Kind::Float, device));
|
||||
/// let input_tensor = Tensor::rand(
|
||||
/// &[batch_size, sequence_length, hidden_size],
|
||||
/// (Kind::Float, device),
|
||||
/// );
|
||||
///
|
||||
/// let pooler_output = no_grad(|| pooler.forward(&input_tensor));
|
||||
/// ```
|
||||
|
@ -21,8 +21,8 @@
|
||||
//! #
|
||||
//! use tch::{nn, Device};
|
||||
//! # use std::path::PathBuf;
|
||||
//! use rust_bert::resources::{LocalResource, RemoteResource, Resource};
|
||||
//! use rust_bert::fnet::{FNetConfig, FNetForMaskedLM};
|
||||
//! use rust_bert::resources::{LocalResource, RemoteResource, Resource};
|
||||
//! use rust_bert::Config;
|
||||
//! use rust_tokenizers::tokenizer::{BertTokenizer, FNetTokenizer};
|
||||
//!
|
||||
|
@ -1,4 +1,5 @@
|
||||
// Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc.
|
||||
// Copyright (c) 2018 chakki (https://github.com/chakki-works/seqeval/blob/master/seqeval/metrics/sequence_labeling.py)
|
||||
// Copyright 2019 Guillaume Becquin
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
@ -121,10 +122,12 @@
|
||||
//! Dutch| XLM_ROBERTA_NER_NL |
|
||||
|
||||
use crate::common::error::RustBertError;
|
||||
use crate::pipelines::token_classification::{TokenClassificationConfig, TokenClassificationModel};
|
||||
use crate::pipelines::token_classification::{
|
||||
Token, TokenClassificationConfig, TokenClassificationModel,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
/// # Entity generated by a `NERModel`
|
||||
pub struct Entity {
|
||||
/// String representation of the Entity
|
||||
@ -212,7 +215,182 @@ impl NERModel {
|
||||
})
|
||||
.collect::<Vec<Vec<Entity>>>()
|
||||
}
|
||||
|
||||
/// Extract full entities from a text performing entity chunking. Follows the algorithm for entities
|
||||
/// chunking described in [Erik F. Tjong Kim Sang, Jorn Veenstra, Representing Text Chunks](https://www.aclweb.org/anthology/E99-1023/)
|
||||
/// The proposed implementation is inspired by the [Python seqeval library](https://github.com/chakki-works/seqeval) (shared under MIT license).
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `input` - `&[&str]` Array of texts to extract entities from.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// * `Vec<Entity>` containing consolidated extracted entities
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```no_run
|
||||
/// # fn main() -> anyhow::Result<()> {
|
||||
/// # use rust_bert::pipelines::ner::NERModel;
|
||||
///
|
||||
/// let ner_model = NERModel::new(Default::default())?;
|
||||
/// let input = ["Asked John Smith about Acme Corp"];
|
||||
/// let output = ner_model.predict_full_entities(&input);
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
///
|
||||
/// Outputs:
|
||||
///
|
||||
/// Output: \
|
||||
/// ```no_run
|
||||
/// # use rust_bert::pipelines::question_answering::Answer;
|
||||
/// # use rust_bert::pipelines::ner::Entity;
|
||||
/// # let output =
|
||||
/// [[
|
||||
/// Entity {
|
||||
/// word: String::from("John Smith"),
|
||||
/// score: 0.9747,
|
||||
/// label: String::from("PER"),
|
||||
/// },
|
||||
/// Entity {
|
||||
/// word: String::from("Acme Corp"),
|
||||
/// score: 0.8847,
|
||||
/// label: String::from("I-LOC"),
|
||||
/// },
|
||||
/// ]]
|
||||
/// # ;
|
||||
/// ```
|
||||
///
|
||||
pub fn predict_full_entities(&self, input: &[&str]) -> Vec<Vec<Entity>> {
|
||||
let tokens = self.token_classification_model.predict(input, true, false);
|
||||
let mut entities: Vec<Vec<Entity>> = Vec::new();
|
||||
|
||||
for sequence_tokens in tokens {
|
||||
entities.push(Self::consolidate_entities(&sequence_tokens));
|
||||
}
|
||||
entities
|
||||
}
|
||||
|
||||
fn consolidate_entities(tokens: &[Token]) -> Vec<Entity> {
|
||||
let mut entities: Vec<Entity> = Vec::new();
|
||||
|
||||
let mut entity_builder = EntityBuilder::new();
|
||||
for (position, token) in tokens.iter().enumerate() {
|
||||
let tag = token.get_tag();
|
||||
let label = token.get_label();
|
||||
if let Some(entity) = entity_builder.handle_current_tag(tag, label, position, tokens) {
|
||||
entities.push(entity)
|
||||
}
|
||||
}
|
||||
if let Some(entity) = entity_builder.flush_and_reset(tokens.len(), tokens) {
|
||||
entities.push(entity);
|
||||
}
|
||||
entities
|
||||
}
|
||||
}
|
||||
|
||||
struct EntityBuilder<'a> {
|
||||
previous_node: Option<(usize, Tag, &'a str)>,
|
||||
}
|
||||
|
||||
impl<'a> EntityBuilder<'a> {
|
||||
fn new() -> Self {
|
||||
EntityBuilder {
|
||||
previous_node: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_current_tag(
|
||||
&mut self,
|
||||
tag: Tag,
|
||||
label: &'a str,
|
||||
position: usize,
|
||||
tokens: &[Token],
|
||||
) -> Option<Entity> {
|
||||
match tag {
|
||||
Tag::Outside => self.flush_and_reset(position, tokens),
|
||||
Tag::Begin | Tag::Single => {
|
||||
let entity = self.flush_and_reset(position, tokens);
|
||||
self.start_new(position, tag, label);
|
||||
entity
|
||||
}
|
||||
Tag::Inside | Tag::End => {
|
||||
if let Some((_, previous_tag, previous_label)) = self.previous_node {
|
||||
if (previous_tag == Tag::End)
|
||||
| (previous_tag == Tag::Single)
|
||||
| (previous_label != label)
|
||||
{
|
||||
let entity = self.flush_and_reset(position, tokens);
|
||||
self.start_new(position, tag, label);
|
||||
entity
|
||||
} else {
|
||||
None
|
||||
}
|
||||
} else {
|
||||
self.start_new(position, tag, label);
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn flush_and_reset(&mut self, position: usize, tokens: &[Token]) -> Option<Entity> {
|
||||
let entity = if let Some((start, _, label)) = self.previous_node {
|
||||
let entity_tokens = &tokens[start..position];
|
||||
Some(Entity {
|
||||
word: entity_tokens
|
||||
.iter()
|
||||
.map(|token| token.text.as_str())
|
||||
.collect::<Vec<&str>>()
|
||||
.join(" "),
|
||||
score: entity_tokens.iter().map(|token| token.score).product(),
|
||||
label: label.to_string(),
|
||||
})
|
||||
} else {
|
||||
None
|
||||
};
|
||||
self.previous_node = None;
|
||||
entity
|
||||
}
|
||||
|
||||
fn start_new(&mut self, position: usize, tag: Tag, label: &'a str) {
|
||||
self.previous_node = Some((position, tag, label))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||
enum Tag {
|
||||
Begin,
|
||||
Inside,
|
||||
Outside,
|
||||
End,
|
||||
Single,
|
||||
}
|
||||
|
||||
impl Token {
|
||||
fn get_tag(&self) -> Tag {
|
||||
match self.label.split('-').collect::<Vec<&str>>()[0] {
|
||||
"B" => Tag::Begin,
|
||||
"I" => Tag::Inside,
|
||||
"O" => Tag::Outside,
|
||||
"E" => Tag::End,
|
||||
"S" => Tag::Single,
|
||||
_ => panic!("Invalid tag encountered for token {:?}", self),
|
||||
}
|
||||
}
|
||||
|
||||
fn get_label(&self) -> &str {
|
||||
let split_label = self.label.split('-').collect::<Vec<&str>>();
|
||||
if split_label.len() > 1 {
|
||||
split_label[1]
|
||||
} else {
|
||||
""
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
|
@ -389,6 +389,34 @@ fn bert_pre_trained_ner() -> anyhow::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn bert_pre_trained_ner_full_entities() -> anyhow::Result<()> {
|
||||
// Set-up model
|
||||
let ner_model = NERModel::new(Default::default())?;
|
||||
|
||||
// Define input
|
||||
let input = ["Asked John Smith about Acme Corp", "Let's go to New York!"];
|
||||
|
||||
// Run model
|
||||
let output = ner_model.predict_full_entities(&input);
|
||||
|
||||
assert_eq!(output.len(), 2);
|
||||
|
||||
assert_eq!(output[0][0].word, "John Smith");
|
||||
assert!((output[0][0].score - 0.9872).abs() < 1e-4);
|
||||
assert_eq!(output[0][0].label, "PER");
|
||||
|
||||
assert_eq!(output[0][1].word, "Acme Corp");
|
||||
assert!((output[0][1].score - 0.9622).abs() < 1e-4);
|
||||
assert_eq!(output[0][1].label, "ORG");
|
||||
|
||||
assert_eq!(output[1][0].word, "New York");
|
||||
assert!((output[1][0].score - 0.9991).abs() < 1e-4);
|
||||
assert_eq!(output[1][0].label, "LOC");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn bert_question_answering() -> anyhow::Result<()> {
|
||||
// Set-up question answering model
|
||||
|
Loading…
Reference in New Issue
Block a user