Removed word_index and continuation for now (and the decoding part responsible for it), will be reintroduced once the tokenizers provide the offset information (guillaume-be/rust-tokenizers#14), issue #29

This commit is contained in:
Maarten van Gompel 2020-05-06 22:39:37 +02:00
parent e044a9cbd9
commit b840cbf57a

View File

@ -44,10 +44,10 @@
//!# use rust_bert::pipelines::token_classification::Token;
//!# let output =
//! [
//! Token { text: String::from("Amy"), score: 0.9986, label: String::from("I-PER"), sentence: 0, index: 0, word_index: Some(3), continuation: false },
//! Token { text: String::from("Paris"), score: 0.9985, label: String::from("I-LOC"), sentence: 0, index: 9 , word_index: Some(8), continuation: false},
//! Token { text: String::from("Paris"), score: 0.9988, label: String::from("I-LOC"), sentence: 1, index: 1, word_index: Some(0), continuation: false},
//! Token { text: String::from("France"), score: 0.9993, label: String::from("I-LOC"), sentence: 1, index: 6, word_index: Some(5), continuation: false},
//! Token { text: String::from("Amy"), score: 0.9986, label: String::from("I-PER"), sentence: 0, index: 0}, //, word_index: Some(3), continuation: false },
//! Token { text: String::from("Paris"), score: 0.9985, label: String::from("I-LOC"), sentence: 0, index: 9}, // , word_index: Some(8), continuation: false},
//! Token { text: String::from("Paris"), score: 0.9988, label: String::from("I-LOC"), sentence: 1, index: 1}, //, word_index: Some(0), continuation: false},
//! Token { text: String::from("France"), score: 0.9993, label: String::from("I-LOC"), sentence: 1, index: 6} //, word_index: Some(5), continuation: false},
//! ]
//!# ;
//! ```
@ -87,12 +87,12 @@ pub struct Token {
/// Token position index
pub index: u16,
/// Word index, relative to the sentence index
pub word_index: Option<u8>,
// Word index, relative to the sentence index
//pub word_index: Option<u8>,
/// Continuation marker: marks this token as a continuation of the previous one
#[serde(default)]
pub continuation: bool,
// Continuation marker: marks this token as a continuation of the previous one
//#[serde(default)]
//pub continuation: bool,
}
/// # Configuration for TokenClassificationModel
@ -405,12 +405,12 @@ impl TokenClassificationModel {
let mut tokens: Vec<Token> = vec!();
for sentence_idx in 0..labels_idx.size()[0] {
let labels = labels_idx.get(sentence_idx);
let mut word_idx: u8 = 0;
//let mut word_idx: u8 = 0;
for position_idx in 0..labels.size()[0] {
let label_id = labels.int64_value(&[position_idx]);
let token = {
let token_id = input_tensor.int64_value(&[sentence_idx, position_idx]);
self.decode_token(token_id, label_id, &score, sentence_idx, position_idx, &mut word_idx)
self.decode_token(token_id, label_id, &score, sentence_idx, position_idx) //, &mut word_idx)
};
if let Some(token) = token {
if !ignore_first_label || label_id != 0 {
@ -422,61 +422,19 @@ impl TokenClassificationModel {
tokens
}
fn decode_token(&self, token_id: i64, label_id: i64, score: &Tensor, sentence_idx: i64, position_idx: i64, word_idx: &mut u8) -> Option<Token> {
let mut text = match self.tokenizer {
fn decode_token(&self, token_id: i64, label_id: i64, score: &Tensor, sentence_idx: i64, position_idx: i64) -> Option<Token> {
//, word_idx: &mut u8) -> Option<Token> {
let text = match self.tokenizer {
TokenizerOption::Bert(ref tokenizer) => Tokenizer::decode(tokenizer, vec!(token_id), false, false),
TokenizerOption::Roberta(ref tokenizer) => Tokenizer::decode(tokenizer, vec!(token_id), false, false),
};
let special_value: bool = match self.tokenizer {
//note: we don't count unk as a special value here, as we still want those in the output
TokenizerOption::Bert(_) => {
text == BertVocab::sep_value() ||
text == BertVocab::mask_value() ||
text == BertVocab::pad_value() ||
text == BertVocab::cls_value()
},
TokenizerOption::Roberta(_) => {
text == RobertaVocab::bos_value() ||
text == RobertaVocab::eos_value() ||
text == RobertaVocab::sep_value() ||
text == RobertaVocab::mask_value() ||
text == RobertaVocab::pad_value() ||
text == RobertaVocab::cls_value()
}
};
let continuation = !special_value && match self.tokenizer {
TokenizerOption::Bert(_) => if text.starts_with("##") && text.len() > 2 {
text.drain(..2); //remove the continuation tokens from the text
true
} else {
false
},
TokenizerOption::Roberta(_) => if text.starts_with(" ") {
text.drain(..1); //remove the leading space from the text
false
} else {
match text.find(|c: char| { !c.is_alphabetic() }) {
Some(0) => false,
_ => true,
}
}
};
if !continuation && !special_value && !text.is_empty() {
*word_idx += 1;
}
if special_value {
None
} else {
Some(Token {
text: text,
score: score.double_value(&[sentence_idx, position_idx, label_id]),
label: self.label_mapping.get(&label_id).expect("Index out of vocabulary bounds.").to_owned(),
sentence: sentence_idx as usize,
index: position_idx as u16,
word_index: Some(*word_idx - 1), //0 indexed
continuation: continuation,
})
}
Some(Token {
text: text,
score: score.double_value(&[sentence_idx, position_idx, label_id]),
label: self.label_mapping.get(&label_id).expect("Index out of vocabulary bounds.").to_owned(),
sentence: sentence_idx as usize,
index: position_idx as u16,
})
}
}