Added function for preparation to model

This commit is contained in:
Guillaume B 2020-09-02 18:49:30 +02:00
parent 56747f7033
commit 76f82c0e44
4 changed files with 206 additions and 4 deletions

View File

@ -282,6 +282,36 @@ impl TokenizerOption {
}
}
/// Interface method for pair encoding
pub fn encode_pair_list(
&self,
text_pair_list: Vec<(&str, &str)>,
max_len: usize,
truncation_strategy: &TruncationStrategy,
stride: usize,
) -> Vec<TokenizedInput> {
match *self {
Self::Bert(ref tokenizer) => {
tokenizer.encode_pair_list(text_pair_list, max_len, truncation_strategy, stride)
}
Self::Roberta(ref tokenizer) => {
tokenizer.encode_pair_list(text_pair_list, max_len, truncation_strategy, stride)
}
Self::Marian(ref tokenizer) => {
tokenizer.encode_pair_list(text_pair_list, max_len, truncation_strategy, stride)
}
Self::T5(ref tokenizer) => {
tokenizer.encode_pair_list(text_pair_list, max_len, truncation_strategy, stride)
}
Self::XLMRoberta(ref tokenizer) => {
tokenizer.encode_pair_list(text_pair_list, max_len, truncation_strategy, stride)
}
Self::Albert(ref tokenizer) => {
tokenizer.encode_pair_list(text_pair_list, max_len, truncation_strategy, stride)
}
}
}
/// Interface method to tokenization
pub fn tokenize(&self, text: &str) -> Vec<String> {
match *self {

View File

@ -422,7 +422,7 @@ impl SequenceClassificationModel {
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]);
input.extend(vec![self.tokenizer.get_pad_id(); max_len - input.len()]);
input
})
.map(|input| Tensor::of_slice(&(input)))

View File

@ -532,7 +532,7 @@ impl TokenClassificationModel {
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]);
input.extend(vec![self.tokenizer.get_pad_id(); max_len - input.len()]);
input
})
.map(|input| Tensor::of_slice(&(input)))

View File

