Updated entity consolidation logic

This commit is contained in:
Guillaume Becquin 2021-11-22 15:37:34 +01:00
parent 067bac0d55
commit 28fc22c70f

View File

@ -267,84 +267,99 @@ impl NERModel {
let tokens = self.token_classification_model.predict(input, true, false);
let mut entities: Vec<Vec<Entity>> = Vec::new();
for mut sequence_tokens in tokens {
entities.push(Self::consolidate_entities(&mut sequence_tokens));
for sequence_tokens in tokens {
entities.push(Self::consolidate_entities(&sequence_tokens));
}
entities
}
fn consolidate_entities(tokens: &mut Vec<Token>) -> Vec<Entity> {
fn consolidate_entities(tokens: &[Token]) -> Vec<Entity> {
let mut entities: Vec<Entity> = Vec::new();
let mut previous_tag = Tag::Outside;
let mut previous_label = "";
let mut current_tag: Tag;
let mut current_label: &str;
let mut begin_offset = 0;
tokens.push(Token {
text: "X".into(),
score: 1.0,
label: "O-X".to_string(),
label_index: 0,
sentence: 0,
index: 0,
word_index: 0,
offset: None,
mask: Default::default(),
});
let mut entity_builder = EntityBuilder::new();
for (position, token) in tokens.iter().enumerate() {
current_tag = token.get_tag();
current_label = token.get_label();
if (previous_tag == Tag::End)
| (previous_tag == Tag::Single)
| matches!(
(previous_tag, current_tag),
(Tag::Begin, Tag::Begin)
| (Tag::Begin, Tag::Outside)
| (Tag::Begin, Tag::Single)
| (Tag::Inside, Tag::Begin)
| (Tag::Inside, Tag::Outside)
| (Tag::Inside, Tag::Single)
)
| ((previous_label != current_label) & (previous_tag != Tag::Outside))
{
let entity_tokens = &tokens[begin_offset..position];
entities.push(Entity {
word: entity_tokens
.iter()
.map(|token| token.text.as_str())
.collect::<Vec<&str>>()
.join(" "),
score: entity_tokens.iter().map(|token| token.score).product(),
label: previous_label.to_string(),
})
let tag = token.get_tag();
let label = token.get_label();
if let Some(entity) = entity_builder.handle_current_tag(tag, label, position, tokens) {
entities.push(entity)
}
if (current_tag == Tag::Begin)
| (current_tag == Tag::Single)
| matches!(
(previous_tag, current_tag),
(Tag::End, Tag::End)
| (Tag::Single, Tag::End)
| (Tag::Outside, Tag::End)
| (Tag::End, Tag::Inside)
| (Tag::Single, Tag::Inside)
| (Tag::Outside, Tag::Inside)
)
| ((previous_label != current_label) & (previous_tag != Tag::Outside))
{
begin_offset = position;
};
previous_tag = current_tag;
previous_label = current_label;
}
if let Some(entity) = entity_builder.flush_and_reset(tokens.len(), tokens) {
entities.push(entity);
}
entities
}
}
struct EntityBuilder<'a> {
previous_node: Option<(usize, Tag, &'a str)>,
}
impl<'a> EntityBuilder<'a> {
fn new() -> Self {
EntityBuilder {
previous_node: None,
}
}
fn handle_current_tag(
&mut self,
tag: Tag,
label: &'a str,
position: usize,
tokens: &[Token],
) -> Option<Entity> {
match tag {
Tag::Outside => self.flush_and_reset(position, tokens),
Tag::Begin | Tag::Single => {
let entity = self.flush_and_reset(position, tokens);
self.start_new(position, tag, label);
entity
}
Tag::Inside | Tag::End => {
if let Some((_, previous_tag, previous_label)) = self.previous_node {
if (previous_tag == Tag::End)
| (previous_tag == Tag::Single)
| (previous_label != label)
{
let entity = self.flush_and_reset(position, tokens);
self.start_new(position, tag, label);
entity
} else {
None
}
} else {
self.start_new(position, tag, label);
None
}
}
}
}
fn flush_and_reset(&mut self, position: usize, tokens: &[Token]) -> Option<Entity> {
let entity = if let Some((start, _, label)) = self.previous_node {
let entity_tokens = &tokens[start..position];
Some(Entity {
word: entity_tokens
.iter()
.map(|token| token.text.as_str())
.collect::<Vec<&str>>()
.join(" "),
score: entity_tokens.iter().map(|token| token.score).product(),
label: label.to_string(),
})
} else {
None
};
self.previous_node = None;
entity
}
fn start_new(&mut self, position: usize, tag: Tag, label: &'a str) {
self.previous_node = Some((position, tag, label))
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
enum Tag {
Begin,