mirror of
https://github.com/guillaume-be/rust-bert.git
synced 2024-11-10 02:12:30 +03:00
Merge pull request #33 from guillaume-be/token_entity_parsing
Token entity parsing
This commit is contained in:
commit
16753ea8fb
@ -30,7 +30,7 @@ all-tests = []
|
||||
features = [ "doc-only" ]
|
||||
|
||||
[dependencies]
|
||||
rust_tokenizers = "~3.0.0"
|
||||
rust_tokenizers = "~3.0.2"
|
||||
tch = "~0.1.7"
|
||||
serde_json = "1.0.51"
|
||||
serde = {version = "1.0.106", features = ["derive"]}
|
||||
|
@ -20,8 +20,8 @@ fn main() -> failure::Fallible<()> {
|
||||
|
||||
// Define input
|
||||
let input = [
|
||||
"My name is Amy. I live in Paris.",
|
||||
"Paris is a city in France."
|
||||
"My name is Amélie. I live in Москва.",
|
||||
"Chongqing is a city in China."
|
||||
];
|
||||
|
||||
// Run model
|
||||
|
43
examples/token_classification.rs
Normal file
43
examples/token_classification.rs
Normal file
@ -0,0 +1,43 @@
|
||||
// Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc.
|
||||
// Copyright 2019 Guillaume Becquin
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use rust_bert::pipelines::token_classification::{TokenClassificationModel, TokenClassificationConfig, LabelAggregationOption};
|
||||
use rust_bert::resources::{Resource, RemoteResource};
|
||||
use rust_bert::bert::{BertModelResources, BertVocabResources, BertConfigResources};
|
||||
use rust_bert::pipelines::common::ModelType;
|
||||
|
||||
fn main() -> failure::Fallible<()> {
|
||||
|
||||
// Load a configuration
|
||||
let config = TokenClassificationConfig::new(ModelType::Bert,
|
||||
Resource::Remote(RemoteResource::from_pretrained(BertModelResources::BERT_NER)),
|
||||
Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT_NER)),
|
||||
Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT_NER)),
|
||||
None, //merges resource only relevant with ModelType::Roberta
|
||||
false, //lowercase
|
||||
LabelAggregationOption::Mode,
|
||||
);
|
||||
|
||||
// Create the model
|
||||
let token_classification_model = TokenClassificationModel::new(config)?;
|
||||
let input = [
|
||||
"My name is Amélie. I live in Москва.",
|
||||
"Chongqing is a city in China."
|
||||
];
|
||||
let token_outputs = token_classification_model.predict(&input, true, false); //ignore_first_label = true (only returns the NER parts, ignoring first label O)
|
||||
|
||||
for token in token_outputs {
|
||||
println!("{:?}", token);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
@ -45,7 +45,7 @@
|
||||
//!# ;
|
||||
//! ```
|
||||
|
||||
use crate::pipelines::token_classification::{TokenClassificationModel,TokenClassificationConfig};
|
||||
use crate::pipelines::token_classification::{TokenClassificationModel, TokenClassificationConfig};
|
||||
|
||||
|
||||
#[derive(Debug)]
|
||||
@ -58,6 +58,7 @@ pub struct Entity {
|
||||
/// Entity label (e.g. ORG, LOC...)
|
||||
pub label: String,
|
||||
}
|
||||
|
||||
//type alias for some backward compatibility
|
||||
type NERConfig = TokenClassificationConfig;
|
||||
|
||||
@ -86,7 +87,7 @@ impl NERModel {
|
||||
///
|
||||
pub fn new(ner_config: NERConfig) -> failure::Fallible<NERModel> {
|
||||
let model = TokenClassificationModel::new(ner_config)?;
|
||||
Ok(NERModel { token_classification_model: model})
|
||||
Ok(NERModel { token_classification_model: model })
|
||||
}
|
||||
|
||||
/// Extract entities from a text
|
||||
@ -116,12 +117,16 @@ impl NERModel {
|
||||
/// ```
|
||||
///
|
||||
pub fn predict(&self, input: &[&str]) -> Vec<Entity> {
|
||||
self.token_classification_model.predict(input, true).into_iter().map(|token| {
|
||||
Entity {
|
||||
word: token.text,
|
||||
score: token.score,
|
||||
label: token.label,
|
||||
}
|
||||
}).collect()
|
||||
self.token_classification_model
|
||||
.predict(input, true, false)
|
||||
.into_iter()
|
||||
.filter(|token| token.label != "O")
|
||||
.map(|token| {
|
||||
Entity {
|
||||
word: token.text,
|
||||
score: token.score,
|
||||
label: token.label,
|
||||
}
|
||||
}).collect()
|
||||
}
|
||||
}
|
||||
|
@ -22,12 +22,14 @@
|
||||
//!# fn main() -> failure::Fallible<()> {
|
||||
//!
|
||||
//! //Load a configuration
|
||||
//! use rust_bert::pipelines::token_classification::LabelAggregationOption;
|
||||
//! let config = TokenClassificationConfig::new(ModelType::Bert,
|
||||
//! Resource::Remote(RemoteResource::from_pretrained(BertModelResources::BERT_NER)),
|
||||
//! Resource::Remote(RemoteResource::from_pretrained(BertVocabResources::BERT_NER)),
|
||||
//! Resource::Remote(RemoteResource::from_pretrained(BertConfigResources::BERT_NER)),
|
||||
//! None, //merges resource only relevant with ModelType::Roberta
|
||||
//! false, //lowercase
|
||||
//! LabelAggregationOption::Mode
|
||||
//! );
|
||||
//!
|
||||
//! //Create the model
|
||||
@ -37,25 +39,29 @@
|
||||
//! "My name is Amy. I live in Paris.",
|
||||
//! "Paris is a city in France."
|
||||
//! ];
|
||||
//! let output = token_classification_model.predict(&input, true); //ignore_first_label = true (only returns the NER parts, ignoring first label O)
|
||||
//! let output = token_classification_model.predict(&input, true, true); //ignore_first_label = true (only returns the NER parts, ignoring first label O)
|
||||
//!# Ok(())
|
||||
//!# }
|
||||
//! ```
|
||||
//! Output: \
|
||||
//! ```no_run
|
||||
//!# use rust_bert::pipelines::token_classification::Token;
|
||||
//! use rust_tokenizers::preprocessing::tokenizer::base_tokenizer::Mask::Special;
|
||||
//! use rust_tokenizers::preprocessing::tokenizer::base_tokenizer::{Offset, Mask};
|
||||
//!# let output =
|
||||
//! [
|
||||
//! Token { text: String::from("Amy"), score: 0.9986, label: String::from("I-PER"), sentence: 0, index: 0},
|
||||
//! Token { text: String::from("Paris"), score: 0.9985, label: String::from("I-LOC"), sentence: 0, index: 9},
|
||||
//! Token { text: String::from("Paris"), score: 0.9988, label: String::from("I-LOC"), sentence: 1, index: 1},
|
||||
//! Token { text: String::from("France"), score: 0.9993, label: String::from("I-LOC"), sentence: 1, index: 6}
|
||||
//! Token { text: String::from("[CLS]"), score: 0.9995001554489136, label: String::from("O"), label_index: 0, sentence: 0, index: 0, word_index: 0, offset: None, mask: Special },
|
||||
//! Token { text: String::from("My"), score: 0.9980450868606567, label: String::from("O"), label_index: 0, sentence: 0, index: 1, word_index: 1, offset: Some(Offset { begin: 0, end: 2 }), mask: Mask::None },
|
||||
//! Token { text: String::from("name"), score: 0.9995062351226807, label: String::from("O"), label_index: 0, sentence: 0, index: 2, word_index: 2, offset: Some(Offset { begin: 3, end: 7 }), mask: Mask::None },
|
||||
//! Token { text: String::from("is"), score: 0.9997343420982361, label: String::from("O"), label_index: 0, sentence: 0, index: 3, word_index: 3, offset: Some(Offset { begin: 8, end: 10 }), mask: Mask::None },
|
||||
//! Token { text: String::from("Amélie"), score: 0.9913727683112525, label: String::from("I-PER"), label_index: 4, sentence: 0, index: 4, word_index: 4, offset: Some(Offset { begin: 11, end: 17 }), mask: Mask::None }
|
||||
//! // ...
|
||||
//! ]
|
||||
//!# ;
|
||||
//! ```
|
||||
|
||||
use tch::nn::VarStore;
|
||||
use rust_tokenizers::preprocessing::tokenizer::base_tokenizer::{Tokenizer, TokenizedInput, TruncationStrategy};
|
||||
use rust_tokenizers::preprocessing::tokenizer::base_tokenizer::{Tokenizer, TokenizedInput, TruncationStrategy, Mask, Offset, ConsolidatableTokens, ConsolidatedTokenIterator, TokenTrait};
|
||||
use std::collections::HashMap;
|
||||
use tch::{Tensor, no_grad, Device};
|
||||
use tch::kind::Kind::Float;
|
||||
@ -63,12 +69,14 @@ use crate::bert::{BertForTokenClassification, BertModelResources, BertConfigReso
|
||||
use crate::roberta::RobertaForTokenClassification;
|
||||
use crate::distilbert::DistilBertForTokenClassification;
|
||||
use crate::common::resources::{Resource, RemoteResource, download_resource};
|
||||
use serde::{Serialize, Deserialize};
|
||||
use crate::pipelines::common::{ModelType, ConfigOption, TokenizerOption};
|
||||
use crate::electra::ElectraForTokenClassification;
|
||||
use itertools::Itertools;
|
||||
use std::cmp::min;
|
||||
use serde::{Serialize, Deserialize};
|
||||
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
/// # Token generated by a `TokenClassificationModel`
|
||||
pub struct Token {
|
||||
/// String representation of the Token
|
||||
@ -80,14 +88,59 @@ pub struct Token {
|
||||
/// Token label (e.g. ORG, LOC in case of NER)
|
||||
pub label: String,
|
||||
|
||||
/// Label index
|
||||
pub label_index: i64,
|
||||
|
||||
/// Sentence index
|
||||
#[serde(default)]
|
||||
pub sentence: usize,
|
||||
|
||||
/// Token position index
|
||||
pub index: u16,
|
||||
|
||||
/// Token word position index
|
||||
pub word_index: u16,
|
||||
|
||||
/// Token offsets
|
||||
pub offset: Option<Offset>,
|
||||
|
||||
/// Token mask
|
||||
pub mask: Mask,
|
||||
}
|
||||
|
||||
impl TokenTrait for Token {
|
||||
fn offset(&self) -> Option<Offset> {
|
||||
self.offset
|
||||
}
|
||||
|
||||
fn mask(&self) -> Mask {
|
||||
self.mask
|
||||
}
|
||||
|
||||
fn as_str(&self) -> &str {
|
||||
self.text.as_str()
|
||||
}
|
||||
}
|
||||
|
||||
impl ConsolidatableTokens<Token> for Vec<Token> {
|
||||
fn iter_consolidate_tokens(&self) -> ConsolidatedTokenIterator<Token> {
|
||||
ConsolidatedTokenIterator::new(self)
|
||||
}
|
||||
}
|
||||
|
||||
/// # Enum defining the label aggregation method for sub tokens
|
||||
/// Defines the behaviour for labels aggregation if the consolidation of sub-tokens is enabled.
|
||||
pub enum LabelAggregationOption {
|
||||
/// The label of the first sub token is assigned to the entire token
|
||||
First,
|
||||
/// The label of the last sub token is assigned to the entire token
|
||||
Last,
|
||||
/// The most frequent sub- token is assigned to the entire token
|
||||
Mode,
|
||||
/// The user can provide a function mapping a `&Vec<Token>` to a `(i64, String)` tuple corresponding to the label index, label String to return
|
||||
Custom(Box<dyn Fn(&[Token]) -> (i64, String)>),
|
||||
}
|
||||
|
||||
|
||||
/// # Configuration for TokenClassificationModel
|
||||
/// Contains information regarding the model to load and device to place the model on.
|
||||
pub struct TokenClassificationConfig {
|
||||
@ -105,6 +158,8 @@ pub struct TokenClassificationConfig {
|
||||
pub lower_case: bool,
|
||||
/// Device to place the model on (default: CUDA/GPU when available)
|
||||
pub device: Device,
|
||||
/// Sub-tokens aggregation method (default: `LabelAggregationOption::First`)
|
||||
pub label_aggregation_function: LabelAggregationOption,
|
||||
}
|
||||
|
||||
impl TokenClassificationConfig {
|
||||
@ -119,7 +174,13 @@ impl TokenClassificationConfig {
|
||||
/// * vocab - An optional `Resource` tuple (`Option<Resource>`) pointing to the tokenizer's merge file to load (e.g. merges.txt), needed only for Roberta.
|
||||
/// * lower_case - A `bool' indicating whether the tokenizer should lower case all input (in case of a lower-cased model)
|
||||
///
|
||||
pub fn new(model_type: ModelType, model_resource: Resource, config_resource: Resource, vocab_resource: Resource, merges_resource: Option<Resource>, lower_case: bool) -> TokenClassificationConfig {
|
||||
pub fn new(model_type: ModelType,
|
||||
model_resource: Resource,
|
||||
config_resource: Resource,
|
||||
vocab_resource: Resource,
|
||||
merges_resource: Option<Resource>,
|
||||
lower_case: bool,
|
||||
label_aggregation_function: LabelAggregationOption) -> TokenClassificationConfig {
|
||||
TokenClassificationConfig {
|
||||
model_type,
|
||||
model_resource,
|
||||
@ -128,6 +189,7 @@ impl TokenClassificationConfig {
|
||||
merges_resource,
|
||||
lower_case,
|
||||
device: Device::cuda_if_available(),
|
||||
label_aggregation_function,
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -143,6 +205,7 @@ impl Default for TokenClassificationConfig {
|
||||
merges_resource: None,
|
||||
lower_case: false,
|
||||
device: Device::cuda_if_available(),
|
||||
label_aggregation_function: LabelAggregationOption::First,
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -212,14 +275,13 @@ impl TokenClassificationOption {
|
||||
}
|
||||
}
|
||||
|
||||
/// Interface method to forward_t() of the particular models.
|
||||
pub fn forward_t(&self,
|
||||
input_ids: Option<Tensor>,
|
||||
mask: Option<Tensor>,
|
||||
token_type_ids: Option<Tensor>,
|
||||
position_ids: Option<Tensor>,
|
||||
input_embeds: Option<Tensor>,
|
||||
train: bool) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
|
||||
fn forward_t(&self,
|
||||
input_ids: Option<Tensor>,
|
||||
mask: Option<Tensor>,
|
||||
token_type_ids: Option<Tensor>,
|
||||
position_ids: Option<Tensor>,
|
||||
input_embeds: Option<Tensor>,
|
||||
train: bool) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
|
||||
match *self {
|
||||
Self::Bert(ref model) => model.forward_t(input_ids, mask, token_type_ids, position_ids, input_embeds, train),
|
||||
Self::DistilBert(ref model) => model.forward_t(input_ids, mask, input_embeds, train).expect("Error in distilbert forward_t"),
|
||||
@ -234,9 +296,9 @@ impl TokenClassificationOption {
|
||||
pub struct TokenClassificationModel {
|
||||
tokenizer: TokenizerOption,
|
||||
token_sequence_classifier: TokenClassificationOption,
|
||||
//e.g. BertForTokenClassification,
|
||||
label_mapping: HashMap<i64, String>,
|
||||
var_store: VarStore,
|
||||
label_aggregation_function: LabelAggregationOption,
|
||||
}
|
||||
|
||||
impl TokenClassificationModel {
|
||||
@ -267,6 +329,7 @@ impl TokenClassificationModel {
|
||||
None
|
||||
};
|
||||
let device = config.device;
|
||||
let label_aggregation_function = config.label_aggregation_function;
|
||||
|
||||
let tokenizer = TokenizerOption::from_file(config.model_type, vocab_path.to_str().unwrap(), merges_path.map(|path| path.to_str().unwrap()), config.lower_case);
|
||||
let mut var_store = VarStore::new(device);
|
||||
@ -274,10 +337,10 @@ impl TokenClassificationModel {
|
||||
let token_sequence_classifier = TokenClassificationOption::new(config.model_type, &var_store.root(), &model_config);
|
||||
let label_mapping = model_config.get_label_mapping();
|
||||
var_store.load(weights_path)?;
|
||||
Ok(TokenClassificationModel { tokenizer, token_sequence_classifier, label_mapping, var_store })
|
||||
Ok(TokenClassificationModel { tokenizer, token_sequence_classifier, label_mapping, var_store, label_aggregation_function })
|
||||
}
|
||||
|
||||
fn prepare_for_model(&self, input: Vec<&str>) -> Tensor {
|
||||
fn prepare_for_model(&self, input: Vec<&str>) -> (Vec<TokenizedInput>, Tensor) {
|
||||
let tokenized_input: Vec<TokenizedInput> = self.tokenizer.encode_list(input.to_vec(),
|
||||
128,
|
||||
&TruncationStrategy::LongestFirst,
|
||||
@ -293,18 +356,20 @@ impl TokenClassificationModel {
|
||||
map(|input|
|
||||
Tensor::of_slice(&(input))).
|
||||
collect::<Vec<_>>();
|
||||
Tensor::stack(tokenized_input_tensors.as_slice(), 0).to(self.var_store.device())
|
||||
(tokenized_input, Tensor::stack(tokenized_input_tensors.as_slice(), 0).to(self.var_store.device()))
|
||||
}
|
||||
|
||||
/// Extract entities from a text
|
||||
/// Classify tokens in a text sequence
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `input` - `&[&str]` Array of texts to extract entities from.
|
||||
/// * `consolidate_subtokens` - bool flag indicating if subtokens should be consolidated at the token level
|
||||
/// * `return_special` - bool flag indicating if labels for special tokens should be returned
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// * `Vec<Entity>` containing extracted entities
|
||||
/// * `Vec<Token>` containing Tokens with associated labels (for example POS tags)
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
@ -317,12 +382,12 @@ impl TokenClassificationModel {
|
||||
/// "My name is Amy. I live in Paris.",
|
||||
/// "Paris is a city in France."
|
||||
/// ];
|
||||
/// let output = ner_model.predict(&input, true);
|
||||
/// let output = ner_model.predict(&input, true, true);
|
||||
///# Ok(())
|
||||
///# }
|
||||
/// ```
|
||||
pub fn predict(&self, input: &[&str], ignore_first_label: bool) -> Vec<Token> {
|
||||
let input_tensor = self.prepare_for_model(input.to_vec());
|
||||
pub fn predict(&self, input: &[&str], consolidate_sub_tokens: bool, return_special: bool) -> Vec<Token> {
|
||||
let (tokenized_input, input_tensor) = self.prepare_for_model(input.to_vec());
|
||||
let (output, _, _) = no_grad(|| {
|
||||
self.token_sequence_classifier
|
||||
.forward_t(Some(input_tensor.copy()),
|
||||
@ -335,39 +400,147 @@ impl TokenClassificationModel {
|
||||
let output = output.detach().to(Device::Cpu);
|
||||
let score: Tensor = output.exp() / output.exp().sum1(&[-1], true, Float);
|
||||
let labels_idx = &score.argmax(-1, true);
|
||||
|
||||
let mut tokens: Vec<Token> = vec!();
|
||||
for sentence_idx in 0..labels_idx.size()[0] {
|
||||
let labels = labels_idx.get(sentence_idx);
|
||||
//let mut word_idx: u8 = 0;
|
||||
for position_idx in 0..labels.size()[0] {
|
||||
let label_id = labels.int64_value(&[position_idx]);
|
||||
let token = {
|
||||
let token_id = input_tensor.int64_value(&[sentence_idx, position_idx]);
|
||||
self.decode_token(token_id, label_id, &score, sentence_idx, position_idx) //, &mut word_idx)
|
||||
};
|
||||
if let Some(token) = token {
|
||||
if !ignore_first_label || label_id != 0 {
|
||||
tokens.push(token);
|
||||
}
|
||||
let sentence_tokens = &tokenized_input[sentence_idx as usize];
|
||||
let original_chars = input[sentence_idx as usize].chars().collect_vec();
|
||||
let mut word_idx: u16 = 0;
|
||||
for position_idx in 0..sentence_tokens.token_ids.len() {
|
||||
let mask = sentence_tokens.mask[position_idx];
|
||||
if (mask == Mask::Special) & (!return_special) {
|
||||
continue;
|
||||
}
|
||||
if !(mask == Mask::Continuation) & !
|
||||
(mask == Mask::InexactContinuation) {
|
||||
word_idx += 1;
|
||||
}
|
||||
let token = {
|
||||
self.decode_token(&original_chars, sentence_tokens, &input_tensor, &labels, &score, sentence_idx, position_idx as i64, word_idx - 1)
|
||||
};
|
||||
tokens.push(token);
|
||||
}
|
||||
}
|
||||
if consolidate_sub_tokens {
|
||||
self.consolidate_tokens(&mut tokens, &self.label_aggregation_function);
|
||||
}
|
||||
tokens
|
||||
}
|
||||
|
||||
fn decode_token(&self, token_id: i64, label_id: i64, score: &Tensor, sentence_idx: i64, position_idx: i64) -> Option<Token> {
|
||||
let text = match self.tokenizer {
|
||||
TokenizerOption::Bert(ref tokenizer) => Tokenizer::decode(tokenizer, vec!(token_id), false, false),
|
||||
TokenizerOption::Roberta(ref tokenizer) => Tokenizer::decode(tokenizer, vec!(token_id), false, false),
|
||||
fn decode_token(&self, original_sentence_chars: &Vec<char>, sentence_tokens: &TokenizedInput, input_tensor: &Tensor,
|
||||
labels: &Tensor, score: &Tensor, sentence_idx: i64, position_idx: i64, word_index: u16) -> Token {
|
||||
let label_id = labels.int64_value(&[position_idx as i64]);
|
||||
let token_id = input_tensor.int64_value(&[sentence_idx, position_idx as i64]);
|
||||
|
||||
let offsets = &sentence_tokens.token_offsets[position_idx as usize];
|
||||
|
||||
let text = match offsets {
|
||||
None => match self.tokenizer {
|
||||
TokenizerOption::Bert(ref tokenizer) => Tokenizer::decode(tokenizer, vec!(token_id), false, false),
|
||||
TokenizerOption::Roberta(ref tokenizer) => Tokenizer::decode(tokenizer, vec!(token_id), false, false),
|
||||
},
|
||||
Some(offsets) => {
|
||||
let (start_char, end_char) = (offsets.begin as usize, offsets.end as usize);
|
||||
let end_char = min(end_char, original_sentence_chars.len());
|
||||
let text = original_sentence_chars[start_char..end_char].iter().collect();
|
||||
text
|
||||
}
|
||||
};
|
||||
|
||||
Some(Token {
|
||||
Token {
|
||||
text,
|
||||
score: score.double_value(&[sentence_idx, position_idx, label_id]),
|
||||
label: self.label_mapping.get(&label_id).expect("Index out of vocabulary bounds.").to_owned(),
|
||||
label_index: label_id,
|
||||
sentence: sentence_idx as usize,
|
||||
index: position_idx as u16,
|
||||
})
|
||||
word_index,
|
||||
offset: offsets.to_owned(),
|
||||
mask: sentence_tokens.mask[position_idx as usize],
|
||||
}
|
||||
}
|
||||
|
||||
fn consolidate_tokens(&self, tokens: &mut Vec<Token>, label_aggregation_function: &LabelAggregationOption) {
|
||||
let mut tokens_to_replace = vec!();
|
||||
let mut token_iter = tokens.iter_consolidate_tokens();
|
||||
let mut cursor = 0;
|
||||
|
||||
while let Some(sub_tokens) = token_iter.next() {
|
||||
if sub_tokens.len() > 1 {
|
||||
let (label_index, label) = self.consolidate_labels(sub_tokens, label_aggregation_function);
|
||||
let sentence = (&sub_tokens[0]).sentence;
|
||||
let index = (&sub_tokens[0]).index;
|
||||
let word_index = (&sub_tokens[0]).word_index;
|
||||
let offset_start = match &sub_tokens.first().unwrap().offset {
|
||||
Some(offset) => Some(offset.begin),
|
||||
None => None
|
||||
};
|
||||
let offset_end = match &sub_tokens.last().unwrap().offset {
|
||||
Some(offset) => Some(offset.end),
|
||||
None => None
|
||||
};
|
||||
let offset = if offset_start.is_some() & offset_end.is_some() {
|
||||
Some(Offset::new(offset_start.unwrap(), offset_end.unwrap()))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let mut text = String::new();
|
||||
let mut score = 1f64;
|
||||
for current_sub_token in sub_tokens.into_iter() {
|
||||
text.push_str(current_sub_token.text.as_str());
|
||||
score *= if current_sub_token.label_index == label_index {
|
||||
current_sub_token.score
|
||||
} else {
|
||||
1.0 - current_sub_token.score
|
||||
};
|
||||
}
|
||||
let token = Token {
|
||||
text,
|
||||
score,
|
||||
label,
|
||||
label_index,
|
||||
sentence,
|
||||
index,
|
||||
word_index,
|
||||
offset,
|
||||
mask: Default::default(),
|
||||
};
|
||||
tokens_to_replace.push(((cursor, cursor + sub_tokens.len()), token));
|
||||
}
|
||||
cursor += sub_tokens.len();
|
||||
}
|
||||
for ((start, end), token) in tokens_to_replace.into_iter().rev() {
|
||||
tokens.splice(start..end, [token].iter().cloned());
|
||||
}
|
||||
}
|
||||
|
||||
fn consolidate_labels(&self, tokens: &[Token], aggregation: &LabelAggregationOption) -> (i64, String) {
|
||||
match aggregation {
|
||||
LabelAggregationOption::First => {
|
||||
let token = tokens.first().unwrap();
|
||||
(token.label_index, token.label.clone())
|
||||
}
|
||||
LabelAggregationOption::Last => {
|
||||
let token = tokens.last().unwrap();
|
||||
(token.label_index, token.label.clone())
|
||||
}
|
||||
LabelAggregationOption::Mode => {
|
||||
let counts = tokens
|
||||
.iter()
|
||||
.fold(
|
||||
HashMap::new(),
|
||||
|mut m, c| {
|
||||
*m.entry((c.label_index, c.label.as_str())).or_insert(0) += 1;
|
||||
m
|
||||
},
|
||||
);
|
||||
counts
|
||||
.into_iter()
|
||||
.max_by(|a, b| a.1.cmp(&b.1))
|
||||
.map(|((label_index, label), _)| (label_index, label.to_owned()))
|
||||
.unwrap()
|
||||
}
|
||||
LabelAggregationOption::Custom(function) => function(tokens)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user