@ -18,10 +18,13 @@ use crate::bart::{
};
use crate::bert::BertForSequenceClassification;
use crate::distilbert::DistilBertModelClassifier;
use crate::pipelines::common::{ConfigOption, ModelType};
use crate::resources::{RemoteResource, Resource};
use crate::pipelines::common::{ConfigOption, ModelType, TokenizerOption};
use crate::resources::{download_resource, RemoteResource, Resource};
use crate::roberta::RobertaForSequenceClassification;
use crate::RustBertError;
use rust_tokenizers::{TokenizedInput, TruncationStrategy};
use std::borrow::Borrow;
use tch::nn::VarStore;
use tch::{nn, Device, Tensor};
/// # Configuration for ZeroShotClassificationModel
@ -286,3 +289,172 @@ impl ZeroShotClassificationOption {
}
}
}
/// # ZeroShotClassificationModel for Zero Shot Classification
pub struct ZeroShotClassificationModel {
tokenizer: TokenizerOption,
sequence_classifier: ZeroShotClassificationOption,
var_store: VarStore,
}
impl ZeroShotClassificationModel {
/// Build a new `ZeroShotClassificationModel`
///
/// # Arguments
///
/// * `config` - `SequenceClassificationConfig` object containing the resource references (model, vocabulary, configuration) and device placement (CPU/GPU)
///
/// # Example
///
/// ```no_run
/// # fn main() -> anyhow::Result<()> {
/// use rust_bert::pipelines::sequence_classification::SequenceClassificationModel;
///
/// let model = SequenceClassificationModel::new(Default::default())?;
/// # Ok(())
/// # }
/// ```
pub fn new(
config: ZeroShotClassificationConfig,
) -> Result<ZeroShotClassificationModel, RustBertError> {
let config_path = download_resource(&config.config_resource)?;
let vocab_path = download_resource(&config.vocab_resource)?;
let weights_path = download_resource(&config.model_resource)?;
let merges_path = if let Some(merges_resource) = &config.merges_resource {
Some(download_resource(merges_resource).expect("Failure downloading resource"))
} else {
None
};
let device = config.device;
let tokenizer = TokenizerOption::from_file(
config.model_type,
vocab_path.to_str().unwrap(),
merges_path.map(|path| path.to_str().unwrap()),
config.lower_case,
config.strip_accents,
config.add_prefix_space,
)?;
let mut var_store = VarStore::new(device);
let model_config = ConfigOption::from_file(config.model_type, config_path);
let sequence_classifier =
ZeroShotClassificationOption::new(config.model_type, &var_store.root(), &model_config);
var_store.load(weights_path)?;
Ok(ZeroShotClassificationModel {
tokenizer,
sequence_classifier,
var_store,
})
}
fn prepare_for_model<F>(
&self,
input: &[&str],
labels: &[&str],
template: Option<F>,
max_len: usize,
) -> Tensor
where
F: Fn(&str) -> String,
{
let label_sentences: Vec<String> = match template {
Some(function) => labels.into_iter().map(|label| function(label)).collect(),
None => labels
.into_iter()
.map(|&label| format!("This example is {}.", label))
.collect(),
};
let text_pair_list = input
.into_iter()
.zip(label_sentences.iter())
.map(|(&s, label)| (s, label.as_str()))
.collect();
let tokenized_input: Vec<TokenizedInput> = self.tokenizer.encode_pair_list(
text_pair_list,
max_len,
&TruncationStrategy::LongestFirst,
0,
);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let tokenized_input_tensors: Vec<tch::Tensor> = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]);
input
})
.map(|input| Tensor::of_slice(&(input)))
.collect::<Vec<_>>();
Tensor::stack(tokenized_input_tensors.as_slice(), 0).to(self.var_store.device())
}
//
// /// Classify texts
// ///
// /// # Arguments
// ///
// /// * `input` - `&[&str]` Array of texts to classify.
// ///
// /// # Returns
// ///
// /// * `Vec<Label>` containing labels for input texts
// ///
// /// # Example
// ///
// /// ```no_run
// /// # fn main() -> anyhow::Result<()> {
// /// # use rust_bert::pipelines::sequence_classification::SequenceClassificationModel;
// ///
// /// let sequence_classification_model = SequenceClassificationModel::new(Default::default())?;
// /// let input = [
// /// "Probably my all-time favorite movie, a story of selflessness, sacrifice and dedication to a noble cause, but it's not preachy or boring.",
// /// "This film tried to be too many things all at once: stinging political satire, Hollywood blockbuster, sappy romantic comedy, family values promo...",
// /// "If you like original gut wrenching laughter you will like this movie. If you are young or old then you will love this movie, hell even my mom liked it.",
// /// ];
// /// let output = sequence_classification_model.predict(&input);
// /// # Ok(())
// /// # }
// /// ```
// pub fn predict(&self, input: &[&str]) -> Vec<Label> {
// let input_tensor = self.prepare_for_model(input.to_vec());
// let output = no_grad(|| {
// let output = self.sequence_classifier.forward_t(
// Some(input_tensor.copy()),
// None,
// None,
// None,
// None,
// false,
// );
// output.softmax(-1, Kind::Float).detach().to(Device::Cpu)
// });
// let label_indices = output.as_ref().argmax(-1, true).squeeze1(1);
// let scores = output
// .gather(1, &label_indices.unsqueeze(-1), false)
// .squeeze1(1);
// let label_indices = label_indices.iter::<i64>().unwrap().collect::<Vec<i64>>();
// let scores = scores.iter::<f64>().unwrap().collect::<Vec<f64>>();
//
// let mut labels: Vec<Label> = vec![];
// for sentence_idx in 0..label_indices.len() {
// let label_string = self
// .label_mapping
// .get(&label_indices[sentence_idx])
// .unwrap()
// .clone();
// let label = Label {
// text: label_string,
// score: scores[sentence_idx],
// id: label_indices[sentence_idx],
// sentence: sentence_idx,
// };
// labels.push(label)
// }
// labels
// }
}