mirror of
https://github.com/guillaume-be/rust-bert.git
synced 2024-10-26 14:07:25 +03:00
Updated offsets fixing overlapping spans
This commit is contained in:
parent
9cadc5d15f
commit
3ff5199376
@ -23,8 +23,8 @@ fn main() -> anyhow::Result<()> {
|
||||
|
||||
// Run model
|
||||
let output = pos_model.predict(&input);
|
||||
for pos_tag in output {
|
||||
println!("{:?}", pos_tag);
|
||||
for (pos, pos_tag) in output[0].iter().enumerate() {
|
||||
println!("{} - {:?}", pos, pos_tag);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
|
@ -126,6 +126,7 @@ use crate::xlnet::XLNetForTokenClassification;
|
||||
use rust_tokenizers::tokenizer::Tokenizer;
|
||||
use rust_tokenizers::{
|
||||
ConsolidatableTokens, ConsolidatedTokenIterator, Mask, Offset, TokenIdsWithOffsets, TokenTrait,
|
||||
TokenizedInput,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::borrow::Borrow;
|
||||
@ -686,20 +687,13 @@ impl TokenClassificationModel {
|
||||
.tokenizer
|
||||
.build_input_with_special_tokens(sub_encoded_input, None);
|
||||
|
||||
// set halfway through the doc_stride to be false if the feature is not the first/last
|
||||
let start_cutoff = if start_token > 0 { doc_stride / 2 } else { 0 };
|
||||
let end_cutoff = if end_token < total_length {
|
||||
doc_stride / 2
|
||||
} else {
|
||||
encoded_span.token_ids.len()
|
||||
};
|
||||
let mut reference_feature = vec![true; encoded_span.token_ids.len()];
|
||||
reference_feature[..start_cutoff]
|
||||
.iter_mut()
|
||||
.for_each(|v| *v = false);
|
||||
reference_feature[end_cutoff..]
|
||||
.iter_mut()
|
||||
.for_each(|v| *v = false);
|
||||
let reference_feature = self.get_reference_feature_flag(
|
||||
start_token,
|
||||
end_token,
|
||||
total_length,
|
||||
doc_stride,
|
||||
&encoded_span,
|
||||
);
|
||||
|
||||
let feature = InputFeature {
|
||||
input_ids: encoded_span.token_ids,
|
||||
@ -717,6 +711,51 @@ impl TokenClassificationModel {
|
||||
spans
|
||||
}
|
||||
|
||||
fn get_reference_feature_flag(
|
||||
&self,
|
||||
start_token: usize,
|
||||
end_token: usize,
|
||||
total_length: usize,
|
||||
doc_stride: usize,
|
||||
encoded_span: &TokenizedInput,
|
||||
) -> Vec<bool> {
|
||||
// set halfway through the doc_stride to be false if the feature is not the first/last
|
||||
let start_cutoff = if start_token > 0 {
|
||||
let leading_special_tokens = {
|
||||
let mut counter = 0;
|
||||
let mut masks = encoded_span.mask.iter();
|
||||
while masks.next().unwrap_or(&Mask::None) == &Mask::Special {
|
||||
counter += 1;
|
||||
}
|
||||
counter
|
||||
};
|
||||
doc_stride / 2 + leading_special_tokens
|
||||
} else {
|
||||
0
|
||||
};
|
||||
let end_cutoff = if end_token < total_length {
|
||||
let trailing_special_tokens = {
|
||||
let mut counter = 0;
|
||||
let mut masks = encoded_span.mask.iter().rev();
|
||||
while masks.next().unwrap_or(&Mask::None) == &Mask::Special {
|
||||
counter += 1;
|
||||
}
|
||||
counter
|
||||
};
|
||||
encoded_span.token_ids.len() - doc_stride / 2 - trailing_special_tokens
|
||||
} else {
|
||||
encoded_span.token_ids.len()
|
||||
};
|
||||
let mut reference_feature = vec![true; encoded_span.token_ids.len()];
|
||||
reference_feature[..start_cutoff]
|
||||
.iter_mut()
|
||||
.for_each(|v| *v = false);
|
||||
reference_feature[end_cutoff..]
|
||||
.iter_mut()
|
||||
.for_each(|v| *v = false);
|
||||
reference_feature
|
||||
}
|
||||
|
||||
/// Classify tokens in a text sequence
|
||||
///
|
||||
/// # Arguments
|
||||
@ -814,7 +853,7 @@ impl TokenClassificationModel {
|
||||
&score,
|
||||
sentence_idx,
|
||||
position_idx as i64,
|
||||
word_idx - 1,
|
||||
word_idx,
|
||||
)
|
||||
};
|
||||
example_tokens_map
|
||||
|
Loading…
Reference in New Issue
Block a user