Added sub-token consolidation

This commit is contained in:
Guillaume B 2020-05-10 12:05:13 +02:00
parent 9d3a944051
commit 1a9e315edf
2 changed files with 64 additions and 3 deletions

View File

@ -20,7 +20,7 @@ fn main() -> failure::Fallible<()> {
// Define input
let input = [
"My name is Amélie. I live in 北京市.",
"My name is Amélie. I live in Москва.",
"Chongqing is a city in China."
];

View File

@ -68,7 +68,7 @@ use crate::electra::ElectraForTokenClassification;
use itertools::Itertools;
#[derive(Debug)]
#[derive(Debug, Clone)]
/// # Token generated by a `TokenClassificationModel`
pub struct Token {
/// String representation of the Token
@ -91,6 +91,9 @@ pub struct Token {
/// Token offsets
pub offset: Option<Offset>,
/// Token mask
pub mask: Mask,
}
/// # Configuration for TokenClassificationModel
@ -326,7 +329,7 @@ impl TokenClassificationModel {
///# Ok(())
///# }
/// ```
pub fn predict(&self, input: &[&str], return_tokens: bool, return_special: bool) -> Vec<Token> {
pub fn predict(&self, input: &[&str], consolidate_subtokens: bool, return_special: bool) -> Vec<Token> {
let (tokenized_input, input_tensor) = self.prepare_for_model(input.to_vec());
let (output, _, _) = no_grad(|| {
self.token_sequence_classifier
@ -360,6 +363,9 @@ impl TokenClassificationModel {
}
}
}
if consolidate_subtokens {
tokens = self.consolidate_tokens(tokens);
}
tokens
}
@ -390,6 +396,61 @@ impl TokenClassificationModel {
index: position_idx as u16,
word_index,
offset: offsets.to_owned(),
mask: sentence_tokens.mask[position_idx as usize],
}
}
fn consolidate_tokens(&self, tokens: Vec<Token>) -> Vec<Token> {
let mut consolidated_tokens: Vec<Token> = vec!();
let mut current_token: Vec<Token> = vec!();
for sub_token in tokens.iter() {
if (sub_token.mask != Mask::Continuation) & (sub_token.mask != Mask::InexactContinuation) {
match current_token.len() {
0 => {}
1 => consolidated_tokens.push(current_token[0].clone()),
_ => {
let mut text = String::new();
let mut score = 1f64;
let label: String = (&current_token[0]).label.clone();
let sentence = (&current_token[0]).sentence;
let index = (&current_token[0]).index;
let word_index = (&current_token[0]).word_index;
let offset_start = match &current_token.first().unwrap().offset {
Some(offset) => Some(offset.begin),
None => None
};
let offset_end = match &current_token.last().unwrap().offset {
Some(offset) => Some(offset.end),
None => None
};
let offset = if offset_start.is_some() & offset_end.is_some() {
Some(Offset::new(offset_start.unwrap(), offset_end.unwrap()))
} else {
None
};
for current_sub_token in current_token.into_iter() {
text.push_str(current_sub_token.text.as_str());
score *= current_sub_token.score;
}
consolidated_tokens.push(
Token {
text,
score,
label,
sentence,
index,
word_index,
offset,
mask: Default::default(),
}
)
}
};
current_token = vec!(sub_token.clone());
} else {
current_token.push(sub_token.clone());
}
}
consolidated_tokens
}
}