mirror of
https://github.com/guillaume-be/rust-bert.git
synced 2024-08-16 16:10:25 +03:00
Addition of DeBERTa MNLI example
This commit is contained in:
parent
71f216598f
commit
e71712816e
77
examples/natural_language_inference_deberta.rs
Normal file
77
examples/natural_language_inference_deberta.rs
Normal file
@ -0,0 +1,77 @@
|
||||
extern crate anyhow;
|
||||
|
||||
use rust_bert::deberta::{
|
||||
DebertaConfig, DebertaConfigResources, DebertaForSequenceClassification,
|
||||
DebertaMergesResources, DebertaModelResources, DebertaVocabResources,
|
||||
};
|
||||
use rust_bert::resources::{RemoteResource, Resource};
|
||||
use rust_bert::Config;
|
||||
use rust_tokenizers::tokenizer::{DeBERTaTokenizer, MultiThreadedTokenizer, TruncationStrategy};
|
||||
use tch::{nn, no_grad, Device, Kind, Tensor};
|
||||
|
||||
fn main() -> anyhow::Result<()> {
|
||||
// Resources paths
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
DebertaConfigResources::DEBERTA_BASE_MNLI,
|
||||
));
|
||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
DebertaVocabResources::DEBERTA_BASE_MNLI,
|
||||
));
|
||||
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
DebertaMergesResources::DEBERTA_BASE_MNLI,
|
||||
));
|
||||
let model_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
DebertaModelResources::DEBERTA_BASE_MNLI,
|
||||
));
|
||||
|
||||
let config_path = config_resource.get_local_path()?;
|
||||
let vocab_path = vocab_resource.get_local_path()?;
|
||||
let merges_path = merges_resource.get_local_path()?;
|
||||
let weights_path = model_resource.get_local_path()?;
|
||||
|
||||
// Set-up model
|
||||
let device = Device::Cpu;
|
||||
let mut vs = nn::VarStore::new(device);
|
||||
let tokenizer = DeBERTaTokenizer::from_file(
|
||||
vocab_path.to_str().unwrap(),
|
||||
merges_path.to_str().unwrap(),
|
||||
false,
|
||||
)?;
|
||||
let config = DebertaConfig::from_file(config_path);
|
||||
let model = DebertaForSequenceClassification::new(&vs.root(), &config);
|
||||
vs.load(weights_path)?;
|
||||
|
||||
// Define input
|
||||
let input = [("I love you.", "I like you.")];
|
||||
|
||||
let tokenized_input = MultiThreadedTokenizer::encode_pair_list(
|
||||
&tokenizer,
|
||||
&input,
|
||||
128,
|
||||
&TruncationStrategy::LongestFirst,
|
||||
0,
|
||||
);
|
||||
let max_len = tokenized_input
|
||||
.iter()
|
||||
.map(|input| input.token_ids.len())
|
||||
.max()
|
||||
.unwrap();
|
||||
let tokenized_input = tokenized_input
|
||||
.iter()
|
||||
.map(|input| input.token_ids.clone())
|
||||
.map(|mut input| {
|
||||
input.extend(vec![0; max_len - input.len()]);
|
||||
input
|
||||
})
|
||||
.map(|input| Tensor::of_slice(&(input)))
|
||||
.collect::<Vec<_>>();
|
||||
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
|
||||
|
||||
// Forward pass
|
||||
let model_output =
|
||||
no_grad(|| model.forward_t(Some(&input_tensor), None, None, None, None, false))?;
|
||||
|
||||
model_output.logits.softmax(-1, Kind::Float).print();
|
||||
|
||||
Ok(())
|
||||
}
|
Loading…
Reference in New Issue
Block a user