Updated token classification pipeline to use next tokenization features, reference to original text

This commit is contained in:
Guillaume B 2020-05-10 11:38:28 +02:00
parent 705489169f
commit 9d3a944051
3 changed files with 62 additions and 35 deletions

View File

@ -20,8 +20,8 @@ fn main() -> failure::Fallible<()> {
// Define input
let input = [
"My name is Amy. I live in Paris.",
"Paris is a city in France."
"My name is Amélie. I live in 北京市.",
"Chongqing is a city in China."
];
// Run model

View File

@ -45,7 +45,7 @@
//!# ;
//! ```
use crate::pipelines::token_classification::{TokenClassificationModel,TokenClassificationConfig};
use crate::pipelines::token_classification::{TokenClassificationModel, TokenClassificationConfig};
#[derive(Debug)]
@ -58,6 +58,7 @@ pub struct Entity {
/// Entity label (e.g. ORG, LOC...)
pub label: String,
}
//type alias for some backward compatibility
type NERConfig = TokenClassificationConfig;
@ -86,7 +87,7 @@ impl NERModel {
///
pub fn new(ner_config: NERConfig) -> failure::Fallible<NERModel> {
let model = TokenClassificationModel::new(ner_config)?;
Ok(NERModel { token_classification_model: model})
Ok(NERModel { token_classification_model: model })
}
/// Extract entities from a text
@ -116,12 +117,16 @@ impl NERModel {
/// ```
///
pub fn predict(&self, input: &[&str]) -> Vec<Entity> {
self.token_classification_model.predict(input, true).into_iter().map(|token| {
Entity {
word: token.text,
score: token.score,
label: token.label,
}
}).collect()
self.token_classification_model
.predict(input, true, false)
.into_iter()
.filter(|token| token.label != "O")
.map(|token| {
Entity {
word: token.text,
score: token.score,
label: token.label,
}
}).collect()
}
}

View File

@ -55,7 +55,7 @@
//! ```
use tch::nn::VarStore;
use rust_tokenizers::preprocessing::tokenizer::base_tokenizer::{Tokenizer, TokenizedInput, TruncationStrategy};
use rust_tokenizers::preprocessing::tokenizer::base_tokenizer::{Tokenizer, TokenizedInput, TruncationStrategy, Mask, Offset};
use std::collections::HashMap;
use tch::{Tensor, no_grad, Device};
use tch::kind::Kind::Float;
@ -63,12 +63,12 @@ use crate::bert::{BertForTokenClassification, BertModelResources, BertConfigReso
use crate::roberta::RobertaForTokenClassification;
use crate::distilbert::DistilBertForTokenClassification;
use crate::common::resources::{Resource, RemoteResource, download_resource};
use serde::{Serialize, Deserialize};
use crate::pipelines::common::{ModelType, ConfigOption, TokenizerOption};
use crate::electra::ElectraForTokenClassification;
use itertools::Itertools;
#[derive(Debug, Serialize, Deserialize)]
#[derive(Debug)]
/// # Token generated by a `TokenClassificationModel`
pub struct Token {
/// String representation of the Token
@ -81,11 +81,16 @@ pub struct Token {
pub label: String,
/// Sentence index
#[serde(default)]
pub sentence: usize,
/// Token position index
pub index: u16,
/// Token word position index
pub word_index: u16,
/// Token offsets
pub offset: Option<Offset>,
}
/// # Configuration for TokenClassificationModel
@ -277,7 +282,7 @@ impl TokenClassificationModel {
Ok(TokenClassificationModel { tokenizer, token_sequence_classifier, label_mapping, var_store })
}
fn prepare_for_model(&self, input: Vec<&str>) -> Tensor {
fn prepare_for_model(&self, input: Vec<&str>) -> (Vec<TokenizedInput>, Tensor) {
let tokenized_input: Vec<TokenizedInput> = self.tokenizer.encode_list(input.to_vec(),
128,
&TruncationStrategy::LongestFirst,
@ -293,7 +298,7 @@ impl TokenClassificationModel {
map(|input|
Tensor::of_slice(&(input))).
collect::<Vec<_>>();
Tensor::stack(tokenized_input_tensors.as_slice(), 0).to(self.var_store.device())
(tokenized_input, Tensor::stack(tokenized_input_tensors.as_slice(), 0).to(self.var_store.device()))
}
/// Extract entities from a text
@ -321,8 +326,8 @@ impl TokenClassificationModel {
///# Ok(())
///# }
/// ```
pub fn predict(&self, input: &[&str], ignore_first_label: bool) -> Vec<Token> {
let input_tensor = self.prepare_for_model(input.to_vec());
pub fn predict(&self, input: &[&str], return_tokens: bool, return_special: bool) -> Vec<Token> {
let (tokenized_input, input_tensor) = self.prepare_for_model(input.to_vec());
let (output, _, _) = no_grad(|| {
self.token_sequence_classifier
.forward_t(Some(input_tensor.copy()),
@ -335,39 +340,56 @@ impl TokenClassificationModel {
let output = output.detach().to(Device::Cpu);
let score: Tensor = output.exp() / output.exp().sum1(&[-1], true, Float);
let labels_idx = &score.argmax(-1, true);
let mut tokens: Vec<Token> = vec!();
for sentence_idx in 0..labels_idx.size()[0] {
let labels = labels_idx.get(sentence_idx);
//let mut word_idx: u8 = 0;
for position_idx in 0..labels.size()[0] {
let label_id = labels.int64_value(&[position_idx]);
let sentence_tokens = &tokenized_input[sentence_idx as usize];
let original_chars = input[sentence_idx as usize].chars().collect_vec();
let mut word_idx: u16 = 0;
for position_idx in 0..sentence_tokens.token_ids.len() {
let mask = sentence_tokens.mask[position_idx];
if (mask == Mask::Special) & (!return_special) {
continue;
}
let token = {
let token_id = input_tensor.int64_value(&[sentence_idx, position_idx]);
self.decode_token(token_id, label_id, &score, sentence_idx, position_idx) //, &mut word_idx)
self.decode_token(&original_chars, sentence_tokens, &input_tensor, &labels, &score, sentence_idx, position_idx as i64, word_idx)
};
if let Some(token) = token {
if !ignore_first_label || label_id != 0 {
tokens.push(token);
}
tokens.push(token);
if !(mask == Mask::Continuation) || !(mask == Mask::InexactContinuation) {
word_idx += 1;
}
}
}
tokens
}
fn decode_token(&self, token_id: i64, label_id: i64, score: &Tensor, sentence_idx: i64, position_idx: i64) -> Option<Token> {
let text = match self.tokenizer {
TokenizerOption::Bert(ref tokenizer) => Tokenizer::decode(tokenizer, vec!(token_id), false, false),
TokenizerOption::Roberta(ref tokenizer) => Tokenizer::decode(tokenizer, vec!(token_id), false, false),
fn decode_token(&self, original_sentence_chars: &Vec<char>, sentence_tokens: &TokenizedInput, input_tensor: &Tensor,
labels: &Tensor, score: &Tensor, sentence_idx: i64, position_idx: i64, word_index: u16) -> Token {
let label_id = labels.int64_value(&[position_idx as i64]);
let token_id = input_tensor.int64_value(&[sentence_idx, position_idx as i64]);
let offsets = &sentence_tokens.token_offsets[position_idx as usize];
let text = match offsets {
None => match self.tokenizer {
TokenizerOption::Bert(ref tokenizer) => Tokenizer::decode(tokenizer, vec!(token_id), false, false),
TokenizerOption::Roberta(ref tokenizer) => Tokenizer::decode(tokenizer, vec!(token_id), false, false),
},
Some(offsets) => {
let (start_char, end_char) = (offsets.begin as usize, offsets.end as usize);
let text = original_sentence_chars[start_char..end_char].iter().collect();
text
}
};
Some(Token {
Token {
text,
score: score.double_value(&[sentence_idx, position_idx, label_id]),
label: self.label_mapping.get(&label_id).expect("Index out of vocabulary bounds.").to_owned(),
sentence: sentence_idx as usize,
index: position_idx as u16,
})
word_index,
offset: offsets.to_owned(),
}
}
}