From e71712816e9d96e32d9bd8ae4c53e29fbbff27d2 Mon Sep 17 00:00:00 2001 From: guillaume-be Date: Sun, 12 Dec 2021 20:14:36 +0100 Subject: [PATCH] Addition of DeBERTa MNLI example --- .../natural_language_inference_deberta.rs | 77 +++++++++++++++++++ 1 file changed, 77 insertions(+) create mode 100644 examples/natural_language_inference_deberta.rs diff --git a/examples/natural_language_inference_deberta.rs b/examples/natural_language_inference_deberta.rs new file mode 100644 index 0000000..6ada34c --- /dev/null +++ b/examples/natural_language_inference_deberta.rs @@ -0,0 +1,77 @@ +extern crate anyhow; + +use rust_bert::deberta::{ + DebertaConfig, DebertaConfigResources, DebertaForSequenceClassification, + DebertaMergesResources, DebertaModelResources, DebertaVocabResources, +}; +use rust_bert::resources::{RemoteResource, Resource}; +use rust_bert::Config; +use rust_tokenizers::tokenizer::{DeBERTaTokenizer, MultiThreadedTokenizer, TruncationStrategy}; +use tch::{nn, no_grad, Device, Kind, Tensor}; + +fn main() -> anyhow::Result<()> { + // Resources paths + let config_resource = Resource::Remote(RemoteResource::from_pretrained( + DebertaConfigResources::DEBERTA_BASE_MNLI, + )); + let vocab_resource = Resource::Remote(RemoteResource::from_pretrained( + DebertaVocabResources::DEBERTA_BASE_MNLI, + )); + let merges_resource = Resource::Remote(RemoteResource::from_pretrained( + DebertaMergesResources::DEBERTA_BASE_MNLI, + )); + let model_resource = Resource::Remote(RemoteResource::from_pretrained( + DebertaModelResources::DEBERTA_BASE_MNLI, + )); + + let config_path = config_resource.get_local_path()?; + let vocab_path = vocab_resource.get_local_path()?; + let merges_path = merges_resource.get_local_path()?; + let weights_path = model_resource.get_local_path()?; + + // Set-up model + let device = Device::Cpu; + let mut vs = nn::VarStore::new(device); + let tokenizer = DeBERTaTokenizer::from_file( + vocab_path.to_str().unwrap(), + merges_path.to_str().unwrap(), + false, + )?; + let config = DebertaConfig::from_file(config_path); + let model = DebertaForSequenceClassification::new(&vs.root(), &config); + vs.load(weights_path)?; + + // Define input + let input = [("I love you.", "I like you.")]; + + let tokenized_input = MultiThreadedTokenizer::encode_pair_list( + &tokenizer, + &input, + 128, + &TruncationStrategy::LongestFirst, + 0, + ); + let max_len = tokenized_input + .iter() + .map(|input| input.token_ids.len()) + .max() + .unwrap(); + let tokenized_input = tokenized_input + .iter() + .map(|input| input.token_ids.clone()) + .map(|mut input| { + input.extend(vec![0; max_len - input.len()]); + input + }) + .map(|input| Tensor::of_slice(&(input))) + .collect::>(); + let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device); + + // Forward pass + let model_output = + no_grad(|| model.forward_t(Some(&input_tensor), None, None, None, None, false))?; + + model_output.logits.softmax(-1, Kind::Float).print(); + + Ok(()) +}