rust-bert/examples/buffer_resource.rs
Matt Weber ba57704c6f
Introduce in-memory resource abstraction (#375)
* Introduce in-memory resource abstraction

This follows from discussion in #366.

The goal of this change is to allow for weights to be loaded from a copy
of `rust_model.ot` that is already present in memory. There are two ways
in which that data might be present:

1. As a `HashMap<String, Tensor>` from previous interaction with `tch`
2. As a contiguous buffer of the file data

One or the other mechanism might be preferable depending on how user
code is using the model data. In some sense, implementing a provider
based on the second option is more of a convenience method for the user
to avoid the `tch::nn::VarStore::load_from_stream` interaction.

I've changed the definition of the `ResourceProvider` trait to require
that it be both `Send` and `Sync`. There are currently certain contexts
where `dyn ResourceProvider + Send` is required, but in theory before
this change an implementation might not be `Send` (or `Sync`). The
existing providers are both `Send` and `Sync`, and it seems reasonable
(if technically incorrect) for user code to assume this to be true. I
don't see a downside to making this explicit, but that part of this
change might be better suited for separate discussion. I am not trying
to sneak it in.

The `enum Resource` data type is used here as a means to abstract over
the possible ways a `ResourceProvider` might represent an underlying
resource. Without this, it would be necessary to either call different
trait methods until one succeeded or implement `as_any` and downcast in
order to implement `load_weights` similarly to how it is now. Those
options seemed less preferable to creating a wrapper.

While it would be possible to replace all calls to `get_local_path` with
the `get_resource` API, removal of the existing function would be a very
big breaking change. As such, this change also introduces
`RustBertError::UnsupportedError` to allow for the different methods to
coexist. An alternative would be for the new `ResourceProvider`s to
write their resources to a temporary disk location and return an
appropriate path, but that is counter to the purpose of the new
`ResourceProvider`s and so I chose not to implement that.

* - Add `impl<T: ResourceProvider + ?Sized> ResourceProvider for Box<T>`
- Remove `Resource::NamedTensors`
- Change `BufferResource` to contain a `&[u8]` rather than `Vec<u8>`

* Further rework proposal for resources

* Use mutable references and locks

* Make model resources mutable in tests/examples

* Remove unnecessary mutability and TensorResource references

* Add `BufferResource` example

---------

Co-authored-by: Guillaume Becquin <guillaume.becquin@gmail.com>
2023-05-26 18:23:28 +01:00

98 lines
5.1 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc.
// Copyright 2019 Guillaume Becquin
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
extern crate anyhow;
use std::sync::{Arc, RwLock};
use rust_bert::bart::{
BartConfigResources, BartMergesResources, BartModelResources, BartVocabResources,
};
use rust_bert::pipelines::summarization::{SummarizationConfig, SummarizationModel};
use rust_bert::resources::{BufferResource, RemoteResource, ResourceProvider};
use tch::Device;
fn main() -> anyhow::Result<()> {
let input = ["In findings published Tuesday in Cornell University's arXiv by a team of scientists \
from the University of Montreal and a separate report published Wednesday in Nature Astronomy by a team \
from University College London (UCL), the presence of water vapour was confirmed in the atmosphere of K2-18b, \
a planet circling a star in the constellation Leo. This is the first such discovery in a planet in its star's \
habitable zone — not too hot and not too cold for liquid water to exist. The Montreal team, led by Björn Benneke, \
used data from the NASA's Hubble telescope to assess changes in the light coming from K2-18b's star as the planet \
passed between it and Earth. They found that certain wavelengths of light, which are usually absorbed by water, \
weakened when the planet was in the way, indicating not only does K2-18b have an atmosphere, but the atmosphere \
contains water in vapour form. The team from UCL then analyzed the Montreal team's data using their own software \
and confirmed their conclusion. This was not the first time scientists have found signs of water on an exoplanet, \
but previous discoveries were made on planets with high temperatures or other pronounced differences from Earth. \
\"This is the first potentially habitable planet where the temperature is right and where we now know there is water,\" \
said UCL astronomer Angelos Tsiaras. \"It's the best candidate for habitability right now.\" \"It's a good sign\", \
said Ryan Cloutier of the HarvardSmithsonian Center for Astrophysics, who was not one of either study's authors. \
\"Overall,\" he continued, \"the presence of water in its atmosphere certainly improves the prospect of K2-18b being \
a potentially habitable planet, but further observations will be required to say for sure. \" \
K2-18b was first identified in 2015 by the Kepler space telescope. It is about 110 light-years from Earth and larger \
but less dense. Its star, a red dwarf, is cooler than the Sun, but the planet's orbit is much closer, such that a year \
on K2-18b lasts 33 Earth days. According to The Guardian, astronomers were optimistic that NASA's James Webb space \
telescope — scheduled for launch in 2021 — and the European Space Agency's 2028 ARIEL program, could reveal more \
about exoplanets like K2-18b."];
let weights = Arc::new(RwLock::new(get_weights()?));
let summarization_model = SummarizationModel::new(config(Device::Cpu, weights.clone()))?;
// Credits: WikiNews, CC BY 2.5 license (https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b)
let output = summarization_model.summarize(&input);
for sentence in output {
println!("{sentence}");
}
let summarization_model =
SummarizationModel::new(config(Device::cuda_if_available(), weights))?;
// Credits: WikiNews, CC BY 2.5 license (https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b)
let output = summarization_model.summarize(&input);
for sentence in output {
println!("{sentence}");
}
Ok(())
}
fn get_weights() -> anyhow::Result<Vec<u8>, anyhow::Error> {
let model_resource = RemoteResource::from_pretrained(BartModelResources::DISTILBART_CNN_6_6);
Ok(std::fs::read(model_resource.get_local_path()?)?)
}
fn config(device: Device, model_data: Arc<RwLock<Vec<u8>>>) -> SummarizationConfig {
let config_resource = Box::new(RemoteResource::from_pretrained(
BartConfigResources::DISTILBART_CNN_6_6,
));
let vocab_resource = Box::new(RemoteResource::from_pretrained(
BartVocabResources::DISTILBART_CNN_6_6,
));
let merges_resource = Box::new(RemoteResource::from_pretrained(
BartMergesResources::DISTILBART_CNN_6_6,
));
let model_resource = Box::new(BufferResource { data: model_data });
SummarizationConfig {
model_resource,
config_resource,
vocab_resource,
merges_resource: Some(merges_resource),
num_beams: 1,
length_penalty: 1.0,
min_length: 56,
max_length: Some(142),
device,
..Default::default()
}
}