mirror of
https://github.com/guillaume-be/rust-bert.git
synced 2024-09-20 00:57:43 +03:00
Removed word_index and continuation for now (and the decoding part responsible for it), will be reintroduced once the tokenizers provide the offset information (guillaume-be/rust-tokenizers#14), issue #29
This commit is contained in:
parent
e044a9cbd9
commit
b840cbf57a
@ -44,10 +44,10 @@
|
||||
//!# use rust_bert::pipelines::token_classification::Token;
|
||||
//!# let output =
|
||||
//! [
|
||||
//! Token { text: String::from("Amy"), score: 0.9986, label: String::from("I-PER"), sentence: 0, index: 0, word_index: Some(3), continuation: false },
|
||||
//! Token { text: String::from("Paris"), score: 0.9985, label: String::from("I-LOC"), sentence: 0, index: 9 , word_index: Some(8), continuation: false},
|
||||
//! Token { text: String::from("Paris"), score: 0.9988, label: String::from("I-LOC"), sentence: 1, index: 1, word_index: Some(0), continuation: false},
|
||||
//! Token { text: String::from("France"), score: 0.9993, label: String::from("I-LOC"), sentence: 1, index: 6, word_index: Some(5), continuation: false},
|
||||
//! Token { text: String::from("Amy"), score: 0.9986, label: String::from("I-PER"), sentence: 0, index: 0}, //, word_index: Some(3), continuation: false },
|
||||
//! Token { text: String::from("Paris"), score: 0.9985, label: String::from("I-LOC"), sentence: 0, index: 9}, // , word_index: Some(8), continuation: false},
|
||||
//! Token { text: String::from("Paris"), score: 0.9988, label: String::from("I-LOC"), sentence: 1, index: 1}, //, word_index: Some(0), continuation: false},
|
||||
//! Token { text: String::from("France"), score: 0.9993, label: String::from("I-LOC"), sentence: 1, index: 6} //, word_index: Some(5), continuation: false},
|
||||
//! ]
|
||||
//!# ;
|
||||
//! ```
|
||||
@ -87,12 +87,12 @@ pub struct Token {
|
||||
/// Token position index
|
||||
pub index: u16,
|
||||
|
||||
/// Word index, relative to the sentence index
|
||||
pub word_index: Option<u8>,
|
||||
// Word index, relative to the sentence index
|
||||
//pub word_index: Option<u8>,
|
||||
|
||||
/// Continuation marker: marks this token as a continuation of the previous one
|
||||
#[serde(default)]
|
||||
pub continuation: bool,
|
||||
// Continuation marker: marks this token as a continuation of the previous one
|
||||
//#[serde(default)]
|
||||
//pub continuation: bool,
|
||||
}
|
||||
|
||||
/// # Configuration for TokenClassificationModel
|
||||
@ -405,12 +405,12 @@ impl TokenClassificationModel {
|
||||
let mut tokens: Vec<Token> = vec!();
|
||||
for sentence_idx in 0..labels_idx.size()[0] {
|
||||
let labels = labels_idx.get(sentence_idx);
|
||||
let mut word_idx: u8 = 0;
|
||||
//let mut word_idx: u8 = 0;
|
||||
for position_idx in 0..labels.size()[0] {
|
||||
let label_id = labels.int64_value(&[position_idx]);
|
||||
let token = {
|
||||
let token_id = input_tensor.int64_value(&[sentence_idx, position_idx]);
|
||||
self.decode_token(token_id, label_id, &score, sentence_idx, position_idx, &mut word_idx)
|
||||
self.decode_token(token_id, label_id, &score, sentence_idx, position_idx) //, &mut word_idx)
|
||||
};
|
||||
if let Some(token) = token {
|
||||
if !ignore_first_label || label_id != 0 {
|
||||
@ -422,61 +422,19 @@ impl TokenClassificationModel {
|
||||
tokens
|
||||
}
|
||||
|
||||
fn decode_token(&self, token_id: i64, label_id: i64, score: &Tensor, sentence_idx: i64, position_idx: i64, word_idx: &mut u8) -> Option<Token> {
|
||||
let mut text = match self.tokenizer {
|
||||
fn decode_token(&self, token_id: i64, label_id: i64, score: &Tensor, sentence_idx: i64, position_idx: i64) -> Option<Token> {
|
||||
//, word_idx: &mut u8) -> Option<Token> {
|
||||
let text = match self.tokenizer {
|
||||
TokenizerOption::Bert(ref tokenizer) => Tokenizer::decode(tokenizer, vec!(token_id), false, false),
|
||||
TokenizerOption::Roberta(ref tokenizer) => Tokenizer::decode(tokenizer, vec!(token_id), false, false),
|
||||
};
|
||||
let special_value: bool = match self.tokenizer {
|
||||
//note: we don't count unk as a special value here, as we still want those in the output
|
||||
TokenizerOption::Bert(_) => {
|
||||
text == BertVocab::sep_value() ||
|
||||
text == BertVocab::mask_value() ||
|
||||
text == BertVocab::pad_value() ||
|
||||
text == BertVocab::cls_value()
|
||||
},
|
||||
TokenizerOption::Roberta(_) => {
|
||||
text == RobertaVocab::bos_value() ||
|
||||
text == RobertaVocab::eos_value() ||
|
||||
text == RobertaVocab::sep_value() ||
|
||||
text == RobertaVocab::mask_value() ||
|
||||
text == RobertaVocab::pad_value() ||
|
||||
text == RobertaVocab::cls_value()
|
||||
}
|
||||
};
|
||||
let continuation = !special_value && match self.tokenizer {
|
||||
TokenizerOption::Bert(_) => if text.starts_with("##") && text.len() > 2 {
|
||||
text.drain(..2); //remove the continuation tokens from the text
|
||||
true
|
||||
} else {
|
||||
false
|
||||
},
|
||||
TokenizerOption::Roberta(_) => if text.starts_with(" ") {
|
||||
text.drain(..1); //remove the leading space from the text
|
||||
false
|
||||
} else {
|
||||
match text.find(|c: char| { !c.is_alphabetic() }) {
|
||||
Some(0) => false,
|
||||
_ => true,
|
||||
}
|
||||
}
|
||||
};
|
||||
if !continuation && !special_value && !text.is_empty() {
|
||||
*word_idx += 1;
|
||||
}
|
||||
|
||||
if special_value {
|
||||
None
|
||||
} else {
|
||||
Some(Token {
|
||||
text: text,
|
||||
score: score.double_value(&[sentence_idx, position_idx, label_id]),
|
||||
label: self.label_mapping.get(&label_id).expect("Index out of vocabulary bounds.").to_owned(),
|
||||
sentence: sentence_idx as usize,
|
||||
index: position_idx as u16,
|
||||
word_index: Some(*word_idx - 1), //0 indexed
|
||||
continuation: continuation,
|
||||
})
|
||||
}
|
||||
Some(Token {
|
||||
text: text,
|
||||
score: score.double_value(&[sentence_idx, position_idx, label_id]),
|
||||
label: self.label_mapping.get(&label_id).expect("Index out of vocabulary bounds.").to_owned(),
|
||||
sentence: sentence_idx as usize,
|
||||
index: position_idx as u16,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user