- reset cache for BART model

- sentiment classifier optimization
- question answering pipeline optimization
This commit is contained in:
Guillaume B 2020-04-06 13:26:50 +02:00
parent 03f642fc68
commit 793118bc94
8 changed files with 178 additions and 72 deletions

View File

@ -37,4 +37,5 @@ serde = {version = "1.0.104", features = ["derive"]}
failure = "0.1.6"
dirs = "2.0"
itertools = "0.9.0"
ordered-float = "1.0.2"
ordered-float = "1.0.2"
csv = "1.1.3"

View File

@ -39,7 +39,7 @@ fn main() -> failure::Fallible<()> {
// Set-up masked LM model
let device = Device::cuda_if_available();
let generate_config = GenerateConfig {
max_length: 20,
max_length: 30,
do_sample: true,
num_beams: 5,
temperature: 1.1,

64
examples/sst2.rs Normal file
View File

@ -0,0 +1,64 @@
// Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc.
// Copyright 2019 Guillaume Becquin
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
extern crate failure;
extern crate dirs;
use std::path::PathBuf;
use tch::Device;
use failure::err_msg;
use rust_bert::pipelines::sentiment::{SentimentClassifier, ss2_processor};
use std::env;
fn main() -> failure::Fallible<()> {
// Resources paths
let mut home: PathBuf = dirs::home_dir().unwrap();
home.push("rustbert");
home.push("distilbert_sst2");
let config_path = &home.as_path().join("config.json");
let vocab_path = &home.as_path().join("vocab.txt");
let weights_path = &home.as_path().join("model.ot");
if !config_path.is_file() | !vocab_path.is_file() | !weights_path.is_file() {
return Err(
err_msg("Could not find required resources to run example. \
Please run ../utils/download_dependencies_sst2_sentiment.py \
in a Python environment with dependencies listed in ../requirements.txt"));
}
// Set-up classifier
let device = Device::cuda_if_available();
let sentiment_classifier = SentimentClassifier::new(vocab_path,
config_path,
weights_path, device)?;
// Define input
let mut sst2_path = PathBuf::from(env::var("SST2_PATH")
.expect("Please set the \"squad_dataset\" environment variable pointing to the SQuAD dataset folder"));
sst2_path.push("train.tsv");
let inputs = ss2_processor(sst2_path).unwrap();
// Run model
let batch_size = 64;
let mut output = vec!();
for batch in inputs.chunks(batch_size) {
output.push(sentiment_classifier.predict(batch.iter().map(|v| v.as_str()).collect::<Vec<&str>>().as_slice()));
}
let mut flat_outputs = vec!();
for batch_output in output.iter_mut() {
flat_outputs.append(batch_output);
}
println!("{:?}", flat_outputs.len());
Ok(())
}

View File

@ -71,7 +71,10 @@ about exoplanets like K2-18b."];
// Credits: WikiNews, CC BY 2.5 license (https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b)
let output = summarization_model.summarize(&input);
for sentence in output {
println!("{:?}", sentence);
}
let output = summarization_model.summarize(&input);
for sentence in output {
println!("{:?}", sentence);
}

View File

@ -39,6 +39,12 @@ impl LayerState {
self.prev_key_padding_mask = Some(self.prev_key_padding_mask.as_ref().unwrap().index_select(0, new_indices));
}
}
pub(crate) fn reset_cache(&mut self) {
self.prev_key = None;
self.prev_value = None;
self.prev_key_padding_mask = None;
}
}

View File

