Updated offsets fixing overlapping spans

This commit is contained in:
guillaume-be 2021-08-18 09:07:23 +02:00
parent 9cadc5d15f
commit 3ff5199376
2 changed files with 56 additions and 17 deletions

View File

@ -23,8 +23,8 @@ fn main() -> anyhow::Result<()> {
// Run model
let output = pos_model.predict(&input);
for pos_tag in output {
println!("{:?}", pos_tag);
for (pos, pos_tag) in output[0].iter().enumerate() {
println!("{} - {:?}", pos, pos_tag);
}
Ok(())

View File

@ -126,6 +126,7 @@ use crate::xlnet::XLNetForTokenClassification;
use rust_tokenizers::tokenizer::Tokenizer;
use rust_tokenizers::{
ConsolidatableTokens, ConsolidatedTokenIterator, Mask, Offset, TokenIdsWithOffsets, TokenTrait,
TokenizedInput,
};
use serde::{Deserialize, Serialize};
use std::borrow::Borrow;
@ -686,20 +687,13 @@ impl TokenClassificationModel {
.tokenizer
.build_input_with_special_tokens(sub_encoded_input, None);
// set halfway through the doc_stride to be false if the feature is not the first/last
let start_cutoff = if start_token > 0 { doc_stride / 2 } else { 0 };
let end_cutoff = if end_token < total_length {
doc_stride / 2
} else {
encoded_span.token_ids.len()
};
let mut reference_feature = vec![true; encoded_span.token_ids.len()];
reference_feature[..start_cutoff]
.iter_mut()
.for_each(|v| *v = false);
reference_feature[end_cutoff..]
.iter_mut()
.for_each(|v| *v = false);
let reference_feature = self.get_reference_feature_flag(
start_token,
end_token,
total_length,
doc_stride,
&encoded_span,
);
let feature = InputFeature {
input_ids: encoded_span.token_ids,
@ -717,6 +711,51 @@ impl TokenClassificationModel {
spans
}
fn get_reference_feature_flag(
&self,
start_token: usize,
end_token: usize,
total_length: usize,
doc_stride: usize,
encoded_span: &TokenizedInput,
) -> Vec<bool> {
// set halfway through the doc_stride to be false if the feature is not the first/last
let start_cutoff = if start_token > 0 {
let leading_special_tokens = {
let mut counter = 0;
let mut masks = encoded_span.mask.iter();
while masks.next().unwrap_or(&Mask::None) == &Mask::Special {
counter += 1;
}
counter
};
doc_stride / 2 + leading_special_tokens
} else {
0
};
let end_cutoff = if end_token < total_length {
let trailing_special_tokens = {
let mut counter = 0;
let mut masks = encoded_span.mask.iter().rev();
while masks.next().unwrap_or(&Mask::None) == &Mask::Special {
counter += 1;
}
counter
};
encoded_span.token_ids.len() - doc_stride / 2 - trailing_special_tokens
} else {
encoded_span.token_ids.len()
};
let mut reference_feature = vec![true; encoded_span.token_ids.len()];
reference_feature[..start_cutoff]
.iter_mut()
.for_each(|v| *v = false);
reference_feature[end_cutoff..]
.iter_mut()
.for_each(|v| *v = false);
reference_feature
}
/// Classify tokens in a text sequence
///
/// # Arguments
@ -814,7 +853,7 @@ impl TokenClassificationModel {
&score,
sentence_idx,
position_idx as i64,
word_idx - 1,
word_idx,
)
};
example_tokens_map