rust-bert/examples/natural_language_inference_deberta.rs
Matt Weber ba57704c6f
Introduce in-memory resource abstraction (#375)
* Introduce in-memory resource abstraction

This follows from discussion in #366.

The goal of this change is to allow for weights to be loaded from a copy
of `rust_model.ot` that is already present in memory. There are two ways
in which that data might be present:

1. As a `HashMap<String, Tensor>` from previous interaction with `tch`
2. As a contiguous buffer of the file data

One or the other mechanism might be preferable depending on how user
code is using the model data. In some sense, implementing a provider
based on the second option is more of a convenience method for the user
to avoid the `tch::nn::VarStore::load_from_stream` interaction.

I've changed the definition of the `ResourceProvider` trait to require
that it be both `Send` and `Sync`. There are currently certain contexts
where `dyn ResourceProvider + Send` is required, but in theory before
this change an implementation might not be `Send` (or `Sync`). The
existing providers are both `Send` and `Sync`, and it seems reasonable
(if technically incorrect) for user code to assume this to be true. I
don't see a downside to making this explicit, but that part of this
change might be better suited for separate discussion. I am not trying
to sneak it in.

The `enum Resource` data type is used here as a means to abstract over
the possible ways a `ResourceProvider` might represent an underlying
resource. Without this, it would be necessary to either call different
trait methods until one succeeded or implement `as_any` and downcast in
order to implement `load_weights` similarly to how it is now. Those
options seemed less preferable to creating a wrapper.

While it would be possible to replace all calls to `get_local_path` with
the `get_resource` API, removal of the existing function would be a very
big breaking change. As such, this change also introduces
`RustBertError::UnsupportedError` to allow for the different methods to
coexist. An alternative would be for the new `ResourceProvider`s to
write their resources to a temporary disk location and return an
appropriate path, but that is counter to the purpose of the new
`ResourceProvider`s and so I chose not to implement that.

* - Add `impl<T: ResourceProvider + ?Sized> ResourceProvider for Box<T>`
- Remove `Resource::NamedTensors`
- Change `BufferResource` to contain a `&[u8]` rather than `Vec<u8>`

* Further rework proposal for resources

* Use mutable references and locks

* Make model resources mutable in tests/examples

* Remove unnecessary mutability and TensorResource references

* Add `BufferResource` example

---------

Co-authored-by: Guillaume Becquin <guillaume.becquin@gmail.com>
2023-05-26 18:23:28 +01:00

77 lines
2.5 KiB
Rust

extern crate anyhow;
use rust_bert::deberta::{
DebertaConfig, DebertaConfigResources, DebertaForSequenceClassification,
DebertaMergesResources, DebertaModelResources, DebertaVocabResources,
};
use rust_bert::resources::{load_weights, RemoteResource, ResourceProvider};
use rust_bert::Config;
use rust_tokenizers::tokenizer::{DeBERTaTokenizer, MultiThreadedTokenizer, TruncationStrategy};
use tch::{nn, no_grad, Device, Kind, Tensor};
fn main() -> anyhow::Result<()> {
// Resources paths
let config_resource = Box::new(RemoteResource::from_pretrained(
DebertaConfigResources::DEBERTA_BASE_MNLI,
));
let vocab_resource = Box::new(RemoteResource::from_pretrained(
DebertaVocabResources::DEBERTA_BASE_MNLI,
));
let merges_resource = Box::new(RemoteResource::from_pretrained(
DebertaMergesResources::DEBERTA_BASE_MNLI,
));
let model_resource = Box::new(RemoteResource::from_pretrained(
DebertaModelResources::DEBERTA_BASE_MNLI,
));
let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?;
let merges_path = merges_resource.get_local_path()?;
// Set-up model
let device = Device::Cpu;
let mut vs = nn::VarStore::new(device);
let tokenizer = DeBERTaTokenizer::from_file(
vocab_path.to_str().unwrap(),
merges_path.to_str().unwrap(),
false,
)?;
let config = DebertaConfig::from_file(config_path);
let model = DebertaForSequenceClassification::new(vs.root(), &config)?;
load_weights(&model_resource, &mut vs)?;
// Define input
let input = [("I love you.", "I like you.")];
let tokenized_input = MultiThreadedTokenizer::encode_pair_list(
&tokenizer,
&input,
128,
&TruncationStrategy::LongestFirst,
0,
);
let max_len = tokenized_input
.iter()
.map(|input| input.token_ids.len())
.max()
.unwrap();
let tokenized_input = tokenized_input
.iter()
.map(|input| input.token_ids.clone())
.map(|mut input| {
input.extend(vec![0; max_len - input.len()]);
input
})
.map(|input| Tensor::from_slice(&(input)))
.collect::<Vec<_>>();
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
let model_output =
no_grad(|| model.forward_t(Some(&input_tensor), None, None, None, None, false))?;
model_output.logits.softmax(-1, Kind::Float).print();
Ok(())
}