@ -505,6 +505,13 @@ impl PrivateLanguageGenerator<BartForConditionalGeneration, RobertaVocab, Robert
};
(None, encoder_outputs)
}
fn reset_cache(&mut self) {
for layer in self.get_model().get_base_model().get_decoder().get_layers() {
layer.get_self_attention().prev_state.as_mut().unwrap().reset_cache();
layer.get_encoder_attention().prev_state.as_mut().unwrap().reset_cache();
};
}
}
impl LanguageGenerator<BartForConditionalGeneration, RobertaVocab, RobertaTokenizer> for BartGenerator {}
@ -1013,6 +1020,8 @@ mod private_generation_utils {
None => (None, None)
}
}
fn reset_cache(&mut self) {}
}
}
@ -1165,7 +1174,7 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>: PrivateL
(input_ids, attention_mask)
};
self.reset_cache();
let decoded = no_grad(|| {
if num_beams > 1 {
self.generate_beam_search(input_ids, encoder_outputs, cur_len, min_length as i64, max_length as i64, do_sample, early_stopping, temperature, top_k as i64, top_p, repetition_penalty,

View File

@ -311,63 +311,66 @@ impl QuestionAnsweringModel {
for (start, end) in batch_indices {
let batch_features = &features[start..end];
let mut input_ids: Vec<Tensor> = vec!();
let mut attention_masks: Vec<Tensor> = vec!();
for feature in batch_features {
input_ids.push(Tensor::of_slice(&feature.input_ids));
attention_masks.push(Tensor::of_slice(&feature.attention_mask));
}
let input_ids = Tensor::stack(&input_ids, 0).to(self.var_store.device());
let attention_masks = Tensor::stack(&attention_masks, 0).to(self.var_store.device());
let mut input_ids = Vec::with_capacity(batch_features.len());
let mut attention_masks = Vec::with_capacity(batch_features.len());
let (start_logits, end_logits, _, _) = no_grad(|| {
self.distilbert_qa.forward_t(Some(input_ids), Some(attention_masks), None, false).unwrap()
});
let start_logits = start_logits.to(Device::Cpu);
let end_logits = end_logits.to(Device::Cpu);
let example_index_to_feature_end_position: Vec<(usize, i64)> = batch_features
.iter()
.enumerate()
.map(|(feature_index, feature)| (feature.example_index as usize, feature_index as i64 + 1))
.collect();
let mut feature_id_start = 0;
for (example_id, max_feature_id) in example_index_to_feature_end_position {
let mut answers: Vec<Answer> = vec!();
let example = &examples[example_id];
for feature_idx in feature_id_start..max_feature_id {
let feature = &batch_features[feature_idx as usize];
let start = start_logits.get(feature_idx);
let end = end_logits.get(feature_idx);
let p_mask = (Tensor::of_slice(&feature.p_mask) - 1).abs();
let start: Tensor = start.exp() / start.exp().sum(Float) * &p_mask;
let end: Tensor = end.exp() / end.exp().sum(Float) * &p_mask;
let (starts, ends, scores) = self.decode(&start, &end, top_k);
for idx in 0..starts.len() {
let start_pos = feature.token_to_orig_map[&starts[idx]] as usize;
let end_pos = feature.token_to_orig_map[&ends[idx]] as usize;
let answer = example.doc_tokens[start_pos..end_pos + 1].join(" ");
let start = example.char_to_word_offset
.iter()
.position(|&v| v as usize == start_pos)
.unwrap();
let end = example.char_to_word_offset
.iter()
.rposition(|&v| v as usize == end_pos)
.unwrap();
answers.push(Answer { score: scores[idx], start, end, answer });
}
no_grad(|| {
for feature in batch_features {
input_ids.push(Tensor::of_slice(&feature.input_ids));
attention_masks.push(Tensor::of_slice(&feature.attention_mask));
}
feature_id_start = max_feature_id;
all_answers.push(answers[..(top_k as usize)].to_vec());
}
let input_ids = Tensor::stack(&input_ids, 0).to(self.var_store.device());
let attention_masks = Tensor::stack(&attention_masks, 0).to(self.var_store.device());
let (start_logits, end_logits, _, _) = self.distilbert_qa.forward_t(Some(input_ids), Some(attention_masks), None, false).unwrap();
let start_logits = start_logits.detach();
let end_logits = end_logits.detach();
let example_index_to_feature_end_position: Vec<(usize, i64)> = batch_features
.iter()
.enumerate()
.map(|(feature_index, feature)| (feature.example_index as usize, feature_index as i64 + 1))
.collect();
let mut feature_id_start = 0;
for (example_id, max_feature_id) in example_index_to_feature_end_position {
let mut answers: Vec<Answer> = vec!();
let example = &examples[example_id];
for feature_idx in feature_id_start..max_feature_id {
let feature = &batch_features[feature_idx as usize];
let start = start_logits.get(feature_idx);
let end = end_logits.get(feature_idx);
let p_mask = (Tensor::of_slice(&feature.p_mask) - 1).abs().to_device(start.device());
let start: Tensor = start.exp() / start.exp().sum(Float) * &p_mask;
let end: Tensor = end.exp() / end.exp().sum(Float) * &p_mask;
let (starts, ends, scores) = self.decode(&start, &end, top_k);
for idx in 0..starts.len() {
let start_pos = feature.token_to_orig_map[&starts[idx]] as usize;
let end_pos = feature.token_to_orig_map[&ends[idx]] as usize;
let answer = example.doc_tokens[start_pos..end_pos + 1].join(" ");
let start = example.char_to_word_offset
.iter()
.position(|&v| v as usize == start_pos)
.unwrap();
let end = example.char_to_word_offset
.iter()
.rposition(|&v| v as usize == end_pos)
.unwrap();
answers.push(Answer { score: scores[idx], start, end, answer });
}
}
feature_id_start = max_feature_id;
all_answers.push(answers[..(top_k as usize)].to_vec());
}
});
}
all_answers
}
@ -384,11 +387,9 @@ impl QuestionAnsweringModel {
} else {
candidates.argsort(0, true).slice(0, 0, top_k, 1)
};
let mut start: Vec<i64> = vec!();
let mut end: Vec<i64> = vec!();
let mut scores: Vec<f64> = vec!();
for flat_index_position in 0..idx_sort.size()[0] {
let flat_index = idx_sort.int64_value(&[flat_index_position]);
scores.push(candidates.double_value(&[flat_index]));

View File

@ -57,13 +57,15 @@
//! ```
use rust_tokenizers::bert_tokenizer::BertTokenizer;
use std::path::Path;
use std::path::{Path, PathBuf};
use tch::{Device, Tensor, Kind, no_grad};
use tch::nn::VarStore;
use rust_tokenizers::preprocessing::tokenizer::base_tokenizer::{TruncationStrategy, MultiThreadedTokenizer};
use crate::distilbert::{DistilBertModelClassifier, DistilBertConfig};
use crate::Config;
use std::fs;
use serde::Deserialize;
use std::error::Error;
#[derive(Debug, PartialEq)]
/// Enum with the possible sentiment polarities. Note that the pre-trained SST2 model does not include neutral sentiment.
@ -188,23 +190,43 @@ impl SentimentClassifier {
///
pub fn predict(&self, input: &[&str]) -> Vec<Sentiment> {
let input_tensor = self.prepare_for_model(input.to_vec());
let (output, _, _) = no_grad(|| {
self.distil_bert_classifier
let output = no_grad(|| {
let (output, _, _) = self.distil_bert_classifier
.forward_t(Some(input_tensor),
None,
None,
false)
.unwrap()
.unwrap();
output.softmax(-1, Kind::Float).detach().to(Device::Cpu)
});
let output = output.softmax(-1, Kind::Float);
let mut sentiments: Vec<Sentiment> = vec!();
for record_index in 0..output.size()[0] {
let mut score = output.double_value(&[record_index, 0]);
let polarity = if score < 0.5 {SentimentPolarity::Positive} else {SentimentPolarity::Negative};
if &SentimentPolarity::Positive == &polarity {score = 1.0 - score};
sentiments.push(Sentiment {polarity, score})
let scores = output.select(1, 0).iter::<f64>().unwrap().collect::<Vec<f64>>();
for score in scores {
let polarity = if score < 0.5 { SentimentPolarity::Positive } else { SentimentPolarity::Negative };
let score = if &SentimentPolarity::Positive == &polarity { 1.0 - score } else { score };
sentiments.push(Sentiment { polarity, score })
};
sentiments
}
}
#[derive(Debug, Deserialize)]
struct Record {
sentence: String,
label: i8,
}
pub fn ss2_processor(file_path: PathBuf) -> Result<Vec<String>, Box<dyn Error>> {
let file = fs::File::open(file_path).expect("unable to open file");
let mut csv = csv::ReaderBuilder::new()
.has_headers(true)
.delimiter(b'\t')
.from_reader(file);
let mut records = Vec::new();
for result in csv.deserialize() {
let record: Record = result?;
records.push(record.sentence);
}
Ok(records)
}