mirror of
https://github.com/guillaume-be/rust-bert.git
synced 2024-09-11 12:55:34 +03:00
Updated entity consolidation logic
This commit is contained in:
parent
067bac0d55
commit
28fc22c70f
@ -267,84 +267,99 @@ impl NERModel {
|
||||
let tokens = self.token_classification_model.predict(input, true, false);
|
||||
let mut entities: Vec<Vec<Entity>> = Vec::new();
|
||||
|
||||
for mut sequence_tokens in tokens {
|
||||
entities.push(Self::consolidate_entities(&mut sequence_tokens));
|
||||
for sequence_tokens in tokens {
|
||||
entities.push(Self::consolidate_entities(&sequence_tokens));
|
||||
}
|
||||
entities
|
||||
}
|
||||
|
||||
fn consolidate_entities(tokens: &mut Vec<Token>) -> Vec<Entity> {
|
||||
fn consolidate_entities(tokens: &[Token]) -> Vec<Entity> {
|
||||
let mut entities: Vec<Entity> = Vec::new();
|
||||
|
||||
let mut previous_tag = Tag::Outside;
|
||||
let mut previous_label = "";
|
||||
let mut current_tag: Tag;
|
||||
let mut current_label: &str;
|
||||
let mut begin_offset = 0;
|
||||
|
||||
tokens.push(Token {
|
||||
text: "X".into(),
|
||||
score: 1.0,
|
||||
label: "O-X".to_string(),
|
||||
label_index: 0,
|
||||
sentence: 0,
|
||||
index: 0,
|
||||
word_index: 0,
|
||||
offset: None,
|
||||
mask: Default::default(),
|
||||
});
|
||||
|
||||
let mut entity_builder = EntityBuilder::new();
|
||||
for (position, token) in tokens.iter().enumerate() {
|
||||
current_tag = token.get_tag();
|
||||
current_label = token.get_label();
|
||||
|
||||
if (previous_tag == Tag::End)
|
||||
| (previous_tag == Tag::Single)
|
||||
| matches!(
|
||||
(previous_tag, current_tag),
|
||||
(Tag::Begin, Tag::Begin)
|
||||
| (Tag::Begin, Tag::Outside)
|
||||
| (Tag::Begin, Tag::Single)
|
||||
| (Tag::Inside, Tag::Begin)
|
||||
| (Tag::Inside, Tag::Outside)
|
||||
| (Tag::Inside, Tag::Single)
|
||||
)
|
||||
| ((previous_label != current_label) & (previous_tag != Tag::Outside))
|
||||
{
|
||||
let entity_tokens = &tokens[begin_offset..position];
|
||||
entities.push(Entity {
|
||||
word: entity_tokens
|
||||
.iter()
|
||||
.map(|token| token.text.as_str())
|
||||
.collect::<Vec<&str>>()
|
||||
.join(" "),
|
||||
score: entity_tokens.iter().map(|token| token.score).product(),
|
||||
label: previous_label.to_string(),
|
||||
})
|
||||
let tag = token.get_tag();
|
||||
let label = token.get_label();
|
||||
if let Some(entity) = entity_builder.handle_current_tag(tag, label, position, tokens) {
|
||||
entities.push(entity)
|
||||
}
|
||||
|
||||
if (current_tag == Tag::Begin)
|
||||
| (current_tag == Tag::Single)
|
||||
| matches!(
|
||||
(previous_tag, current_tag),
|
||||
(Tag::End, Tag::End)
|
||||
| (Tag::Single, Tag::End)
|
||||
| (Tag::Outside, Tag::End)
|
||||
| (Tag::End, Tag::Inside)
|
||||
| (Tag::Single, Tag::Inside)
|
||||
| (Tag::Outside, Tag::Inside)
|
||||
)
|
||||
| ((previous_label != current_label) & (previous_tag != Tag::Outside))
|
||||
{
|
||||
begin_offset = position;
|
||||
};
|
||||
previous_tag = current_tag;
|
||||
previous_label = current_label;
|
||||
}
|
||||
if let Some(entity) = entity_builder.flush_and_reset(tokens.len(), tokens) {
|
||||
entities.push(entity);
|
||||
}
|
||||
entities
|
||||
}
|
||||
}
|
||||
|
||||
struct EntityBuilder<'a> {
|
||||
previous_node: Option<(usize, Tag, &'a str)>,
|
||||
}
|
||||
|
||||
impl<'a> EntityBuilder<'a> {
|
||||
fn new() -> Self {
|
||||
EntityBuilder {
|
||||
previous_node: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_current_tag(
|
||||
&mut self,
|
||||
tag: Tag,
|
||||
label: &'a str,
|
||||
position: usize,
|
||||
tokens: &[Token],
|
||||
) -> Option<Entity> {
|
||||
match tag {
|
||||
Tag::Outside => self.flush_and_reset(position, tokens),
|
||||
Tag::Begin | Tag::Single => {
|
||||
let entity = self.flush_and_reset(position, tokens);
|
||||
self.start_new(position, tag, label);
|
||||
entity
|
||||
}
|
||||
Tag::Inside | Tag::End => {
|
||||
if let Some((_, previous_tag, previous_label)) = self.previous_node {
|
||||
if (previous_tag == Tag::End)
|
||||
| (previous_tag == Tag::Single)
|
||||
| (previous_label != label)
|
||||
{
|
||||
let entity = self.flush_and_reset(position, tokens);
|
||||
self.start_new(position, tag, label);
|
||||
entity
|
||||
} else {
|
||||
None
|
||||
}
|
||||
} else {
|
||||
self.start_new(position, tag, label);
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn flush_and_reset(&mut self, position: usize, tokens: &[Token]) -> Option<Entity> {
|
||||
let entity = if let Some((start, _, label)) = self.previous_node {
|
||||
let entity_tokens = &tokens[start..position];
|
||||
Some(Entity {
|
||||
word: entity_tokens
|
||||
.iter()
|
||||
.map(|token| token.text.as_str())
|
||||
.collect::<Vec<&str>>()
|
||||
.join(" "),
|
||||
score: entity_tokens.iter().map(|token| token.score).product(),
|
||||
label: label.to_string(),
|
||||
})
|
||||
} else {
|
||||
None
|
||||
};
|
||||
self.previous_node = None;
|
||||
entity
|
||||
}
|
||||
|
||||
fn start_new(&mut self, position: usize, tag: Tag, label: &'a str) {
|
||||
self.previous_node = Some((position, tag, label))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||
enum Tag {
|
||||
Begin,
|
||||
|
Loading…
Reference in New Issue
Block a user